aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph.kurata.bradley@gmail.com>2014-07-17 15:05:02 -0700
committerXiangrui Meng <meng@databricks.com>2014-07-17 15:05:02 -0700
commit935fe65ff6559a0e3b481e7508fa14337b23020b (patch)
treeee298094fca9a7aead7c7e3f01abdd952ddc4845 /mllib/src/test/scala
parent1fcd5dcdd8edb0e6989278c95e7f2c7d86c4efb2 (diff)
downloadspark-935fe65ff6559a0e3b481e7508fa14337b23020b.tar.gz
spark-935fe65ff6559a0e3b481e7508fa14337b23020b.tar.bz2
spark-935fe65ff6559a0e3b481e7508fa14337b23020b.zip
SPARK-1215 [MLLIB]: Clustering: Index out of bounds error (2)
Added check to LocalKMeans.scala: kMeansPlusPlus initialization to handle case with fewer distinct data points than clusters k. Added two related unit tests to KMeansSuite. (Re-submitting PR after tangling commits in PR 1407 https://github.com/apache/spark/pull/1407 ) Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com> Closes #1468 from jkbradley/kmeans-fix and squashes the following commits: 4e9bd1e [Joseph K. Bradley] Updated PR per comments from mengxr 6c7a2ec [Joseph K. Bradley] Added check to LocalKMeans.scala: kMeansPlusPlus initialization to handle case with fewer distinct data points than clusters k. Added two related unit tests to KMeansSuite.
Diffstat (limited to 'mllib/src/test/scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala26
1 files changed, 26 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
index 560a4ad71a..76a3bdf9b1 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
@@ -61,6 +61,32 @@ class KMeansSuite extends FunSuite with LocalSparkContext {
assert(model.clusterCenters.head === center)
}
+ test("no distinct points") {
+ val data = sc.parallelize(
+ Array(
+ Vectors.dense(1.0, 2.0, 3.0),
+ Vectors.dense(1.0, 2.0, 3.0),
+ Vectors.dense(1.0, 2.0, 3.0)),
+ 2)
+ val center = Vectors.dense(1.0, 2.0, 3.0)
+
+ // Make sure code runs.
+ var model = KMeans.train(data, k=2, maxIterations=1)
+ assert(model.clusterCenters.size === 2)
+ }
+
+ test("more clusters than points") {
+ val data = sc.parallelize(
+ Array(
+ Vectors.dense(1.0, 2.0, 3.0),
+ Vectors.dense(1.0, 3.0, 4.0)),
+ 2)
+
+ // Make sure code runs.
+ var model = KMeans.train(data, k=3, maxIterations=1)
+ assert(model.clusterCenters.size === 3)
+ }
+
test("single cluster with big dataset") {
val smallData = Array(
Vectors.dense(1.0, 2.0, 6.0),