aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/src/main/scala/spark/examples/SparkKMeans.scala67
-rw-r--r--examples/src/main/scala/spark/examples/Vector.scala22
-rw-r--r--kmeans_data.txt16
3 files changed, 102 insertions, 3 deletions
diff --git a/examples/src/main/scala/spark/examples/SparkKMeans.scala b/examples/src/main/scala/spark/examples/SparkKMeans.scala
new file mode 100644
index 0000000000..048001dc4f
--- /dev/null
+++ b/examples/src/main/scala/spark/examples/SparkKMeans.scala
@@ -0,0 +1,67 @@
+package spark.examples
+
+import java.util.Random
+import spark.SparkContext
+import spark.SparkContext._
+import spark.examples.Vector._
+
+object SparkKMeans {
+ def parseVector(line: String): Vector = {
+ return new Vector(line.split(' ').map(_.toDouble))
+ }
+
+ def closestCenter(p: Vector, centers: Array[Vector]): Int = {
+ var bestIndex = 0
+ var bestDist = p.squaredDist(centers(0))
+ for (i <- 1 until centers.length) {
+ val dist = p.squaredDist(centers(i))
+ if (dist < bestDist) {
+ bestDist = dist
+ bestIndex = i
+ }
+ }
+ return bestIndex
+ }
+
+ def main(args: Array[String]) {
+ if (args.length < 3) {
+ System.err.println("Usage: SparkKMeans <master> <file> <dimensions> <k> <iters>")
+ System.exit(1)
+ }
+ val sc = new SparkContext(args(0), "SparkKMeans")
+ val lines = sc.textFile(args(1))
+ val points = lines.map(parseVector _).cache()
+ val dimensions = args(2).toInt
+ val k = args(3).toInt
+ val iterations = args(4).toInt
+
+ // Initialize cluster centers randomly
+ val rand = new Random(42)
+ var centers = new Array[Vector](k)
+ for (i <- 0 until k)
+ centers(i) = Vector(dimensions, _ => 2 * rand.nextDouble - 1)
+ println("Initial centers: " + centers.mkString(", "))
+
+ for (i <- 1 to iterations) {
+ println("On iteration " + i)
+
+ // Map each point to the index of its closest center and a (point, 1) pair
+ // that we will use to compute an average later
+ val mappedPoints = points.map { p => (closestCenter(p, centers), (p, 1)) }
+
+ // Compute the new centers by summing the (point, 1) pairs and taking an average
+ val newCenters = mappedPoints.reduceByKey {
+ case ((sum1, count1), (sum2, count2)) => (sum1 + sum2, count1 + count2)
+ }.map {
+ case (id, (sum, count)) => (id, sum / count)
+ }.collect
+
+ // Update the centers array with the new centers we collected
+ for ((id, value) <- newCenters) {
+ centers(id) = value
+ }
+ }
+
+ println("Final centers: " + centers.mkString(", "))
+ }
+}
diff --git a/examples/src/main/scala/spark/examples/Vector.scala b/examples/src/main/scala/spark/examples/Vector.scala
index dd34dffee5..2abccbafce 100644
--- a/examples/src/main/scala/spark/examples/Vector.scala
+++ b/examples/src/main/scala/spark/examples/Vector.scala
@@ -21,19 +21,35 @@ class Vector(val elements: Array[Double]) extends Serializable {
if (length != other.length)
throw new IllegalArgumentException("Vectors of different length")
var ans = 0.0
- for (i <- 0 until length)
+ var i = 0
+ while (i < length) {
ans += this(i) * other(i)
+ i += 1
+ }
return ans
}
- def * ( scale: Double): Vector = Vector(length, i => this(i) * scale)
+ def * (scale: Double): Vector = Vector(length, i => this(i) * scale)
+
+ def / (d: Double): Vector = this * (1 / d)
def unary_- = this * -1
def sum = elements.reduceLeft(_ + _)
- override def toString = elements.mkString("(", ", ", ")")
+ def squaredDist(other: Vector): Double = {
+ var ans = 0.0
+ var i = 0
+ while (i < length) {
+ ans += (this(i) - other(i)) * (this(i) - other(i))
+ i += 1
+ }
+ return ans
+ }
+ def dist(other: Vector): Double = math.sqrt(squaredDist(other))
+
+ override def toString = elements.mkString("(", ", ", ")")
}
object Vector {
diff --git a/kmeans_data.txt b/kmeans_data.txt
new file mode 100644
index 0000000000..06e5c9b45a
--- /dev/null
+++ b/kmeans_data.txt
@@ -0,0 +1,16 @@
+0.1 0.2 0.0 0.2
+0.2 0.2 0.3 0.2
+0.3 0.0 0.0 0.1
+0.1 0.2 0.3 0.2
+1.1 0.2 0.0 0.2
+1.2 0.2 0.3 0.2
+1.3 0.0 0.0 0.1
+1.1 0.2 0.3 0.2
+0.1 1.2 1.0 0.2
+0.2 1.2 1.3 0.2
+0.3 1.0 1.0 0.1
+0.1 1.2 1.3 0.2
+0.1 0.2 0.0 1.2
+0.2 0.2 0.3 1.2
+0.3 0.0 0.0 1.1
+0.1 0.2 0.3 1.2