aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala19
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala19
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala13
3 files changed, 44 insertions, 7 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala
index 336f2fc114..ae98e24a75 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala
@@ -339,10 +339,15 @@ private object BisectingKMeans extends Serializable {
assignments.map { case (index, v) =>
if (divisibleIndices.contains(index)) {
val children = Seq(leftChildIndex(index), rightChildIndex(index))
- val selected = children.minBy { child =>
- KMeans.fastSquaredDistance(newClusterCenters(child), v)
+ val newClusterChildren = children.filter(newClusterCenters.contains(_))
+ if (newClusterChildren.nonEmpty) {
+ val selected = newClusterChildren.minBy { child =>
+ KMeans.fastSquaredDistance(newClusterCenters(child), v)
+ }
+ (selected, v)
+ } else {
+ (index, v)
}
- (selected, v)
} else {
(index, v)
}
@@ -372,12 +377,12 @@ private object BisectingKMeans extends Serializable {
internalIndex -= 1
val leftIndex = leftChildIndex(rawIndex)
val rightIndex = rightChildIndex(rawIndex)
- val height = math.sqrt(Seq(leftIndex, rightIndex).map { childIndex =>
+ val indexes = Seq(leftIndex, rightIndex).filter(clusters.contains(_))
+ val height = math.sqrt(indexes.map { childIndex =>
KMeans.fastSquaredDistance(center, clusters(childIndex).center)
}.max)
- val left = buildSubTree(leftIndex)
- val right = buildSubTree(rightIndex)
- new ClusteringTreeNode(index, size, center, cost, height, Array(left, right))
+ val children = indexes.map(buildSubTree(_)).toArray
+ new ClusteringTreeNode(index, size, center, cost, height, children)
} else {
val index = leafIndex
leafIndex += 1
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
index fc491cd616..30513c1e27 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
@@ -29,9 +29,12 @@ class BisectingKMeansSuite
final val k = 5
@transient var dataset: Dataset[_] = _
+ @transient var sparseDataset: Dataset[_] = _
+
override def beforeAll(): Unit = {
super.beforeAll()
dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k)
+ sparseDataset = KMeansSuite.generateSparseData(spark, 10, 1000, 42)
}
test("default parameters") {
@@ -51,6 +54,22 @@ class BisectingKMeansSuite
assert(copiedModel.hasSummary)
}
+ test("SPARK-16473: Verify Bisecting K-Means does not fail in edge case where" +
+ "one cluster is empty after split") {
+ val bkm = new BisectingKMeans()
+ .setK(k)
+ .setMinDivisibleClusterSize(4)
+ .setMaxIter(4)
+ .setSeed(123)
+
+ // Verify fit does not fail on very sparse data
+ val model = bkm.fit(sparseDataset)
+ val result = model.transform(sparseDataset)
+ val numClusters = result.select("prediction").distinct().collect().length
+ // Verify we hit the edge case
+ assert(numClusters < k && numClusters > 1)
+ }
+
test("setter/getter") {
val bkm = new BisectingKMeans()
.setK(9)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
index c1b7242e11..e10127f7d1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ml.clustering
+import scala.util.Random
+
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
@@ -160,6 +162,17 @@ object KMeansSuite {
spark.createDataFrame(rdd)
}
+ def generateSparseData(spark: SparkSession, rows: Int, dim: Int, seed: Int): DataFrame = {
+ val sc = spark.sparkContext
+ val random = new Random(seed)
+ val nnz = random.nextInt(dim)
+ val rdd = sc.parallelize(1 to rows)
+ .map(i => Vectors.sparse(dim, random.shuffle(0 to dim - 1).slice(0, nnz).sorted.toArray,
+ Array.fill(nnz)(random.nextDouble())))
+ .map(v => new TestRow(v))
+ spark.createDataFrame(rdd)
+ }
+
/**
* Mapping from all Params to valid settings which differ from the defaults.
* This is useful for tests which need to exercise all Params, such as save/load.