aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala8
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala26
2 files changed, 33 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala
index 2e3a4ce783..f0722d7c14 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala
@@ -59,7 +59,13 @@ private[mllib] object LocalKMeans extends Logging {
cumulativeScore += weights(j) * KMeans.pointCost(curCenters, points(j))
j += 1
}
- centers(i) = points(j-1).toDense
+ if (j == 0) {
+ logWarning("kMeansPlusPlus initialization ran out of distinct points for centers." +
+ s" Using duplicate point for center k = $i.")
+ centers(i) = points(0).toDense
+ } else {
+ centers(i) = points(j - 1).toDense
+ }
}
// Run up to maxIterations iterations of Lloyd's algorithm
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
index 560a4ad71a..76a3bdf9b1 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
@@ -61,6 +61,32 @@ class KMeansSuite extends FunSuite with LocalSparkContext {
assert(model.clusterCenters.head === center)
}
+ test("no distinct points") {
+ val data = sc.parallelize(
+ Array(
+ Vectors.dense(1.0, 2.0, 3.0),
+ Vectors.dense(1.0, 2.0, 3.0),
+ Vectors.dense(1.0, 2.0, 3.0)),
+ 2)
+ val center = Vectors.dense(1.0, 2.0, 3.0)
+
+ // Make sure code runs.
+ var model = KMeans.train(data, k=2, maxIterations=1)
+ assert(model.clusterCenters.size === 2)
+ }
+
+ test("more clusters than points") {
+ val data = sc.parallelize(
+ Array(
+ Vectors.dense(1.0, 2.0, 3.0),
+ Vectors.dense(1.0, 3.0, 4.0)),
+ 2)
+
+ // Make sure code runs.
+ var model = KMeans.train(data, k=3, maxIterations=1)
+ assert(model.clusterCenters.size === 3)
+ }
+
test("single cluster with big dataset") {
val smallData = Array(
Vectors.dense(1.0, 2.0, 6.0),