aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorShivaram Venkataraman <shivaram@eecs.berkeley.edu>2013-08-06 16:35:47 -0700
committerShivaram Venkataraman <shivaram@eecs.berkeley.edu>2013-08-06 16:35:47 -0700
commit6caec3f44193a459a2dd10b0393e391979795039 (patch)
tree651da54654302504e8dfe9ab765ee857357d26cd /mllib
parent471fbadd0c8cb8d310e3e1dd0e694e357ff1233e (diff)
downloadspark-6caec3f44193a459a2dd10b0393e391979795039.tar.gz
spark-6caec3f44193a459a2dd10b0393e391979795039.tar.bz2
spark-6caec3f44193a459a2dd10b0393e391979795039.zip
Add a test case for random initialization.
Also workaround a bug where double[][] class cast fails
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/spark/mllib/clustering/KMeans.scala4
-rw-r--r--mllib/src/test/scala/spark/mllib/clustering/JavaKMeansSuite.java11
2 files changed, 13 insertions, 2 deletions
diff --git a/mllib/src/main/scala/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/spark/mllib/clustering/KMeans.scala
index 750163e1c3..97e3d110ae 100644
--- a/mllib/src/main/scala/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/spark/mllib/clustering/KMeans.scala
@@ -194,8 +194,8 @@ class KMeans private (
*/
private def initRandom(data: RDD[Array[Double]]): Array[ClusterCenters] = {
// Sample all the cluster centers in one pass to avoid repeated scans
- val sample = data.takeSample(true, runs * k, new Random().nextInt())
- Array.tabulate(runs)(r => sample.slice(r * k, (r + 1) * k))
+ val sample = data.takeSample(true, runs * k, new Random().nextInt()).toSeq
+ Array.tabulate(runs)(r => sample.slice(r * k, (r + 1) * k).toArray)
}
/**
diff --git a/mllib/src/test/scala/spark/mllib/clustering/JavaKMeansSuite.java b/mllib/src/test/scala/spark/mllib/clustering/JavaKMeansSuite.java
index f438a92fad..a3db0c0f6d 100644
--- a/mllib/src/test/scala/spark/mllib/clustering/JavaKMeansSuite.java
+++ b/mllib/src/test/scala/spark/mllib/clustering/JavaKMeansSuite.java
@@ -86,6 +86,10 @@ public class JavaKMeansSuite implements Serializable {
JavaRDD<double[]> data = sc.parallelize(points, 2);
KMeansModel model = KMeans.train(data.rdd(), 1, 1);
+ assertSetsEqual(model.clusterCenters(), expectedCenter);
+
+ model = KMeans.train(data.rdd(), 1, 1, 1, KMeans.RANDOM());
+ assertSetsEqual(model.clusterCenters(), expectedCenter);
}
@Test
@@ -100,5 +104,12 @@ public class JavaKMeansSuite implements Serializable {
JavaRDD<double[]> data = sc.parallelize(points, 2);
KMeansModel model = new KMeans().setK(1).setMaxIterations(5).run(data.rdd());
assertSetsEqual(model.clusterCenters(), expectedCenter);
+
+ model = new KMeans().setK(1)
+ .setMaxIterations(1)
+ .setRuns(1)
+ .setInitializationMode(KMeans.RANDOM())
+ .run(data.rdd());
+ assertSetsEqual(model.clusterCenters(), expectedCenter);
}
}