aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorMarek Kolodziej <mkolod@gmail.com>2013-11-20 07:03:36 -0500
committerMarek Kolodziej <mkolod@gmail.com>2013-11-20 07:03:36 -0500
commit22724659db8d711492f58c90d530be2f4a5b3de9 (patch)
treef34b427cb731b8e9810551a67386d4a9ea5c1933 /mllib
parentbcc6ed30bf7189ebf0226f212b4e39830b830b6e (diff)
downloadspark-22724659db8d711492f58c90d530be2f4a5b3de9.tar.gz
spark-22724659db8d711492f58c90d530be2f4a5b3de9.tar.bz2
spark-22724659db8d711492f58c90d530be2f4a5b3de9.zip
Make XORShiftRandom explicit in KMeans and roll it back for RDD
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala8
1 files changed, 4 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index f09ea9e2f7..0dee9399a8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -26,7 +26,7 @@ import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD
import org.apache.spark.Logging
import org.apache.spark.mllib.util.MLUtils
-import org.apache.spark.util.{XORShiftRandom => Random}
+import org.apache.spark.util.XORShiftRandom
@@ -196,7 +196,7 @@ class KMeans private (
*/
private def initRandom(data: RDD[Array[Double]]): Array[ClusterCenters] = {
// Sample all the cluster centers in one pass to avoid repeated scans
- val sample = data.takeSample(true, runs * k, new Random().nextInt()).toSeq
+ val sample = data.takeSample(true, runs * k, new XORShiftRandom().nextInt()).toSeq
Array.tabulate(runs)(r => sample.slice(r * k, (r + 1) * k).toArray)
}
@@ -211,7 +211,7 @@ class KMeans private (
*/
private def initKMeansParallel(data: RDD[Array[Double]]): Array[ClusterCenters] = {
// Initialize each run's center to a random point
- val seed = new Random().nextInt()
+ val seed = new XORShiftRandom().nextInt()
val sample = data.takeSample(true, runs, seed).toSeq
val centers = Array.tabulate(runs)(r => ArrayBuffer(sample(r)))
@@ -223,7 +223,7 @@ class KMeans private (
for (r <- 0 until runs) yield (r, KMeans.pointCost(centerArrays(r), point))
}.reduceByKey(_ + _).collectAsMap()
val chosen = data.mapPartitionsWithIndex { (index, points) =>
- val rand = new Random(seed ^ (step << 16) ^ index)
+ val rand = new XORShiftRandom(seed ^ (step << 16) ^ index)
for {
p <- points
r <- 0 until runs