diff options
author | Matei Zaharia <matei@eecs.berkeley.edu> | 2011-11-01 19:25:58 -0700 |
---|---|---|
committer | Matei Zaharia <matei@eecs.berkeley.edu> | 2011-11-01 19:25:58 -0700 |
commit | d4c8e69dc7851ffba577d4be3b3daf1723971300 (patch) | |
tree | b786f54fac7b6d4884c929f38348a1f6ffeac1b9 /examples/src | |
parent | 157279e9eb4d732c94cf4c5a7cfcd840b0da300c (diff) | |
download | spark-d4c8e69dc7851ffba577d4be3b3daf1723971300.tar.gz spark-d4c8e69dc7851ffba577d4be3b3daf1723971300.tar.bz2 spark-d4c8e69dc7851ffba577d4be3b3daf1723971300.zip |
K-means example
Diffstat (limited to 'examples/src')
-rw-r--r-- | examples/src/main/scala/spark/examples/SparkKMeans.scala | 67 | ||||
-rw-r--r-- | examples/src/main/scala/spark/examples/Vector.scala | 22 |
2 files changed, 86 insertions, 3 deletions
diff --git a/examples/src/main/scala/spark/examples/SparkKMeans.scala b/examples/src/main/scala/spark/examples/SparkKMeans.scala new file mode 100644 index 0000000000..048001dc4f --- /dev/null +++ b/examples/src/main/scala/spark/examples/SparkKMeans.scala @@ -0,0 +1,67 @@ +package spark.examples + +import java.util.Random +import spark.SparkContext +import spark.SparkContext._ +import spark.examples.Vector._ + +object SparkKMeans { + def parseVector(line: String): Vector = { + return new Vector(line.split(' ').map(_.toDouble)) + } + + def closestCenter(p: Vector, centers: Array[Vector]): Int = { + var bestIndex = 0 + var bestDist = p.squaredDist(centers(0)) + for (i <- 1 until centers.length) { + val dist = p.squaredDist(centers(i)) + if (dist < bestDist) { + bestDist = dist + bestIndex = i + } + } + return bestIndex + } + + def main(args: Array[String]) { + if (args.length < 3) { + System.err.println("Usage: SparkKMeans <master> <file> <dimensions> <k> <iters>") + System.exit(1) + } + val sc = new SparkContext(args(0), "SparkKMeans") + val lines = sc.textFile(args(1)) + val points = lines.map(parseVector _).cache() + val dimensions = args(2).toInt + val k = args(3).toInt + val iterations = args(4).toInt + + // Initialize cluster centers randomly + val rand = new Random(42) + var centers = new Array[Vector](k) + for (i <- 0 until k) + centers(i) = Vector(dimensions, _ => 2 * rand.nextDouble - 1) + println("Initial centers: " + centers.mkString(", ")) + + for (i <- 1 to iterations) { + println("On iteration " + i) + + // Map each point to the index of its closest center and a (point, 1) pair + // that we will use to compute an average later + val mappedPoints = points.map { p => (closestCenter(p, centers), (p, 1)) } + + // Compute the new centers by summing the (point, 1) pairs and taking an average + val newCenters = mappedPoints.reduceByKey { + case ((sum1, count1), (sum2, count2)) => (sum1 + sum2, count1 + count2) + }.map { + case (id, (sum, count)) => (id, sum / count) + }.collect + + // Update the centers array with the new centers we collected + for ((id, value) <- newCenters) { + centers(id) = value + } + } + + println("Final centers: " + centers.mkString(", ")) + } +} diff --git a/examples/src/main/scala/spark/examples/Vector.scala b/examples/src/main/scala/spark/examples/Vector.scala index dd34dffee5..2abccbafce 100644 --- a/examples/src/main/scala/spark/examples/Vector.scala +++ b/examples/src/main/scala/spark/examples/Vector.scala @@ -21,19 +21,35 @@ class Vector(val elements: Array[Double]) extends Serializable { if (length != other.length) throw new IllegalArgumentException("Vectors of different length") var ans = 0.0 - for (i <- 0 until length) + var i = 0 + while (i < length) { ans += this(i) * other(i) + i += 1 + } return ans } - def * ( scale: Double): Vector = Vector(length, i => this(i) * scale) + def * (scale: Double): Vector = Vector(length, i => this(i) * scale) + + def / (d: Double): Vector = this * (1 / d) def unary_- = this * -1 def sum = elements.reduceLeft(_ + _) - override def toString = elements.mkString("(", ", ", ")") + def squaredDist(other: Vector): Double = { + var ans = 0.0 + var i = 0 + while (i < length) { + ans += (this(i) - other(i)) * (this(i) - other(i)) + i += 1 + } + return ans + } + def dist(other: Vector): Double = math.sqrt(squaredDist(other)) + + override def toString = elements.mkString("(", ", ", ")") } object Vector { |