From 3f6296fed4ee10f53e728eb1e02f13338839b94d Mon Sep 17 00:00:00 2001 From: FlytxtRnD Date: Tue, 14 Jul 2015 23:29:02 -0700 Subject: [SPARK-8018] [MLLIB] KMeans should accept initial cluster centers as param This allows Kmeans to be initialized using an existing set of cluster centers provided as a KMeansModel object. This mode of initialization performs a single run. Author: FlytxtRnD Closes #6737 from FlytxtRnD/Kmeans-8018 and squashes the following commits: 94b56df [FlytxtRnD] style correction ef95ee2 [FlytxtRnD] style correction c446c58 [FlytxtRnD] documentation and numRuns warning change 06d13ef [FlytxtRnD] numRuns corrected d12336e [FlytxtRnD] numRuns variable modifications 07f8554 [FlytxtRnD] remove setRuns from setIntialModel e721dfe [FlytxtRnD] Merge remote-tracking branch 'upstream/master' into Kmeans-8018 242ead1 [FlytxtRnD] corrected == to === in assert 714acb5 [FlytxtRnD] added numRuns 60c8ce2 [FlytxtRnD] ignore runs parameter and initialModel test suite changed 582e6d9 [FlytxtRnD] Merge remote-tracking branch 'upstream/master' into Kmeans-8018 3f5fc8e [FlytxtRnD] test case modified and one runs condition added cd5dc5c [FlytxtRnD] Merge remote-tracking branch 'upstream/master' into Kmeans-8018 16f1b53 [FlytxtRnD] Merge branch 'Kmeans-8018', remote-tracking branch 'upstream/master' into Kmeans-8018 e9c35d7 [FlytxtRnD] Remove getInitialModel and match cluster count criteria 6959861 [FlytxtRnD] Accept initial cluster centers in KMeans --- docs/mllib-clustering.md | 1 + .../org/apache/spark/mllib/clustering/KMeans.scala | 41 ++++++++++++++++++---- .../spark/mllib/clustering/KMeansSuite.scala | 22 ++++++++++++ 3 files changed, 58 insertions(+), 6 deletions(-) diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index d72dc20a5a..0fc7036bff 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -33,6 +33,7 @@ guaranteed to find a globally optimal solution, and when run multiple times on a given dataset, the algorithm returns the best clustering result). * *initializationSteps* determines the number of steps in the k-means\|\| algorithm. * *epsilon* determines the distance threshold within which we consider k-means to have converged. +* *initialModel* is an optional set of cluster centers used for initialization. If this parameter is supplied, only one run is performed. **Examples** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 0f8d6a3996..68297130a7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -156,6 +156,21 @@ class KMeans private ( this } + // Initial cluster centers can be provided as a KMeansModel object rather than using the + // random or k-means|| initializationMode + private var initialModel: Option[KMeansModel] = None + + /** + * Set the initial starting point, bypassing the random initialization or k-means|| + * The condition model.k == this.k must be met, failure results + * in an IllegalArgumentException. + */ + def setInitialModel(model: KMeansModel): this.type = { + require(model.k == k, "mismatched cluster count") + initialModel = Some(model) + this + } + /** * Train a K-means model on the given set of points; `data` should be cached for high * performance, because this is an iterative algorithm. @@ -193,20 +208,34 @@ class KMeans private ( val initStartTime = System.nanoTime() - val centers = if (initializationMode == KMeans.RANDOM) { - initRandom(data) + // Only one run is allowed when initialModel is given + val numRuns = if (initialModel.nonEmpty) { + if (runs > 1) logWarning("Ignoring runs; one run is allowed when initialModel is given.") + 1 } else { - initKMeansParallel(data) + runs } + val centers = initialModel match { + case Some(kMeansCenters) => { + Array(kMeansCenters.clusterCenters.map(s => new VectorWithNorm(s))) + } + case None => { + if (initializationMode == KMeans.RANDOM) { + initRandom(data) + } else { + initKMeansParallel(data) + } + } + } val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9 logInfo(s"Initialization with $initializationMode took " + "%.3f".format(initTimeInSeconds) + " seconds.") - val active = Array.fill(runs)(true) - val costs = Array.fill(runs)(0.0) + val active = Array.fill(numRuns)(true) + val costs = Array.fill(numRuns)(0.0) - var activeRuns = new ArrayBuffer[Int] ++ (0 until runs) + var activeRuns = new ArrayBuffer[Int] ++ (0 until numRuns) var iteration = 0 val iterationStartTime = System.nanoTime() diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala index 0dbbd71274..3003c62d98 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -278,6 +278,28 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { } } } + + test("Initialize using given cluster centers") { + val points = Seq( + Vectors.dense(0.0, 0.0), + Vectors.dense(1.0, 0.0), + Vectors.dense(0.0, 1.0), + Vectors.dense(1.0, 1.0) + ) + val rdd = sc.parallelize(points, 3) + // creating an initial model + val initialModel = new KMeansModel(Array(points(0), points(2))) + + val returnModel = new KMeans() + .setK(2) + .setMaxIterations(0) + .setInitialModel(initialModel) + .run(rdd) + // comparing the returned model and the initial model + assert(returnModel.clusterCenters(0) === initialModel.clusterCenters(0)) + assert(returnModel.clusterCenters(1) === initialModel.clusterCenters(1)) + } + } object KMeansSuite extends SparkFunSuite { -- cgit v1.2.3