aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--docs/mllib-clustering.md1
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala41
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala22
3 files changed, 58 insertions, 6 deletions
diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md
index d72dc20a5a..0fc7036bff 100644
--- a/docs/mllib-clustering.md
+++ b/docs/mllib-clustering.md
@@ -33,6 +33,7 @@ guaranteed to find a globally optimal solution, and when run multiple times on
a given dataset, the algorithm returns the best clustering result).
* *initializationSteps* determines the number of steps in the k-means\|\| algorithm.
* *epsilon* determines the distance threshold within which we consider k-means to have converged.
+* *initialModel* is an optional set of cluster centers used for initialization. If this parameter is supplied, only one run is performed.
**Examples**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index 0f8d6a3996..68297130a7 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -156,6 +156,21 @@ class KMeans private (
this
}
+ // Initial cluster centers can be provided as a KMeansModel object rather than using the
+ // random or k-means|| initializationMode
+ private var initialModel: Option[KMeansModel] = None
+
+ /**
+ * Set the initial starting point, bypassing the random initialization or k-means||
+ * The condition model.k == this.k must be met, failure results
+ * in an IllegalArgumentException.
+ */
+ def setInitialModel(model: KMeansModel): this.type = {
+ require(model.k == k, "mismatched cluster count")
+ initialModel = Some(model)
+ this
+ }
+
/**
* Train a K-means model on the given set of points; `data` should be cached for high
* performance, because this is an iterative algorithm.
@@ -193,20 +208,34 @@ class KMeans private (
val initStartTime = System.nanoTime()
- val centers = if (initializationMode == KMeans.RANDOM) {
- initRandom(data)
+ // Only one run is allowed when initialModel is given
+ val numRuns = if (initialModel.nonEmpty) {
+ if (runs > 1) logWarning("Ignoring runs; one run is allowed when initialModel is given.")
+ 1
} else {
- initKMeansParallel(data)
+ runs
}
+ val centers = initialModel match {
+ case Some(kMeansCenters) => {
+ Array(kMeansCenters.clusterCenters.map(s => new VectorWithNorm(s)))
+ }
+ case None => {
+ if (initializationMode == KMeans.RANDOM) {
+ initRandom(data)
+ } else {
+ initKMeansParallel(data)
+ }
+ }
+ }
val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9
logInfo(s"Initialization with $initializationMode took " + "%.3f".format(initTimeInSeconds) +
" seconds.")
- val active = Array.fill(runs)(true)
- val costs = Array.fill(runs)(0.0)
+ val active = Array.fill(numRuns)(true)
+ val costs = Array.fill(numRuns)(0.0)
- var activeRuns = new ArrayBuffer[Int] ++ (0 until runs)
+ var activeRuns = new ArrayBuffer[Int] ++ (0 until numRuns)
var iteration = 0
val iterationStartTime = System.nanoTime()
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
index 0dbbd71274..3003c62d98 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
@@ -278,6 +278,28 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
}
}
}
+
+ test("Initialize using given cluster centers") {
+ val points = Seq(
+ Vectors.dense(0.0, 0.0),
+ Vectors.dense(1.0, 0.0),
+ Vectors.dense(0.0, 1.0),
+ Vectors.dense(1.0, 1.0)
+ )
+ val rdd = sc.parallelize(points, 3)
+ // creating an initial model
+ val initialModel = new KMeansModel(Array(points(0), points(2)))
+
+ val returnModel = new KMeans()
+ .setK(2)
+ .setMaxIterations(0)
+ .setInitialModel(initialModel)
+ .run(rdd)
+ // comparing the returned model and the initial model
+ assert(returnModel.clusterCenters(0) === initialModel.clusterCenters(0))
+ assert(returnModel.clusterCenters(1) === initialModel.clusterCenters(1))
+ }
+
}
object KMeansSuite extends SparkFunSuite {