aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorSean Owen <sowen@cloudera.com>2016-10-30 09:36:23 +0000
committerSean Owen <sowen@cloudera.com>2016-10-30 09:36:23 +0000
commita489567e36e671cee290f8d69188837a8b1a75b3 (patch)
treefa5b4015686d4ffbf79daefa813ff75f30640c69 /mllib
parent505b927cb7ff037adb797b9c3b9ecac3f885b7c8 (diff)
downloadspark-a489567e36e671cee290f8d69188837a8b1a75b3.tar.gz
spark-a489567e36e671cee290f8d69188837a8b1a75b3.tar.bz2
spark-a489567e36e671cee290f8d69188837a8b1a75b3.zip
[SPARK-3261][MLLIB] KMeans clusterer can return duplicate cluster centers
## What changes were proposed in this pull request? Return potentially fewer than k cluster centers in cases where k distinct centroids aren't available or aren't selected. ## How was this patch tested? Existing tests Author: Sean Owen <sowen@cloudera.com> Closes #15450 from srowen/SPARK-3261.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala27
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala119
3 files changed, 85 insertions, 65 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index 05ed3223ae..85bb8c93b3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -41,7 +41,9 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
with HasSeed with HasPredictionCol with HasTol {
/**
- * The number of clusters to create (k). Must be > 1. Default: 2.
+ * The number of clusters to create (k). Must be > 1. Note that it is possible for fewer than
+ * k clusters to be returned, for example, if there are fewer than k distinct points to cluster.
+ * Default: 2.
* @group param
*/
@Since("1.5.0")
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index 68a7b3b676..ed9c064879 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -56,13 +56,15 @@ class KMeans private (
def this() = this(2, 20, KMeans.K_MEANS_PARALLEL, 2, 1e-4, Utils.random.nextLong())
/**
- * Number of clusters to create (k).
+ * Number of clusters to create (k). Note that it is possible for fewer than k clusters to
+ * be returned, for example, if there are fewer than k distinct points to cluster.
*/
@Since("1.4.0")
def getK: Int = k
/**
- * Set the number of clusters to create (k). Default: 2.
+ * Set the number of clusters to create (k). Note that it is possible for fewer than k clusters to
+ * be returned, for example, if there are fewer than k distinct points to cluster. Default: 2.
*/
@Since("0.8.0")
def setK(k: Int): this.type = {
@@ -323,7 +325,10 @@ class KMeans private (
* Initialize a set of cluster centers at random.
*/
private def initRandom(data: RDD[VectorWithNorm]): Array[VectorWithNorm] = {
- data.takeSample(true, k, new XORShiftRandom(this.seed).nextInt()).map(_.toDense)
+ // Select without replacement; may still produce duplicates if the data has < k distinct
+ // points, so deduplicate the centroids to match the behavior of k-means|| in the same situation
+ data.takeSample(false, k, new XORShiftRandom(this.seed).nextInt())
+ .map(_.vector).distinct.map(new VectorWithNorm(_))
}
/**
@@ -335,7 +340,7 @@ class KMeans private (
*
* The original paper can be found at http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf.
*/
- private def initKMeansParallel(data: RDD[VectorWithNorm]): Array[VectorWithNorm] = {
+ private[clustering] def initKMeansParallel(data: RDD[VectorWithNorm]): Array[VectorWithNorm] = {
// Initialize empty centers and point costs.
var costs = data.map(_ => Double.PositiveInfinity)
@@ -378,19 +383,21 @@ class KMeans private (
costs.unpersist(blocking = false)
bcNewCentersList.foreach(_.destroy(false))
- if (centers.size == k) {
- centers.toArray
+ val distinctCenters = centers.map(_.vector).distinct.map(new VectorWithNorm(_))
+
+ if (distinctCenters.size <= k) {
+ distinctCenters.toArray
} else {
- // Finally, we might have a set of more or less than k candidate centers; weight each
+ // Finally, we might have a set of more than k distinct candidate centers; weight each
// candidate by the number of points in the dataset mapping to it and run a local k-means++
// on the weighted centers to pick k of them
- val bcCenters = data.context.broadcast(centers)
+ val bcCenters = data.context.broadcast(distinctCenters)
val countMap = data.map(KMeans.findClosest(bcCenters.value, _)._1).countByValue()
bcCenters.destroy(blocking = false)
- val myWeights = centers.indices.map(countMap.getOrElse(_, 0L).toDouble).toArray
- LocalKMeans.kMeansPlusPlus(0, centers.toArray, myWeights, k, 30)
+ val myWeights = distinctCenters.indices.map(countMap.getOrElse(_, 0L).toDouble).toArray
+ LocalKMeans.kMeansPlusPlus(0, distinctCenters.toArray, myWeights, k, 30)
}
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
index 2d35b31208..48bd41dc3e 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
@@ -29,6 +29,8 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
import org.apache.spark.mllib.clustering.KMeans.{K_MEANS_PARALLEL, RANDOM}
+ private val seed = 42
+
test("single cluster") {
val data = sc.parallelize(Array(
Vectors.dense(1.0, 2.0, 6.0),
@@ -38,7 +40,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
val center = Vectors.dense(1.0, 3.0, 4.0)
- // No matter how many runs or iterations we use, we should get one cluster,
+ // No matter how many iterations we use, we should get one cluster,
// centered at the mean of the points
var model = KMeans.train(data, k = 1, maxIterations = 1)
@@ -50,44 +52,72 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
model = KMeans.train(data, k = 1, maxIterations = 5)
assert(model.clusterCenters.head ~== center absTol 1E-5)
- model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5)
- assert(model.clusterCenters.head ~== center absTol 1E-5)
-
- model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5)
- assert(model.clusterCenters.head ~== center absTol 1E-5)
-
- model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = RANDOM)
+ model = KMeans.train(data, k = 1, maxIterations = 1, initializationMode = RANDOM)
assert(model.clusterCenters.head ~== center absTol 1E-5)
model = KMeans.train(
- data, k = 1, maxIterations = 1, runs = 1, initializationMode = K_MEANS_PARALLEL)
+ data, k = 1, maxIterations = 1, initializationMode = K_MEANS_PARALLEL)
assert(model.clusterCenters.head ~== center absTol 1E-5)
}
- test("no distinct points") {
+ test("fewer distinct points than clusters") {
val data = sc.parallelize(
Array(
Vectors.dense(1.0, 2.0, 3.0),
Vectors.dense(1.0, 2.0, 3.0),
Vectors.dense(1.0, 2.0, 3.0)),
2)
- val center = Vectors.dense(1.0, 2.0, 3.0)
- // Make sure code runs.
- var model = KMeans.train(data, k = 2, maxIterations = 1)
- assert(model.clusterCenters.size === 2)
- }
+ var model = KMeans.train(data, k = 2, maxIterations = 1, initializationMode = "random")
+ assert(model.clusterCenters.length === 1)
- test("more clusters than points") {
- val data = sc.parallelize(
- Array(
- Vectors.dense(1.0, 2.0, 3.0),
- Vectors.dense(1.0, 3.0, 4.0)),
- 2)
+ model = KMeans.train(data, k = 2, maxIterations = 1, initializationMode = "k-means||")
+ assert(model.clusterCenters.length === 1)
+ }
- // Make sure code runs.
- var model = KMeans.train(data, k = 3, maxIterations = 1)
- assert(model.clusterCenters.size === 3)
+ test("unique cluster centers") {
+ val rng = new Random(seed)
+ val numDistinctPoints = 10
+ val points = (0 until numDistinctPoints).map(i => Vectors.dense(Array.fill(3)(rng.nextDouble)))
+ val data = sc.parallelize(points.flatMap(Array.fill(1 + rng.nextInt(3))(_)), 2)
+ val normedData = data.map(new VectorWithNorm(_))
+
+ // less centers than k
+ val km = new KMeans().setK(50)
+ .setMaxIterations(5)
+ .setInitializationMode("k-means||")
+ .setInitializationSteps(10)
+ .setSeed(seed)
+ val initialCenters = km.initKMeansParallel(normedData).map(_.vector)
+ assert(initialCenters.length === initialCenters.distinct.length)
+ assert(initialCenters.length <= numDistinctPoints)
+
+ val model = km.run(data)
+ val finalCenters = model.clusterCenters
+ assert(finalCenters.length === finalCenters.distinct.length)
+
+ // run local k-means
+ val k = 10
+ val km2 = new KMeans().setK(k)
+ .setMaxIterations(5)
+ .setInitializationMode("k-means||")
+ .setInitializationSteps(10)
+ .setSeed(seed)
+ val initialCenters2 = km2.initKMeansParallel(normedData).map(_.vector)
+ assert(initialCenters2.length === initialCenters2.distinct.length)
+ assert(initialCenters2.length === k)
+
+ val model2 = km2.run(data)
+ val finalCenters2 = model2.clusterCenters
+ assert(finalCenters2.length === finalCenters2.distinct.length)
+
+ val km3 = new KMeans().setK(k)
+ .setMaxIterations(5)
+ .setInitializationMode("random")
+ .setSeed(seed)
+ val model3 = km3.run(data)
+ val finalCenters3 = model3.clusterCenters
+ assert(finalCenters3.length === finalCenters3.distinct.length)
}
test("deterministic initialization") {
@@ -97,12 +127,12 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
for (initMode <- Seq(RANDOM, K_MEANS_PARALLEL)) {
// Create three deterministic models and compare cluster means
- val model1 = KMeans.train(rdd, k = 10, maxIterations = 2, runs = 1,
- initializationMode = initMode, seed = 42)
+ val model1 = KMeans.train(rdd, k = 10, maxIterations = 2,
+ initializationMode = initMode, seed = seed)
val centers1 = model1.clusterCenters
- val model2 = KMeans.train(rdd, k = 10, maxIterations = 2, runs = 1,
- initializationMode = initMode, seed = 42)
+ val model2 = KMeans.train(rdd, k = 10, maxIterations = 2,
+ initializationMode = initMode, seed = seed)
val centers2 = model2.clusterCenters
centers1.zip(centers2).foreach { case (c1, c2) =>
@@ -119,7 +149,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
)
val data = sc.parallelize((1 to 100).flatMap(_ => smallData), 4)
- // No matter how many runs or iterations we use, we should get one cluster,
+ // No matter how many iterations we use, we should get one cluster,
// centered at the mean of the points
val center = Vectors.dense(1.0, 3.0, 4.0)
@@ -134,17 +164,10 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
model = KMeans.train(data, k = 1, maxIterations = 5)
assert(model.clusterCenters.head ~== center absTol 1E-5)
- model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5)
+ model = KMeans.train(data, k = 1, maxIterations = 1, initializationMode = RANDOM)
assert(model.clusterCenters.head ~== center absTol 1E-5)
- model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5)
- assert(model.clusterCenters.head ~== center absTol 1E-5)
-
- model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = RANDOM)
- assert(model.clusterCenters.head ~== center absTol 1E-5)
-
- model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1,
- initializationMode = K_MEANS_PARALLEL)
+ model = KMeans.train(data, k = 1, maxIterations = 1, initializationMode = K_MEANS_PARALLEL)
assert(model.clusterCenters.head ~== center absTol 1E-5)
}
@@ -165,7 +188,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
data.persist()
- // No matter how many runs or iterations we use, we should get one cluster,
+ // No matter how many iterations we use, we should get one cluster,
// centered at the mean of the points
val center = Vectors.sparse(n, Seq((0, 1.0), (1, 3.0), (2, 4.0)))
@@ -179,17 +202,10 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
model = KMeans.train(data, k = 1, maxIterations = 5)
assert(model.clusterCenters.head ~== center absTol 1E-5)
- model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5)
- assert(model.clusterCenters.head ~== center absTol 1E-5)
-
- model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5)
+ model = KMeans.train(data, k = 1, maxIterations = 1, initializationMode = RANDOM)
assert(model.clusterCenters.head ~== center absTol 1E-5)
- model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = RANDOM)
- assert(model.clusterCenters.head ~== center absTol 1E-5)
-
- model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1,
- initializationMode = K_MEANS_PARALLEL)
+ model = KMeans.train(data, k = 1, maxIterations = 1, initializationMode = K_MEANS_PARALLEL)
assert(model.clusterCenters.head ~== center absTol 1E-5)
data.unpersist()
@@ -230,11 +246,6 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
model = KMeans.train(rdd, k = 5, maxIterations = 10)
assert(model.clusterCenters.sortBy(VectorWithCompare(_))
.zip(points.sortBy(VectorWithCompare(_))).forall(x => x._1 ~== (x._2) absTol 1E-5))
-
- // Neither should more runs
- model = KMeans.train(rdd, k = 5, maxIterations = 10, runs = 5)
- assert(model.clusterCenters.sortBy(VectorWithCompare(_))
- .zip(points.sortBy(VectorWithCompare(_))).forall(x => x._1 ~== (x._2) absTol 1E-5))
}
test("two clusters") {
@@ -250,7 +261,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
for (initMode <- Seq(RANDOM, K_MEANS_PARALLEL)) {
// Two iterations are sufficient no matter where the initial centers are.
- val model = KMeans.train(rdd, k = 2, maxIterations = 2, runs = 1, initMode)
+ val model = KMeans.train(rdd, k = 2, maxIterations = 2, initMode)
val predicts = model.predict(rdd).collect()