From 100e8007822040df34e4e47f872d183a48e9c7f4 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Tue, 31 Jan 2012 00:33:18 -0800 Subject: Some fixes to the examples (mostly to use functional API) --- .../main/scala/spark/examples/SparkHdfsLR.scala | 10 +- .../main/scala/spark/examples/SparkKMeans.scala | 118 ++++++++++----------- .../src/main/scala/spark/examples/SparkLR.scala | 10 +- .../src/main/scala/spark/examples/SparkPi.scala | 10 +- 4 files changed, 72 insertions(+), 76 deletions(-) (limited to 'examples/src/main/scala') diff --git a/examples/src/main/scala/spark/examples/SparkHdfsLR.scala b/examples/src/main/scala/spark/examples/SparkHdfsLR.scala index 4c71fd0845..f4cb8f7903 100644 --- a/examples/src/main/scala/spark/examples/SparkHdfsLR.scala +++ b/examples/src/main/scala/spark/examples/SparkHdfsLR.scala @@ -40,12 +40,10 @@ object SparkHdfsLR { for (i <- 1 to ITERATIONS) { println("On iteration " + i) - val gradient = sc.accumulator(Vector.zeros(D)) - for (p <- points) { - val scale = (1 / (1 + exp(-p.y * (w dot p.x))) - 1) * p.y - gradient += scale * p.x - } - w -= gradient.value + val gradient = points.map { p => + (1 / (1 + exp(-p.y * (w dot p.x))) - 1) * p.y * p.x + }.reduce(_ + _) + w -= gradient } println("Final w: " + w) diff --git a/examples/src/main/scala/spark/examples/SparkKMeans.scala b/examples/src/main/scala/spark/examples/SparkKMeans.scala index b0d3407801..3139a0a6e2 100644 --- a/examples/src/main/scala/spark/examples/SparkKMeans.scala +++ b/examples/src/main/scala/spark/examples/SparkKMeans.scala @@ -8,66 +8,66 @@ import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet object SparkKMeans { - val R = 1000 // Scaling factor - val rand = new Random(42) - - def parseVector(line: String): Vector = { - return new Vector(line.split(' ').map(_.toDouble)) - } - - def closestPoint(p: Vector, centers: HashMap[Int, Vector]): Int = { - var index = 0 - var bestIndex = 0 - var closest = Double.PositiveInfinity - - for (i <- 1 to centers.size) { - val vCurr = centers.get(i).get - val tempDist = p.squaredDist(vCurr) - if (tempDist < closest) { - closest = tempDist - bestIndex = i - } - } - - return bestIndex - } + val R = 1000 // Scaling factor + val rand = new Random(42) + + def parseVector(line: String): Vector = { + return new Vector(line.split(' ').map(_.toDouble)) + } + + def closestPoint(p: Vector, centers: HashMap[Int, Vector]): Int = { + var index = 0 + var bestIndex = 0 + var closest = Double.PositiveInfinity + + for (i <- 1 to centers.size) { + val vCurr = centers.get(i).get + val tempDist = p.squaredDist(vCurr) + if (tempDist < closest) { + closest = tempDist + bestIndex = i + } + } + + return bestIndex + } - def main(args: Array[String]) { - if (args.length < 4) { - System.err.println("Usage: SparkLocalKMeans ") - System.exit(1) - } - val sc = new SparkContext(args(0), "SparkLocalKMeans") - val lines = sc.textFile(args(1)) - val data = lines.map(parseVector _).cache() - val K = args(2).toInt - val convergeDist = args(3).toDouble - - var points = data.takeSample(false, K, 42) - var kPoints = new HashMap[Int, Vector] - var tempDist = 1.0 - - for (i <- 1 to points.size) { - kPoints.put(i, points(i-1)) - } + def main(args: Array[String]) { + if (args.length < 4) { + System.err.println("Usage: SparkLocalKMeans ") + System.exit(1) + } + val sc = new SparkContext(args(0), "SparkLocalKMeans") + val lines = sc.textFile(args(1)) + val data = lines.map(parseVector _).cache() + val K = args(2).toInt + val convergeDist = args(3).toDouble + + var points = data.takeSample(false, K, 42) + var kPoints = new HashMap[Int, Vector] + var tempDist = 1.0 + + for (i <- 1 to points.size) { + kPoints.put(i, points(i-1)) + } - while(tempDist > convergeDist) { - var closest = data.map (p => (closestPoint(p, kPoints), (p, 1))) - - var pointStats = closest.reduceByKey {case ((x1, y1), (x2, y2)) => (x1 + x2, y1+y2)} - - var newPoints = pointStats.map {mapping => (mapping._1, mapping._2._1/mapping._2._2)}.collect() - - tempDist = 0.0 - for (mapping <- newPoints) { - tempDist += kPoints.get(mapping._1).get.squaredDist(mapping._2) - } - - for (newP <- newPoints) { - kPoints.put(newP._1, newP._2) - } - } + while(tempDist > convergeDist) { + var closest = data.map (p => (closestPoint(p, kPoints), (p, 1))) + + var pointStats = closest.reduceByKey {case ((x1, y1), (x2, y2)) => (x1 + x2, y1+y2)} + + var newPoints = pointStats.map {pair => (pair._1, pair._2._1 / pair._2._2)}.collect() + + tempDist = 0.0 + for (pair <- newPoints) { + tempDist += kPoints.get(pair._1).get.squaredDist(pair._2) + } + + for (newP <- newPoints) { + kPoints.put(newP._1, newP._2) + } + } - println("Final centers: " + kPoints) - } + println("Final centers: " + kPoints) + } } diff --git a/examples/src/main/scala/spark/examples/SparkLR.scala b/examples/src/main/scala/spark/examples/SparkLR.scala index faa8471824..207d936a15 100644 --- a/examples/src/main/scala/spark/examples/SparkLR.scala +++ b/examples/src/main/scala/spark/examples/SparkLR.scala @@ -38,12 +38,10 @@ object SparkLR { for (i <- 1 to ITERATIONS) { println("On iteration " + i) - val gradient = sc.accumulator(Vector.zeros(D)) - for (p <- sc.parallelize(data, numSlices)) { - val scale = (1 / (1 + exp(-p.y * (w dot p.x))) - 1) * p.y - gradient += scale * p.x - } - w -= gradient.value + val gradient = sc.parallelize(data, numSlices).map { p => + (1 / (1 + exp(-p.y * (w dot p.x))) - 1) * p.y * p.x + }.reduce(_ + _) + w -= gradient } println("Final w: " + w) diff --git a/examples/src/main/scala/spark/examples/SparkPi.scala b/examples/src/main/scala/spark/examples/SparkPi.scala index 31c6c5b9b1..4cce3c6f36 100644 --- a/examples/src/main/scala/spark/examples/SparkPi.scala +++ b/examples/src/main/scala/spark/examples/SparkPi.scala @@ -12,12 +12,12 @@ object SparkPi { } val spark = new SparkContext(args(0), "SparkPi") val slices = if (args.length > 1) args(1).toInt else 2 - var count = spark.accumulator(0) - for (i <- spark.parallelize(1 to 100000, slices)) { + val n = 100000 * slices + val count = spark.parallelize(1 to n, slices).map { i => val x = random * 2 - 1 val y = random * 2 - 1 - if (x*x + y*y < 1) count += 1 - } - println("Pi is roughly " + 4 * count.value / 100000.0) + if (x*x + y*y < 1) 1 else 0 + }.reduce(_ + _) + println("Pi is roughly " + 4.0 * count / n) } } -- cgit v1.2.3