From f607ffb9e1f799d73818f1d37c633007a6b900fb Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Wed, 31 Jul 2013 14:31:07 -0700 Subject: Added data generator for K-means Also made it possible to specify the number of runs in KMeans.main(). --- .../main/scala/spark/mllib/clustering/KMeans.scala | 7 +- .../spark/mllib/util/KMeansDataGenerator.scala | 80 ++++++++++++++++++++++ 2 files changed, 84 insertions(+), 3 deletions(-) create mode 100644 mllib/src/main/scala/spark/mllib/util/KMeansDataGenerator.scala (limited to 'mllib') diff --git a/mllib/src/main/scala/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/spark/mllib/clustering/KMeans.scala index d875d6de50..a2ed42d7a5 100644 --- a/mllib/src/main/scala/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/spark/mllib/clustering/KMeans.scala @@ -315,14 +315,15 @@ object KMeans { } def main(args: Array[String]) { - if (args.length != 4) { - println("Usage: KMeans ") + if (args.length < 4) { + println("Usage: KMeans []") System.exit(1) } val (master, inputFile, k, iters) = (args(0), args(1), args(2).toInt, args(3).toInt) + val runs = if (args.length >= 5) args(4).toInt else 1 val sc = new SparkContext(master, "KMeans") val data = sc.textFile(inputFile).map(line => line.split(' ').map(_.toDouble)) - val model = KMeans.train(data, k, iters) + val model = KMeans.train(data, k, iters, runs) val cost = model.computeCost(data) println("Cluster centers:") for (c <- model.clusterCenters) { diff --git a/mllib/src/main/scala/spark/mllib/util/KMeansDataGenerator.scala b/mllib/src/main/scala/spark/mllib/util/KMeansDataGenerator.scala new file mode 100644 index 0000000000..8f95cf7479 --- /dev/null +++ b/mllib/src/main/scala/spark/mllib/util/KMeansDataGenerator.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package spark.mllib.util + +import scala.util.Random + +import spark.{RDD, SparkContext} + +object KMeansDataGenerator { + + /** + * Generate an RDD containing test data for KMeans. This function chooses k cluster centers + * from a d-dimensional Gaussian distribution scaled by factor r, then creates a Gaussian + * cluster with scale 1 around each center. + * + * @param sc SparkContext to use for creating the RDD + * @param numPoints Number of points that will be contained in the RDD + * @param k Number of clusters + * @param d Number of dimensions + * @parak r Scaling factor for the distribution of the initial centers + * @param numPartitions Number of partitions of the generated RDD; default 2 + */ + def generateKMeansRDD( + sc: SparkContext, + numPoints: Int, + k: Int, + d: Int, + r: Double, + numPartitions: Int = 2) + : RDD[Array[Double]] = + { + // First, generate some centers + val rand = new Random(42) + val centers = Array.fill(k)(Array.fill(d)(rand.nextGaussian() * r)) + // Then generate points around each center + sc.parallelize(0 until numPoints, numPartitions).map { idx => + val center = centers(idx % k) + val rand2 = new Random(42 + idx) + Array.tabulate(d)(i => center(i) + rand2.nextGaussian()) + } + } + + def main(args: Array[String]) { + if (args.length < 6) { + println("Usage: KMeansGenerator " + + " []") + System.exit(1) + } + + val sparkMaster = args(0) + val outputPath = args(1) + val numPoints = args(2).toInt + val k = args(3).toInt + val d = args(4).toInt + val r = args(5).toDouble + val parts = if (args.length >= 7) args(6).toInt else 2 + + val sc = new SparkContext(sparkMaster, "KMeansDataGenerator") + val data = generateKMeansRDD(sc, numPoints, k, d, r, parts) + data.map(_.mkString(" ")).saveAsTextFile(outputPath) + + System.exit(0) + } +} + -- cgit v1.2.3 From 52dba89261ee6dddafff5c746322980567252843 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Wed, 31 Jul 2013 23:08:12 -0700 Subject: Turn on caching in KMeans.main --- mllib/src/main/scala/spark/mllib/clustering/KMeans.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'mllib') diff --git a/mllib/src/main/scala/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/spark/mllib/clustering/KMeans.scala index a2ed42d7a5..b402c71ed2 100644 --- a/mllib/src/main/scala/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/spark/mllib/clustering/KMeans.scala @@ -322,7 +322,7 @@ object KMeans { val (master, inputFile, k, iters) = (args(0), args(1), args(2).toInt, args(3).toInt) val runs = if (args.length >= 5) args(4).toInt else 1 val sc = new SparkContext(master, "KMeans") - val data = sc.textFile(inputFile).map(line => line.split(' ').map(_.toDouble)) + val data = sc.textFile(inputFile).map(line => line.split(' ').map(_.toDouble)).cache() val model = KMeans.train(data, k, iters, runs) val cost = model.computeCost(data) println("Cluster centers:") -- cgit v1.2.3