aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorSean Owen <sowen@cloudera.com>2016-10-30 09:36:23 +0000
committerSean Owen <sowen@cloudera.com>2016-10-30 09:36:23 +0000
commita489567e36e671cee290f8d69188837a8b1a75b3 (patch)
treefa5b4015686d4ffbf79daefa813ff75f30640c69 /mllib/src/main
parent505b927cb7ff037adb797b9c3b9ecac3f885b7c8 (diff)
downloadspark-a489567e36e671cee290f8d69188837a8b1a75b3.tar.gz
spark-a489567e36e671cee290f8d69188837a8b1a75b3.tar.bz2
spark-a489567e36e671cee290f8d69188837a8b1a75b3.zip
[SPARK-3261][MLLIB] KMeans clusterer can return duplicate cluster centers
## What changes were proposed in this pull request? Return potentially fewer than k cluster centers in cases where k distinct centroids aren't available or aren't selected. ## How was this patch tested? Existing tests Author: Sean Owen <sowen@cloudera.com> Closes #15450 from srowen/SPARK-3261.
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala27
2 files changed, 20 insertions, 11 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index 05ed3223ae..85bb8c93b3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -41,7 +41,9 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
with HasSeed with HasPredictionCol with HasTol {
/**
- * The number of clusters to create (k). Must be > 1. Default: 2.
+ * The number of clusters to create (k). Must be > 1. Note that it is possible for fewer than
+ * k clusters to be returned, for example, if there are fewer than k distinct points to cluster.
+ * Default: 2.
* @group param
*/
@Since("1.5.0")
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index 68a7b3b676..ed9c064879 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -56,13 +56,15 @@ class KMeans private (
def this() = this(2, 20, KMeans.K_MEANS_PARALLEL, 2, 1e-4, Utils.random.nextLong())
/**
- * Number of clusters to create (k).
+ * Number of clusters to create (k). Note that it is possible for fewer than k clusters to
+ * be returned, for example, if there are fewer than k distinct points to cluster.
*/
@Since("1.4.0")
def getK: Int = k
/**
- * Set the number of clusters to create (k). Default: 2.
+ * Set the number of clusters to create (k). Note that it is possible for fewer than k clusters to
+ * be returned, for example, if there are fewer than k distinct points to cluster. Default: 2.
*/
@Since("0.8.0")
def setK(k: Int): this.type = {
@@ -323,7 +325,10 @@ class KMeans private (
* Initialize a set of cluster centers at random.
*/
private def initRandom(data: RDD[VectorWithNorm]): Array[VectorWithNorm] = {
- data.takeSample(true, k, new XORShiftRandom(this.seed).nextInt()).map(_.toDense)
+ // Select without replacement; may still produce duplicates if the data has < k distinct
+ // points, so deduplicate the centroids to match the behavior of k-means|| in the same situation
+ data.takeSample(false, k, new XORShiftRandom(this.seed).nextInt())
+ .map(_.vector).distinct.map(new VectorWithNorm(_))
}
/**
@@ -335,7 +340,7 @@ class KMeans private (
*
* The original paper can be found at http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf.
*/
- private def initKMeansParallel(data: RDD[VectorWithNorm]): Array[VectorWithNorm] = {
+ private[clustering] def initKMeansParallel(data: RDD[VectorWithNorm]): Array[VectorWithNorm] = {
// Initialize empty centers and point costs.
var costs = data.map(_ => Double.PositiveInfinity)
@@ -378,19 +383,21 @@ class KMeans private (
costs.unpersist(blocking = false)
bcNewCentersList.foreach(_.destroy(false))
- if (centers.size == k) {
- centers.toArray
+ val distinctCenters = centers.map(_.vector).distinct.map(new VectorWithNorm(_))
+
+ if (distinctCenters.size <= k) {
+ distinctCenters.toArray
} else {
- // Finally, we might have a set of more or less than k candidate centers; weight each
+ // Finally, we might have a set of more than k distinct candidate centers; weight each
// candidate by the number of points in the dataset mapping to it and run a local k-means++
// on the weighted centers to pick k of them
- val bcCenters = data.context.broadcast(centers)
+ val bcCenters = data.context.broadcast(distinctCenters)
val countMap = data.map(KMeans.findClosest(bcCenters.value, _)._1).countByValue()
bcCenters.destroy(blocking = false)
- val myWeights = centers.indices.map(countMap.getOrElse(_, 0L).toDouble).toArray
- LocalKMeans.kMeansPlusPlus(0, centers.toArray, myWeights, k, 30)
+ val myWeights = distinctCenters.indices.map(countMap.getOrElse(_, 0L).toDouble).toArray
+ LocalKMeans.kMeansPlusPlus(0, distinctCenters.toArray, myWeights, k, 30)
}
}
}