From 98c556ebbca6a815813daaefd292d2e46fb16cc2 Mon Sep 17 00:00:00 2001 From: freeman Date: Fri, 31 Oct 2014 22:30:12 -0700 Subject: Streaming KMeans [MLLIB][SPARK-3254] This adds a Streaming KMeans algorithm to MLlib. It uses an update rule that generalizes the mini-batch KMeans update to incorporate a decay factor, which allows past data to be forgotten. The decay factor can be specified explicitly, or via a more intuitive "fractional decay" setting, in units of either data points or batches. The PR includes: - StreamingKMeans algorithm with decay factor settings - Usage example - Additions to documentation clustering page - Unit tests of basic behavior and decay behaviors tdas mengxr rezazadeh Author: freeman Author: Jeremy Freeman Author: Xiangrui Meng Closes #2942 from freeman-lab/streaming-kmeans and squashes the following commits: b2e5b4a [freeman] Fixes to docs / examples 078617c [Jeremy Freeman] Merge pull request #1 from mengxr/SPARK-3254 2e682c0 [Xiangrui Meng] take discount on previous weights; use BLAS; detect dying clusters 0411bf5 [freeman] Change decay parameterization 9f7aea9 [freeman] Style fixes 374a706 [freeman] Formatting ad9bdc2 [freeman] Use labeled points and predictOnValues in examples 77dbd3f [freeman] Make initialization check an assertion 9cfc301 [freeman] Make random seed an argument 44050a9 [freeman] Simpler constructor c7050d5 [freeman] Fix spacing 2899623 [freeman] Use pattern matching for clarity a4a316b [freeman] Use collect 1472ec5 [freeman] Doc formatting ea22ec8 [freeman] Fix imports 2086bdc [freeman] Log cluster center updates ea9877c [freeman] More documentation 9facbe3 [freeman] Bug fix 5db7074 [freeman] Example usage for StreamingKMeans f33684b [freeman] Add explanation and example to docs b5b5f8d [freeman] Add better documentation a0fd790 [freeman] Merge remote-tracking branch 'upstream/master' into streaming-kmeans 9fd9c15 [freeman] Merge remote-tracking branch 'upstream/master' into streaming-kmeans b93350f [freeman] Streaming KMeans with decay --- docs/mllib-clustering.md | 96 +++++++- .../spark/examples/mllib/StreamingKMeans.scala | 77 ++++++ .../spark/mllib/clustering/StreamingKMeans.scala | 268 +++++++++++++++++++++ .../mllib/clustering/StreamingKMeansSuite.scala | 157 ++++++++++++ 4 files changed, 597 insertions(+), 1 deletion(-) create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeans.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index 7978e934fb..c696ae9c8e 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -34,7 +34,7 @@ a given dataset, the algorithm returns the best clustering result). * *initializationSteps* determines the number of steps in the k-means\|\| algorithm. * *epsilon* determines the distance threshold within which we consider k-means to have converged. -## Examples +### Examples
@@ -153,3 +153,97 @@ provided in the [Self-Contained Applications](quick-start.html#self-contained-ap section of the Spark Quick Start guide. Be sure to also include *spark-mllib* to your build file as a dependency. + +## Streaming clustering + +When data arrive in a stream, we may want to estimate clusters dynamically, +updating them as new data arrive. MLlib provides support for streaming k-means clustering, +with parameters to control the decay (or "forgetfulness") of the estimates. The algorithm +uses a generalization of the mini-batch k-means update rule. For each batch of data, we assign +all points to their nearest cluster, compute new cluster centers, then update each cluster using: + +`\begin{equation} + c_{t+1} = \frac{c_tn_t\alpha + x_tm_t}{n_t\alpha+m_t} +\end{equation}` +`\begin{equation} + n_{t+1} = n_t + m_t +\end{equation}` + +Where `$c_t$` is the previous center for the cluster, `$n_t$` is the number of points assigned +to the cluster thus far, `$x_t$` is the new cluster center from the current batch, and `$m_t$` +is the number of points added to the cluster in the current batch. The decay factor `$\alpha$` +can be used to ignore the past: with `$\alpha$=1` all data will be used from the beginning; +with `$\alpha$=0` only the most recent data will be used. This is analogous to an +exponentially-weighted moving average. + +The decay can be specified using a `halfLife` parameter, which determines the +correct decay factor `a` such that, for data acquired +at time `t`, its contribution by time `t + halfLife` will have dropped to 0.5. +The unit of time can be specified either as `batches` or `points` and the update rule +will be adjusted accordingly. + +### Examples + +This example shows how to estimate clusters on streaming data. + +
+ +
+ +First we import the neccessary classes. + +{% highlight scala %} + +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.clustering.StreamingKMeans + +{% endhighlight %} + +Then we make an input stream of vectors for training, as well as a stream of labeled data +points for testing. We assume a StreamingContext `ssc` has been created, see +[Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) for more info. + +{% highlight scala %} + +val trainingData = ssc.textFileStream("/training/data/dir").map(Vectors.parse) +val testData = ssc.textFileStream("/testing/data/dir").map(LabeledPoint.parse) + +{% endhighlight %} + +We create a model with random clusters and specify the number of clusters to find + +{% highlight scala %} + +val numDimensions = 3 +val numClusters = 2 +val model = new StreamingKMeans() + .setK(numClusters) + .setDecayFactor(1.0) + .setRandomCenters(numDimensions, 0.0) + +{% endhighlight %} + +Now register the streams for training and testing and start the job, printing +the predicted cluster assignments on new data points as they arrive. + +{% highlight scala %} + +model.trainOn(trainingData) +model.predictOnValues(testData).print() + +ssc.start() +ssc.awaitTermination() + +{% endhighlight %} + +As you add new text files with data the cluster centers will update. Each training +point should be formatted as `[x1, x2, x3]`, and each test data point +should be formatted as `(y, [x1, x2, x3])`, where `y` is some useful label or identifier +(e.g. a true category assignment). Anytime a text file is placed in `/training/data/dir` +the model will update. Anytime a text file is placed in `/testing/data/dir` +you will see predictions. With new data, the cluster centers will change! + +
+ +
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeans.scala new file mode 100644 index 0000000000..33e5760aed --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeans.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.clustering.StreamingKMeans +import org.apache.spark.SparkConf +import org.apache.spark.streaming.{Seconds, StreamingContext} + +/** + * Estimate clusters on one stream of data and make predictions + * on another stream, where the data streams arrive as text files + * into two different directories. + * + * The rows of the training text files must be vector data in the form + * `[x1,x2,x3,...,xn]` + * Where n is the number of dimensions. + * + * The rows of the test text files must be labeled data in the form + * `(y,[x1,x2,x3,...,xn])` + * Where y is some identifier. n must be the same for train and test. + * + * Usage: StreamingKmeans + * + * To run on your local machine using the two directories `trainingDir` and `testDir`, + * with updates every 5 seconds, 2 dimensions per data point, and 3 clusters, call: + * $ bin/run-example \ + * org.apache.spark.examples.mllib.StreamingKMeans trainingDir testDir 5 3 2 + * + * As you add text files to `trainingDir` the clusters will continuously update. + * Anytime you add text files to `testDir`, you'll see predicted labels using the current model. + * + */ +object StreamingKMeans { + + def main(args: Array[String]) { + if (args.length != 5) { + System.err.println( + "Usage: StreamingKMeans " + + " ") + System.exit(1) + } + + val conf = new SparkConf().setMaster("local").setAppName("StreamingLinearRegression") + val ssc = new StreamingContext(conf, Seconds(args(2).toLong)) + + val trainingData = ssc.textFileStream(args(0)).map(Vectors.parse) + val testData = ssc.textFileStream(args(1)).map(LabeledPoint.parse) + + val model = new StreamingKMeans() + .setK(args(3).toInt) + .setDecayFactor(1.0) + .setRandomCenters(args(4).toInt, 0.0) + + model.trainOn(trainingData) + model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print() + + ssc.start() + ssc.awaitTermination() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala new file mode 100644 index 0000000000..6189dce9b2 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -0,0 +1,268 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import scala.reflect.ClassTag + +import org.apache.spark.Logging +import org.apache.spark.SparkContext._ +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.StreamingContext._ +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.util.Utils +import org.apache.spark.util.random.XORShiftRandom + +/** + * :: DeveloperApi :: + * StreamingKMeansModel extends MLlib's KMeansModel for streaming + * algorithms, so it can keep track of a continuously updated weight + * associated with each cluster, and also update the model by + * doing a single iteration of the standard k-means algorithm. + * + * The update algorithm uses the "mini-batch" KMeans rule, + * generalized to incorporate forgetfullness (i.e. decay). + * The update rule (for each cluster) is: + * + * c_t+1 = [(c_t * n_t * a) + (x_t * m_t)] / [n_t + m_t] + * n_t+t = n_t * a + m_t + * + * Where c_t is the previously estimated centroid for that cluster, + * n_t is the number of points assigned to it thus far, x_t is the centroid + * estimated on the current batch, and m_t is the number of points assigned + * to that centroid in the current batch. + * + * The decay factor 'a' scales the contribution of the clusters as estimated thus far, + * by applying a as a discount weighting on the current point when evaluating + * new incoming data. If a=1, all batches are weighted equally. If a=0, new centroids + * are determined entirely by recent data. Lower values correspond to + * more forgetting. + * + * Decay can optionally be specified by a half life and associated + * time unit. The time unit can either be a batch of data or a single + * data point. Considering data arrived at time t, the half life h is defined + * such that at time t + h the discount applied to the data from t is 0.5. + * The definition remains the same whether the time unit is given + * as batches or points. + * + */ +@DeveloperApi +class StreamingKMeansModel( + override val clusterCenters: Array[Vector], + val clusterWeights: Array[Double]) extends KMeansModel(clusterCenters) with Logging { + + /** Perform a k-means update on a batch of data. */ + def update(data: RDD[Vector], decayFactor: Double, timeUnit: String): StreamingKMeansModel = { + + // find nearest cluster to each point + val closest = data.map(point => (this.predict(point), (point, 1L))) + + // get sums and counts for updating each cluster + val mergeContribs: ((Vector, Long), (Vector, Long)) => (Vector, Long) = (p1, p2) => { + BLAS.axpy(1.0, p2._1, p1._1) + (p1._1, p1._2 + p2._2) + } + val dim = clusterCenters(0).size + val pointStats: Array[(Int, (Vector, Long))] = closest + .aggregateByKey((Vectors.zeros(dim), 0L))(mergeContribs, mergeContribs) + .collect() + + val discount = timeUnit match { + case StreamingKMeans.BATCHES => decayFactor + case StreamingKMeans.POINTS => + val numNewPoints = pointStats.view.map { case (_, (_, n)) => + n + }.sum + math.pow(decayFactor, numNewPoints) + } + + // apply discount to weights + BLAS.scal(discount, Vectors.dense(clusterWeights)) + + // implement update rule + pointStats.foreach { case (label, (sum, count)) => + val centroid = clusterCenters(label) + + val updatedWeight = clusterWeights(label) + count + val lambda = count / math.max(updatedWeight, 1e-16) + + clusterWeights(label) = updatedWeight + BLAS.scal(1.0 - lambda, centroid) + BLAS.axpy(lambda / count, sum, centroid) + + // display the updated cluster centers + val display = clusterCenters(label).size match { + case x if x > 100 => centroid.toArray.take(100).mkString("[", ",", "...") + case _ => centroid.toArray.mkString("[", ",", "]") + } + + logInfo(s"Cluster $label updated with weight $updatedWeight and centroid: $display") + } + + // Check whether the smallest cluster is dying. If so, split the largest cluster. + val weightsWithIndex = clusterWeights.view.zipWithIndex + val (maxWeight, largest) = weightsWithIndex.maxBy(_._1) + val (minWeight, smallest) = weightsWithIndex.minBy(_._1) + if (minWeight < 1e-8 * maxWeight) { + logInfo(s"Cluster $smallest is dying. Split the largest cluster $largest into two.") + val weight = (maxWeight + minWeight) / 2.0 + clusterWeights(largest) = weight + clusterWeights(smallest) = weight + val largestClusterCenter = clusterCenters(largest) + val smallestClusterCenter = clusterCenters(smallest) + var j = 0 + while (j < dim) { + val x = largestClusterCenter(j) + val p = 1e-14 * math.max(math.abs(x), 1.0) + largestClusterCenter.toBreeze(j) = x + p + smallestClusterCenter.toBreeze(j) = x - p + j += 1 + } + } + + this + } +} + +/** + * :: DeveloperApi :: + * StreamingKMeans provides methods for configuring a + * streaming k-means analysis, training the model on streaming, + * and using the model to make predictions on streaming data. + * See KMeansModel for details on algorithm and update rules. + * + * Use a builder pattern to construct a streaming k-means analysis + * in an application, like: + * + * val model = new StreamingKMeans() + * .setDecayFactor(0.5) + * .setK(3) + * .setRandomCenters(5, 100.0) + * .trainOn(DStream) + */ +@DeveloperApi +class StreamingKMeans( + var k: Int, + var decayFactor: Double, + var timeUnit: String) extends Logging { + + def this() = this(2, 1.0, StreamingKMeans.BATCHES) + + protected var model: StreamingKMeansModel = new StreamingKMeansModel(null, null) + + /** Set the number of clusters. */ + def setK(k: Int): this.type = { + this.k = k + this + } + + /** Set the decay factor directly (for forgetful algorithms). */ + def setDecayFactor(a: Double): this.type = { + this.decayFactor = decayFactor + this + } + + /** Set the half life and time unit ("batches" or "points") for forgetful algorithms. */ + def setHalfLife(halfLife: Double, timeUnit: String): this.type = { + if (timeUnit != StreamingKMeans.BATCHES && timeUnit != StreamingKMeans.POINTS) { + throw new IllegalArgumentException("Invalid time unit for decay: " + timeUnit) + } + this.decayFactor = math.exp(math.log(0.5) / halfLife) + logInfo("Setting decay factor to: %g ".format (this.decayFactor)) + this.timeUnit = timeUnit + this + } + + /** Specify initial centers directly. */ + def setInitialCenters(centers: Array[Vector], weights: Array[Double]): this.type = { + model = new StreamingKMeansModel(centers, weights) + this + } + + /** + * Initialize random centers, requiring only the number of dimensions. + * + * @param dim Number of dimensions + * @param weight Weight for each center + * @param seed Random seed + */ + def setRandomCenters(dim: Int, weight: Double, seed: Long = Utils.random.nextLong): this.type = { + val random = new XORShiftRandom(seed) + val centers = Array.fill(k)(Vectors.dense(Array.fill(dim)(random.nextGaussian()))) + val weights = Array.fill(k)(weight) + model = new StreamingKMeansModel(centers, weights) + this + } + + /** Return the latest model. */ + def latestModel(): StreamingKMeansModel = { + model + } + + /** + * Update the clustering model by training on batches of data from a DStream. + * This operation registers a DStream for training the model, + * checks whether the cluster centers have been initialized, + * and updates the model using each batch of data from the stream. + * + * @param data DStream containing vector data + */ + def trainOn(data: DStream[Vector]) { + assertInitialized() + data.foreachRDD { (rdd, time) => + model = model.update(rdd, decayFactor, timeUnit) + } + } + + /** + * Use the clustering model to make predictions on batches of data from a DStream. + * + * @param data DStream containing vector data + * @return DStream containing predictions + */ + def predictOn(data: DStream[Vector]): DStream[Int] = { + assertInitialized() + data.map(model.predict) + } + + /** + * Use the model to make predictions on the values of a DStream and carry over its keys. + * + * @param data DStream containing (key, feature vector) pairs + * @tparam K key type + * @return DStream containing the input keys and the predictions as values + */ + def predictOnValues[K: ClassTag](data: DStream[(K, Vector)]): DStream[(K, Int)] = { + assertInitialized() + data.mapValues(model.predict) + } + + /** Check whether cluster centers have been initialized. */ + private[this] def assertInitialized(): Unit = { + if (model.clusterCenters == null) { + throw new IllegalStateException( + "Initial cluster centers must be set before starting predictions") + } + } +} + +private[clustering] object StreamingKMeans { + final val BATCHES = "batches" + final val POINTS = "points" +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala new file mode 100644 index 0000000000..850c9fce50 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.streaming.TestSuiteBase +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.util.random.XORShiftRandom + +class StreamingKMeansSuite extends FunSuite with TestSuiteBase { + + override def maxWaitTimeMillis = 30000 + + test("accuracy for single center and equivalence to grand average") { + // set parameters + val numBatches = 10 + val numPoints = 50 + val k = 1 + val d = 5 + val r = 0.1 + + // create model with one cluster + val model = new StreamingKMeans() + .setK(1) + .setDecayFactor(1.0) + .setInitialCenters(Array(Vectors.dense(0.0, 0.0, 0.0, 0.0, 0.0)), Array(0.0)) + + // generate random data for k-means + val (input, centers) = StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42) + + // setup and run the model training + val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { + model.trainOn(inputDStream) + inputDStream.count() + }) + runStreams(ssc, numBatches, numBatches) + + // estimated center should be close to true center + assert(centers(0) ~== model.latestModel().clusterCenters(0) absTol 1E-1) + + // estimated center from streaming should exactly match the arithmetic mean of all data points + // because the decay factor is set to 1.0 + val grandMean = + input.flatten.map(x => x.toBreeze).reduce(_+_) / (numBatches * numPoints).toDouble + assert(model.latestModel().clusterCenters(0) ~== Vectors.dense(grandMean.toArray) absTol 1E-5) + } + + test("accuracy for two centers") { + val numBatches = 10 + val numPoints = 5 + val k = 2 + val d = 5 + val r = 0.1 + + // create model with two clusters + val kMeans = new StreamingKMeans() + .setK(2) + .setHalfLife(2, "batches") + .setInitialCenters( + Array(Vectors.dense(-0.1, 0.1, -0.2, -0.3, -0.1), + Vectors.dense(0.1, -0.2, 0.0, 0.2, 0.1)), + Array(5.0, 5.0)) + + // generate random data for k-means + val (input, centers) = StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42) + + // setup and run the model training + val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { + kMeans.trainOn(inputDStream) + inputDStream.count() + }) + runStreams(ssc, numBatches, numBatches) + + // check that estimated centers are close to true centers + // NOTE exact assignment depends on the initialization! + assert(centers(0) ~== kMeans.latestModel().clusterCenters(0) absTol 1E-1) + assert(centers(1) ~== kMeans.latestModel().clusterCenters(1) absTol 1E-1) + } + + test("detecting dying clusters") { + val numBatches = 10 + val numPoints = 5 + val k = 1 + val d = 1 + val r = 1.0 + + // create model with two clusters + val kMeans = new StreamingKMeans() + .setK(2) + .setHalfLife(0.5, "points") + .setInitialCenters( + Array(Vectors.dense(0.0), Vectors.dense(1000.0)), + Array(1.0, 1.0)) + + // new data are all around the first cluster 0.0 + val (input, _) = + StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42, Array(Vectors.dense(0.0))) + + // setup and run the model training + val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { + kMeans.trainOn(inputDStream) + inputDStream.count() + }) + runStreams(ssc, numBatches, numBatches) + + // check that estimated centers are close to true centers + // NOTE exact assignment depends on the initialization! + val model = kMeans.latestModel() + val c0 = model.clusterCenters(0)(0) + val c1 = model.clusterCenters(1)(0) + + assert(c0 * c1 < 0.0, "should have one positive center and one negative center") + // 0.8 is the mean of half-normal distribution + assert(math.abs(c0) ~== 0.8 absTol 0.6) + assert(math.abs(c1) ~== 0.8 absTol 0.6) + } + + def StreamingKMeansDataGenerator( + numPoints: Int, + numBatches: Int, + k: Int, + d: Int, + r: Double, + seed: Int, + initCenters: Array[Vector] = null): (IndexedSeq[IndexedSeq[Vector]], Array[Vector]) = { + val rand = new XORShiftRandom(seed) + val centers = initCenters match { + case null => Array.fill(k)(Vectors.dense(Array.fill(d)(rand.nextGaussian()))) + case _ => initCenters + } + val data = (0 until numBatches).map { i => + (0 until numPoints).map { idx => + val center = centers(idx % k) + Vectors.dense(Array.tabulate(d)(x => center(x) + rand.nextGaussian() * r)) + } + } + (data, centers) + } +} -- cgit v1.2.3