aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorIlya Matiach <ilmat@microsoft.com>2017-01-23 13:34:27 -0800
committerJoseph K. Bradley <joseph@databricks.com>2017-01-23 13:34:27 -0800
commit5b258b8b0752d13842d40ae69107f7976678cf17 (patch)
tree9bfccd9d4dfb84b30554a8eddf03b695c7c604d5 /mllib/src
parentc8aea7445c3bc724f4f2d4cae37d59748dd0e678 (diff)
downloadspark-5b258b8b0752d13842d40ae69107f7976678cf17.tar.gz
spark-5b258b8b0752d13842d40ae69107f7976678cf17.tar.bz2
spark-5b258b8b0752d13842d40ae69107f7976678cf17.zip
[SPARK-16473][MLLIB] Fix BisectingKMeans Algorithm failing in edge case
[SPARK-16473][MLLIB] Fix BisectingKMeans Algorithm failing in edge case where no children exist in updateAssignments ## What changes were proposed in this pull request? Fix a bug in which BisectingKMeans fails with error: java.util.NoSuchElementException: key not found: 166 at scala.collection.MapLike$class.default(MapLike.scala:228) at scala.collection.AbstractMap.default(Map.scala:58) at scala.collection.MapLike$class.apply(MapLike.scala:141) at scala.collection.AbstractMap.apply(Map.scala:58) at org.apache.spark.mllib.clustering.BisectingKMeans$$anonfun$org$apache$spark$mllib$clustering$BisectingKMeans$$updateAssignments$1$$anonfun$2.apply$mcDJ$sp(BisectingKMeans.scala:338) at org.apache.spark.mllib.clustering.BisectingKMeans$$anonfun$org$apache$spark$mllib$clustering$BisectingKMeans$$updateAssignments$1$$anonfun$2.apply(BisectingKMeans.scala:337) at org.apache.spark.mllib.clustering.BisectingKMeans$$anonfun$org$apache$spark$mllib$clustering$BisectingKMeans$$updateAssignments$1$$anonfun$2.apply(BisectingKMeans.scala:337) at scala.collection.TraversableOnce$$anonfun$minBy$1.apply(TraversableOnce.scala:231) at scala.collection.LinearSeqOptimized$class.foldLeft(LinearSeqOptimized.scala:111) at scala.collection.immutable.List.foldLeft(List.scala:84) at scala.collection.LinearSeqOptimized$class.reduceLeft(LinearSeqOptimized.scala:125) at scala.collection.immutable.List.reduceLeft(List.scala:84) at scala.collection.TraversableOnce$class.minBy(TraversableOnce.scala:231) at scala.collection.AbstractTraversable.minBy(Traversable.scala:105) at org.apache.spark.mllib.clustering.BisectingKMeans$$anonfun$org$apache$spark$mllib$clustering$BisectingKMeans$$updateAssignments$1.apply(BisectingKMeans.scala:337) at org.apache.spark.mllib.clustering.BisectingKMeans$$anonfun$org$apache$spark$mllib$clustering$BisectingKMeans$$updateAssignments$1.apply(BisectingKMeans.scala:334) at scala.collection.Iterator$$anon$11.next(Iterator.scala:328) at scala.collection.Iterator$$anon$14.hasNext(Iterator.scala:389) ## How was this patch tested? The dataset was run against the code change to verify that the code works. I will try to add unit tests to the code. (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Ilya Matiach <ilmat@microsoft.com> Closes #16355 from imatiach-msft/ilmat/fix-kmeans.
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala19
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala19
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala13
3 files changed, 44 insertions, 7 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala
index 336f2fc114..ae98e24a75 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala
@@ -339,10 +339,15 @@ private object BisectingKMeans extends Serializable {
assignments.map { case (index, v) =>
if (divisibleIndices.contains(index)) {
val children = Seq(leftChildIndex(index), rightChildIndex(index))
- val selected = children.minBy { child =>
- KMeans.fastSquaredDistance(newClusterCenters(child), v)
+ val newClusterChildren = children.filter(newClusterCenters.contains(_))
+ if (newClusterChildren.nonEmpty) {
+ val selected = newClusterChildren.minBy { child =>
+ KMeans.fastSquaredDistance(newClusterCenters(child), v)
+ }
+ (selected, v)
+ } else {
+ (index, v)
}
- (selected, v)
} else {
(index, v)
}
@@ -372,12 +377,12 @@ private object BisectingKMeans extends Serializable {
internalIndex -= 1
val leftIndex = leftChildIndex(rawIndex)
val rightIndex = rightChildIndex(rawIndex)
- val height = math.sqrt(Seq(leftIndex, rightIndex).map { childIndex =>
+ val indexes = Seq(leftIndex, rightIndex).filter(clusters.contains(_))
+ val height = math.sqrt(indexes.map { childIndex =>
KMeans.fastSquaredDistance(center, clusters(childIndex).center)
}.max)
- val left = buildSubTree(leftIndex)
- val right = buildSubTree(rightIndex)
- new ClusteringTreeNode(index, size, center, cost, height, Array(left, right))
+ val children = indexes.map(buildSubTree(_)).toArray
+ new ClusteringTreeNode(index, size, center, cost, height, children)
} else {
val index = leafIndex
leafIndex += 1
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
index fc491cd616..30513c1e27 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
@@ -29,9 +29,12 @@ class BisectingKMeansSuite
final val k = 5
@transient var dataset: Dataset[_] = _
+ @transient var sparseDataset: Dataset[_] = _
+
override def beforeAll(): Unit = {
super.beforeAll()
dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k)
+ sparseDataset = KMeansSuite.generateSparseData(spark, 10, 1000, 42)
}
test("default parameters") {
@@ -51,6 +54,22 @@ class BisectingKMeansSuite
assert(copiedModel.hasSummary)
}
+ test("SPARK-16473: Verify Bisecting K-Means does not fail in edge case where" +
+ "one cluster is empty after split") {
+ val bkm = new BisectingKMeans()
+ .setK(k)
+ .setMinDivisibleClusterSize(4)
+ .setMaxIter(4)
+ .setSeed(123)
+
+ // Verify fit does not fail on very sparse data
+ val model = bkm.fit(sparseDataset)
+ val result = model.transform(sparseDataset)
+ val numClusters = result.select("prediction").distinct().collect().length
+ // Verify we hit the edge case
+ assert(numClusters < k && numClusters > 1)
+ }
+
test("setter/getter") {
val bkm = new BisectingKMeans()
.setK(9)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
index c1b7242e11..e10127f7d1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ml.clustering
+import scala.util.Random
+
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
@@ -160,6 +162,17 @@ object KMeansSuite {
spark.createDataFrame(rdd)
}
+ def generateSparseData(spark: SparkSession, rows: Int, dim: Int, seed: Int): DataFrame = {
+ val sc = spark.sparkContext
+ val random = new Random(seed)
+ val nnz = random.nextInt(dim)
+ val rdd = sc.parallelize(1 to rows)
+ .map(i => Vectors.sparse(dim, random.shuffle(0 to dim - 1).slice(0, nnz).sorted.toArray,
+ Array.fill(nnz)(random.nextDouble())))
+ .map(v => new TestRow(v))
+ spark.createDataFrame(rdd)
+ }
+
/**
* Mapping from all Params to valid settings which differ from the defaults.
* This is useful for tests which need to exercise all Params, such as save/load.