From 6caec3f44193a459a2dd10b0393e391979795039 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Tue, 6 Aug 2013 16:35:47 -0700 Subject: Add a test case for random initialization. Also workaround a bug where double[][] class cast fails --- mllib/src/main/scala/spark/mllib/clustering/KMeans.scala | 4 ++-- .../test/scala/spark/mllib/clustering/JavaKMeansSuite.java | 11 +++++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) (limited to 'mllib/src') diff --git a/mllib/src/main/scala/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/spark/mllib/clustering/KMeans.scala index 750163e1c3..97e3d110ae 100644 --- a/mllib/src/main/scala/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/spark/mllib/clustering/KMeans.scala @@ -194,8 +194,8 @@ class KMeans private ( */ private def initRandom(data: RDD[Array[Double]]): Array[ClusterCenters] = { // Sample all the cluster centers in one pass to avoid repeated scans - val sample = data.takeSample(true, runs * k, new Random().nextInt()) - Array.tabulate(runs)(r => sample.slice(r * k, (r + 1) * k)) + val sample = data.takeSample(true, runs * k, new Random().nextInt()).toSeq + Array.tabulate(runs)(r => sample.slice(r * k, (r + 1) * k).toArray) } /** diff --git a/mllib/src/test/scala/spark/mllib/clustering/JavaKMeansSuite.java b/mllib/src/test/scala/spark/mllib/clustering/JavaKMeansSuite.java index f438a92fad..a3db0c0f6d 100644 --- a/mllib/src/test/scala/spark/mllib/clustering/JavaKMeansSuite.java +++ b/mllib/src/test/scala/spark/mllib/clustering/JavaKMeansSuite.java @@ -86,6 +86,10 @@ public class JavaKMeansSuite implements Serializable { JavaRDD data = sc.parallelize(points, 2); KMeansModel model = KMeans.train(data.rdd(), 1, 1); + assertSetsEqual(model.clusterCenters(), expectedCenter); + + model = KMeans.train(data.rdd(), 1, 1, 1, KMeans.RANDOM()); + assertSetsEqual(model.clusterCenters(), expectedCenter); } @Test @@ -100,5 +104,12 @@ public class JavaKMeansSuite implements Serializable { JavaRDD data = sc.parallelize(points, 2); KMeansModel model = new KMeans().setK(1).setMaxIterations(5).run(data.rdd()); assertSetsEqual(model.clusterCenters(), expectedCenter); + + model = new KMeans().setK(1) + .setMaxIterations(1) + .setRuns(1) + .setInitializationMode(KMeans.RANDOM()) + .run(data.rdd()); + assertSetsEqual(model.clusterCenters(), expectedCenter); } } -- cgit v1.2.3