aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorfreeman <the.freeman.lab@gmail.com>2014-10-31 22:30:12 -0700
committerXiangrui Meng <meng@databricks.com>2014-10-31 22:30:12 -0700
commit98c556ebbca6a815813daaefd292d2e46fb16cc2 (patch)
tree99433d41db58d784b1a5bba9b76c777a70494fa3 /mllib
parent8602195510f5821b37746bb7fa24902f43a1bd93 (diff)
downloadspark-98c556ebbca6a815813daaefd292d2e46fb16cc2.tar.gz
spark-98c556ebbca6a815813daaefd292d2e46fb16cc2.tar.bz2
spark-98c556ebbca6a815813daaefd292d2e46fb16cc2.zip
Streaming KMeans [MLLIB][SPARK-3254]
This adds a Streaming KMeans algorithm to MLlib. It uses an update rule that generalizes the mini-batch KMeans update to incorporate a decay factor, which allows past data to be forgotten. The decay factor can be specified explicitly, or via a more intuitive "fractional decay" setting, in units of either data points or batches. The PR includes: - StreamingKMeans algorithm with decay factor settings - Usage example - Additions to documentation clustering page - Unit tests of basic behavior and decay behaviors tdas mengxr rezazadeh Author: freeman <the.freeman.lab@gmail.com> Author: Jeremy Freeman <the.freeman.lab@gmail.com> Author: Xiangrui Meng <meng@databricks.com> Closes #2942 from freeman-lab/streaming-kmeans and squashes the following commits: b2e5b4a [freeman] Fixes to docs / examples 078617c [Jeremy Freeman] Merge pull request #1 from mengxr/SPARK-3254 2e682c0 [Xiangrui Meng] take discount on previous weights; use BLAS; detect dying clusters 0411bf5 [freeman] Change decay parameterization 9f7aea9 [freeman] Style fixes 374a706 [freeman] Formatting ad9bdc2 [freeman] Use labeled points and predictOnValues in examples 77dbd3f [freeman] Make initialization check an assertion 9cfc301 [freeman] Make random seed an argument 44050a9 [freeman] Simpler constructor c7050d5 [freeman] Fix spacing 2899623 [freeman] Use pattern matching for clarity a4a316b [freeman] Use collect 1472ec5 [freeman] Doc formatting ea22ec8 [freeman] Fix imports 2086bdc [freeman] Log cluster center updates ea9877c [freeman] More documentation 9facbe3 [freeman] Bug fix 5db7074 [freeman] Example usage for StreamingKMeans f33684b [freeman] Add explanation and example to docs b5b5f8d [freeman] Add better documentation a0fd790 [freeman] Merge remote-tracking branch 'upstream/master' into streaming-kmeans 9fd9c15 [freeman] Merge remote-tracking branch 'upstream/master' into streaming-kmeans b93350f [freeman] Streaming KMeans with decay
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala268
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala157
2 files changed, 425 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
new file mode 100644
index 0000000000..6189dce9b2
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
@@ -0,0 +1,268 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.clustering
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.Logging
+import org.apache.spark.SparkContext._
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.streaming.StreamingContext._
+import org.apache.spark.streaming.dstream.DStream
+import org.apache.spark.util.Utils
+import org.apache.spark.util.random.XORShiftRandom
+
+/**
+ * :: DeveloperApi ::
+ * StreamingKMeansModel extends MLlib's KMeansModel for streaming
+ * algorithms, so it can keep track of a continuously updated weight
+ * associated with each cluster, and also update the model by
+ * doing a single iteration of the standard k-means algorithm.
+ *
+ * The update algorithm uses the "mini-batch" KMeans rule,
+ * generalized to incorporate forgetfullness (i.e. decay).
+ * The update rule (for each cluster) is:
+ *
+ * c_t+1 = [(c_t * n_t * a) + (x_t * m_t)] / [n_t + m_t]
+ * n_t+t = n_t * a + m_t
+ *
+ * Where c_t is the previously estimated centroid for that cluster,
+ * n_t is the number of points assigned to it thus far, x_t is the centroid
+ * estimated on the current batch, and m_t is the number of points assigned
+ * to that centroid in the current batch.
+ *
+ * The decay factor 'a' scales the contribution of the clusters as estimated thus far,
+ * by applying a as a discount weighting on the current point when evaluating
+ * new incoming data. If a=1, all batches are weighted equally. If a=0, new centroids
+ * are determined entirely by recent data. Lower values correspond to
+ * more forgetting.
+ *
+ * Decay can optionally be specified by a half life and associated
+ * time unit. The time unit can either be a batch of data or a single
+ * data point. Considering data arrived at time t, the half life h is defined
+ * such that at time t + h the discount applied to the data from t is 0.5.
+ * The definition remains the same whether the time unit is given
+ * as batches or points.
+ *
+ */
+@DeveloperApi
+class StreamingKMeansModel(
+ override val clusterCenters: Array[Vector],
+ val clusterWeights: Array[Double]) extends KMeansModel(clusterCenters) with Logging {
+
+ /** Perform a k-means update on a batch of data. */
+ def update(data: RDD[Vector], decayFactor: Double, timeUnit: String): StreamingKMeansModel = {
+
+ // find nearest cluster to each point
+ val closest = data.map(point => (this.predict(point), (point, 1L)))
+
+ // get sums and counts for updating each cluster
+ val mergeContribs: ((Vector, Long), (Vector, Long)) => (Vector, Long) = (p1, p2) => {
+ BLAS.axpy(1.0, p2._1, p1._1)
+ (p1._1, p1._2 + p2._2)
+ }
+ val dim = clusterCenters(0).size
+ val pointStats: Array[(Int, (Vector, Long))] = closest
+ .aggregateByKey((Vectors.zeros(dim), 0L))(mergeContribs, mergeContribs)
+ .collect()
+
+ val discount = timeUnit match {
+ case StreamingKMeans.BATCHES => decayFactor
+ case StreamingKMeans.POINTS =>
+ val numNewPoints = pointStats.view.map { case (_, (_, n)) =>
+ n
+ }.sum
+ math.pow(decayFactor, numNewPoints)
+ }
+
+ // apply discount to weights
+ BLAS.scal(discount, Vectors.dense(clusterWeights))
+
+ // implement update rule
+ pointStats.foreach { case (label, (sum, count)) =>
+ val centroid = clusterCenters(label)
+
+ val updatedWeight = clusterWeights(label) + count
+ val lambda = count / math.max(updatedWeight, 1e-16)
+
+ clusterWeights(label) = updatedWeight
+ BLAS.scal(1.0 - lambda, centroid)
+ BLAS.axpy(lambda / count, sum, centroid)
+
+ // display the updated cluster centers
+ val display = clusterCenters(label).size match {
+ case x if x > 100 => centroid.toArray.take(100).mkString("[", ",", "...")
+ case _ => centroid.toArray.mkString("[", ",", "]")
+ }
+
+ logInfo(s"Cluster $label updated with weight $updatedWeight and centroid: $display")
+ }
+
+ // Check whether the smallest cluster is dying. If so, split the largest cluster.
+ val weightsWithIndex = clusterWeights.view.zipWithIndex
+ val (maxWeight, largest) = weightsWithIndex.maxBy(_._1)
+ val (minWeight, smallest) = weightsWithIndex.minBy(_._1)
+ if (minWeight < 1e-8 * maxWeight) {
+ logInfo(s"Cluster $smallest is dying. Split the largest cluster $largest into two.")
+ val weight = (maxWeight + minWeight) / 2.0
+ clusterWeights(largest) = weight
+ clusterWeights(smallest) = weight
+ val largestClusterCenter = clusterCenters(largest)
+ val smallestClusterCenter = clusterCenters(smallest)
+ var j = 0
+ while (j < dim) {
+ val x = largestClusterCenter(j)
+ val p = 1e-14 * math.max(math.abs(x), 1.0)
+ largestClusterCenter.toBreeze(j) = x + p
+ smallestClusterCenter.toBreeze(j) = x - p
+ j += 1
+ }
+ }
+
+ this
+ }
+}
+
+/**
+ * :: DeveloperApi ::
+ * StreamingKMeans provides methods for configuring a
+ * streaming k-means analysis, training the model on streaming,
+ * and using the model to make predictions on streaming data.
+ * See KMeansModel for details on algorithm and update rules.
+ *
+ * Use a builder pattern to construct a streaming k-means analysis
+ * in an application, like:
+ *
+ * val model = new StreamingKMeans()
+ * .setDecayFactor(0.5)
+ * .setK(3)
+ * .setRandomCenters(5, 100.0)
+ * .trainOn(DStream)
+ */
+@DeveloperApi
+class StreamingKMeans(
+ var k: Int,
+ var decayFactor: Double,
+ var timeUnit: String) extends Logging {
+
+ def this() = this(2, 1.0, StreamingKMeans.BATCHES)
+
+ protected var model: StreamingKMeansModel = new StreamingKMeansModel(null, null)
+
+ /** Set the number of clusters. */
+ def setK(k: Int): this.type = {
+ this.k = k
+ this
+ }
+
+ /** Set the decay factor directly (for forgetful algorithms). */
+ def setDecayFactor(a: Double): this.type = {
+ this.decayFactor = decayFactor
+ this
+ }
+
+ /** Set the half life and time unit ("batches" or "points") for forgetful algorithms. */
+ def setHalfLife(halfLife: Double, timeUnit: String): this.type = {
+ if (timeUnit != StreamingKMeans.BATCHES && timeUnit != StreamingKMeans.POINTS) {
+ throw new IllegalArgumentException("Invalid time unit for decay: " + timeUnit)
+ }
+ this.decayFactor = math.exp(math.log(0.5) / halfLife)
+ logInfo("Setting decay factor to: %g ".format (this.decayFactor))
+ this.timeUnit = timeUnit
+ this
+ }
+
+ /** Specify initial centers directly. */
+ def setInitialCenters(centers: Array[Vector], weights: Array[Double]): this.type = {
+ model = new StreamingKMeansModel(centers, weights)
+ this
+ }
+
+ /**
+ * Initialize random centers, requiring only the number of dimensions.
+ *
+ * @param dim Number of dimensions
+ * @param weight Weight for each center
+ * @param seed Random seed
+ */
+ def setRandomCenters(dim: Int, weight: Double, seed: Long = Utils.random.nextLong): this.type = {
+ val random = new XORShiftRandom(seed)
+ val centers = Array.fill(k)(Vectors.dense(Array.fill(dim)(random.nextGaussian())))
+ val weights = Array.fill(k)(weight)
+ model = new StreamingKMeansModel(centers, weights)
+ this
+ }
+
+ /** Return the latest model. */
+ def latestModel(): StreamingKMeansModel = {
+ model
+ }
+
+ /**
+ * Update the clustering model by training on batches of data from a DStream.
+ * This operation registers a DStream for training the model,
+ * checks whether the cluster centers have been initialized,
+ * and updates the model using each batch of data from the stream.
+ *
+ * @param data DStream containing vector data
+ */
+ def trainOn(data: DStream[Vector]) {
+ assertInitialized()
+ data.foreachRDD { (rdd, time) =>
+ model = model.update(rdd, decayFactor, timeUnit)
+ }
+ }
+
+ /**
+ * Use the clustering model to make predictions on batches of data from a DStream.
+ *
+ * @param data DStream containing vector data
+ * @return DStream containing predictions
+ */
+ def predictOn(data: DStream[Vector]): DStream[Int] = {
+ assertInitialized()
+ data.map(model.predict)
+ }
+
+ /**
+ * Use the model to make predictions on the values of a DStream and carry over its keys.
+ *
+ * @param data DStream containing (key, feature vector) pairs
+ * @tparam K key type
+ * @return DStream containing the input keys and the predictions as values
+ */
+ def predictOnValues[K: ClassTag](data: DStream[(K, Vector)]): DStream[(K, Int)] = {
+ assertInitialized()
+ data.mapValues(model.predict)
+ }
+
+ /** Check whether cluster centers have been initialized. */
+ private[this] def assertInitialized(): Unit = {
+ if (model.clusterCenters == null) {
+ throw new IllegalStateException(
+ "Initial cluster centers must be set before starting predictions")
+ }
+ }
+}
+
+private[clustering] object StreamingKMeans {
+ final val BATCHES = "batches"
+ final val POINTS = "points"
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala
new file mode 100644
index 0000000000..850c9fce50
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.clustering
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.streaming.TestSuiteBase
+import org.apache.spark.streaming.dstream.DStream
+import org.apache.spark.util.random.XORShiftRandom
+
+class StreamingKMeansSuite extends FunSuite with TestSuiteBase {
+
+ override def maxWaitTimeMillis = 30000
+
+ test("accuracy for single center and equivalence to grand average") {
+ // set parameters
+ val numBatches = 10
+ val numPoints = 50
+ val k = 1
+ val d = 5
+ val r = 0.1
+
+ // create model with one cluster
+ val model = new StreamingKMeans()
+ .setK(1)
+ .setDecayFactor(1.0)
+ .setInitialCenters(Array(Vectors.dense(0.0, 0.0, 0.0, 0.0, 0.0)), Array(0.0))
+
+ // generate random data for k-means
+ val (input, centers) = StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42)
+
+ // setup and run the model training
+ val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => {
+ model.trainOn(inputDStream)
+ inputDStream.count()
+ })
+ runStreams(ssc, numBatches, numBatches)
+
+ // estimated center should be close to true center
+ assert(centers(0) ~== model.latestModel().clusterCenters(0) absTol 1E-1)
+
+ // estimated center from streaming should exactly match the arithmetic mean of all data points
+ // because the decay factor is set to 1.0
+ val grandMean =
+ input.flatten.map(x => x.toBreeze).reduce(_+_) / (numBatches * numPoints).toDouble
+ assert(model.latestModel().clusterCenters(0) ~== Vectors.dense(grandMean.toArray) absTol 1E-5)
+ }
+
+ test("accuracy for two centers") {
+ val numBatches = 10
+ val numPoints = 5
+ val k = 2
+ val d = 5
+ val r = 0.1
+
+ // create model with two clusters
+ val kMeans = new StreamingKMeans()
+ .setK(2)
+ .setHalfLife(2, "batches")
+ .setInitialCenters(
+ Array(Vectors.dense(-0.1, 0.1, -0.2, -0.3, -0.1),
+ Vectors.dense(0.1, -0.2, 0.0, 0.2, 0.1)),
+ Array(5.0, 5.0))
+
+ // generate random data for k-means
+ val (input, centers) = StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42)
+
+ // setup and run the model training
+ val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => {
+ kMeans.trainOn(inputDStream)
+ inputDStream.count()
+ })
+ runStreams(ssc, numBatches, numBatches)
+
+ // check that estimated centers are close to true centers
+ // NOTE exact assignment depends on the initialization!
+ assert(centers(0) ~== kMeans.latestModel().clusterCenters(0) absTol 1E-1)
+ assert(centers(1) ~== kMeans.latestModel().clusterCenters(1) absTol 1E-1)
+ }
+
+ test("detecting dying clusters") {
+ val numBatches = 10
+ val numPoints = 5
+ val k = 1
+ val d = 1
+ val r = 1.0
+
+ // create model with two clusters
+ val kMeans = new StreamingKMeans()
+ .setK(2)
+ .setHalfLife(0.5, "points")
+ .setInitialCenters(
+ Array(Vectors.dense(0.0), Vectors.dense(1000.0)),
+ Array(1.0, 1.0))
+
+ // new data are all around the first cluster 0.0
+ val (input, _) =
+ StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42, Array(Vectors.dense(0.0)))
+
+ // setup and run the model training
+ val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => {
+ kMeans.trainOn(inputDStream)
+ inputDStream.count()
+ })
+ runStreams(ssc, numBatches, numBatches)
+
+ // check that estimated centers are close to true centers
+ // NOTE exact assignment depends on the initialization!
+ val model = kMeans.latestModel()
+ val c0 = model.clusterCenters(0)(0)
+ val c1 = model.clusterCenters(1)(0)
+
+ assert(c0 * c1 < 0.0, "should have one positive center and one negative center")
+ // 0.8 is the mean of half-normal distribution
+ assert(math.abs(c0) ~== 0.8 absTol 0.6)
+ assert(math.abs(c1) ~== 0.8 absTol 0.6)
+ }
+
+ def StreamingKMeansDataGenerator(
+ numPoints: Int,
+ numBatches: Int,
+ k: Int,
+ d: Int,
+ r: Double,
+ seed: Int,
+ initCenters: Array[Vector] = null): (IndexedSeq[IndexedSeq[Vector]], Array[Vector]) = {
+ val rand = new XORShiftRandom(seed)
+ val centers = initCenters match {
+ case null => Array.fill(k)(Vectors.dense(Array.fill(d)(rand.nextGaussian())))
+ case _ => initCenters
+ }
+ val data = (0 until numBatches).map { i =>
+ (0 until numPoints).map { idx =>
+ val center = centers(idx % k)
+ Vectors.dense(Array.tabulate(d)(x => center(x) + rand.nextGaussian() * r))
+ }
+ }
+ (data, centers)
+ }
+}