From acf827232458e87773a71a38f88cb7ba9a6ab77e Mon Sep 17 00:00:00 2001 From: root Date: Sun, 11 Nov 2012 07:05:22 +0000 Subject: Fix K-means example a little --- .../main/scala/spark/examples/SparkKMeans.scala | 27 +++++++++------------- 1 file changed, 11 insertions(+), 16 deletions(-) (limited to 'examples/src') diff --git a/examples/src/main/scala/spark/examples/SparkKMeans.scala b/examples/src/main/scala/spark/examples/SparkKMeans.scala index adce551322..6375961390 100644 --- a/examples/src/main/scala/spark/examples/SparkKMeans.scala +++ b/examples/src/main/scala/spark/examples/SparkKMeans.scala @@ -15,14 +15,13 @@ object SparkKMeans { return new Vector(line.split(' ').map(_.toDouble)) } - def closestPoint(p: Vector, centers: HashMap[Int, Vector]): Int = { + def closestPoint(p: Vector, centers: Array[Vector]): Int = { var index = 0 var bestIndex = 0 var closest = Double.PositiveInfinity - for (i <- 1 to centers.size) { - val vCurr = centers.get(i).get - val tempDist = p.squaredDist(vCurr) + for (i <- 0 until centers.length) { + val tempDist = p.squaredDist(centers(i)) if (tempDist < closest) { closest = tempDist bestIndex = i @@ -43,32 +42,28 @@ object SparkKMeans { val K = args(2).toInt val convergeDist = args(3).toDouble - var points = data.takeSample(false, K, 42) - var kPoints = new HashMap[Int, Vector] + var kPoints = data.takeSample(false, K, 42).toArray var tempDist = 1.0 - - for (i <- 1 to points.size) { - kPoints.put(i, points(i-1)) - } while(tempDist > convergeDist) { var closest = data.map (p => (closestPoint(p, kPoints), (p, 1))) - var pointStats = closest.reduceByKey {case ((x1, y1), (x2, y2)) => (x1 + x2, y1 + y2)} + var pointStats = closest.reduceByKey{case ((x1, y1), (x2, y2)) => (x1 + x2, y1 + y2)} - var newPoints = pointStats.map {pair => (pair._1, pair._2._1 / pair._2._2)}.collect() + var newPoints = pointStats.map {pair => (pair._1, pair._2._1 / pair._2._2)}.collectAsMap() tempDist = 0.0 - for (pair <- newPoints) { - tempDist += kPoints.get(pair._1).get.squaredDist(pair._2) + for (i <- 0 until K) { + tempDist += kPoints(i).squaredDist(newPoints(i)) } for (newP <- newPoints) { - kPoints.put(newP._1, newP._2) + kPoints(newP._1) = newP._2 } } - println("Final centers: " + kPoints) + println("Final centers:") + kPoints.foreach(println) System.exit(0) } } -- cgit v1.2.3