diff options
author | root <root@ip-10-6-154-245.ec2.internal> | 2012-11-11 07:05:22 +0000 |
---|---|---|
committer | Matei Zaharia <matei@eecs.berkeley.edu> | 2012-11-10 23:07:21 -0800 |
commit | acf827232458e87773a71a38f88cb7ba9a6ab77e (patch) | |
tree | 084083078865e51ef407f03e622cab8f3ad6683c /examples/src | |
parent | d0f0fc8c1eea2d7b4fa3220ff68feb9686269810 (diff) | |
download | spark-acf827232458e87773a71a38f88cb7ba9a6ab77e.tar.gz spark-acf827232458e87773a71a38f88cb7ba9a6ab77e.tar.bz2 spark-acf827232458e87773a71a38f88cb7ba9a6ab77e.zip |
Fix K-means example a little
Diffstat (limited to 'examples/src')
-rw-r--r-- | examples/src/main/scala/spark/examples/SparkKMeans.scala | 27 |
1 files changed, 11 insertions, 16 deletions
diff --git a/examples/src/main/scala/spark/examples/SparkKMeans.scala b/examples/src/main/scala/spark/examples/SparkKMeans.scala index adce551322..6375961390 100644 --- a/examples/src/main/scala/spark/examples/SparkKMeans.scala +++ b/examples/src/main/scala/spark/examples/SparkKMeans.scala @@ -15,14 +15,13 @@ object SparkKMeans { return new Vector(line.split(' ').map(_.toDouble)) } - def closestPoint(p: Vector, centers: HashMap[Int, Vector]): Int = { + def closestPoint(p: Vector, centers: Array[Vector]): Int = { var index = 0 var bestIndex = 0 var closest = Double.PositiveInfinity - for (i <- 1 to centers.size) { - val vCurr = centers.get(i).get - val tempDist = p.squaredDist(vCurr) + for (i <- 0 until centers.length) { + val tempDist = p.squaredDist(centers(i)) if (tempDist < closest) { closest = tempDist bestIndex = i @@ -43,32 +42,28 @@ object SparkKMeans { val K = args(2).toInt val convergeDist = args(3).toDouble - var points = data.takeSample(false, K, 42) - var kPoints = new HashMap[Int, Vector] + var kPoints = data.takeSample(false, K, 42).toArray var tempDist = 1.0 - - for (i <- 1 to points.size) { - kPoints.put(i, points(i-1)) - } while(tempDist > convergeDist) { var closest = data.map (p => (closestPoint(p, kPoints), (p, 1))) - var pointStats = closest.reduceByKey {case ((x1, y1), (x2, y2)) => (x1 + x2, y1 + y2)} + var pointStats = closest.reduceByKey{case ((x1, y1), (x2, y2)) => (x1 + x2, y1 + y2)} - var newPoints = pointStats.map {pair => (pair._1, pair._2._1 / pair._2._2)}.collect() + var newPoints = pointStats.map {pair => (pair._1, pair._2._1 / pair._2._2)}.collectAsMap() tempDist = 0.0 - for (pair <- newPoints) { - tempDist += kPoints.get(pair._1).get.squaredDist(pair._2) + for (i <- 0 until K) { + tempDist += kPoints(i).squaredDist(newPoints(i)) } for (newP <- newPoints) { - kPoints.put(newP._1, newP._2) + kPoints(newP._1) = newP._2 } } - println("Final centers: " + kPoints) + println("Final centers:") + kPoints.foreach(println) System.exit(0) } } |