aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorroot <root@ip-10-6-154-245.ec2.internal>2012-11-11 07:05:22 +0000
committerMatei Zaharia <matei@eecs.berkeley.edu>2012-11-10 23:07:21 -0800
commitacf827232458e87773a71a38f88cb7ba9a6ab77e (patch)
tree084083078865e51ef407f03e622cab8f3ad6683c /examples
parentd0f0fc8c1eea2d7b4fa3220ff68feb9686269810 (diff)
downloadspark-acf827232458e87773a71a38f88cb7ba9a6ab77e.tar.gz
spark-acf827232458e87773a71a38f88cb7ba9a6ab77e.tar.bz2
spark-acf827232458e87773a71a38f88cb7ba9a6ab77e.zip
Fix K-means example a little
Diffstat (limited to 'examples')
-rw-r--r--examples/src/main/scala/spark/examples/SparkKMeans.scala27
1 files changed, 11 insertions, 16 deletions
diff --git a/examples/src/main/scala/spark/examples/SparkKMeans.scala b/examples/src/main/scala/spark/examples/SparkKMeans.scala
index adce551322..6375961390 100644
--- a/examples/src/main/scala/spark/examples/SparkKMeans.scala
+++ b/examples/src/main/scala/spark/examples/SparkKMeans.scala
@@ -15,14 +15,13 @@ object SparkKMeans {
return new Vector(line.split(' ').map(_.toDouble))
}
- def closestPoint(p: Vector, centers: HashMap[Int, Vector]): Int = {
+ def closestPoint(p: Vector, centers: Array[Vector]): Int = {
var index = 0
var bestIndex = 0
var closest = Double.PositiveInfinity
- for (i <- 1 to centers.size) {
- val vCurr = centers.get(i).get
- val tempDist = p.squaredDist(vCurr)
+ for (i <- 0 until centers.length) {
+ val tempDist = p.squaredDist(centers(i))
if (tempDist < closest) {
closest = tempDist
bestIndex = i
@@ -43,32 +42,28 @@ object SparkKMeans {
val K = args(2).toInt
val convergeDist = args(3).toDouble
- var points = data.takeSample(false, K, 42)
- var kPoints = new HashMap[Int, Vector]
+ var kPoints = data.takeSample(false, K, 42).toArray
var tempDist = 1.0
-
- for (i <- 1 to points.size) {
- kPoints.put(i, points(i-1))
- }
while(tempDist > convergeDist) {
var closest = data.map (p => (closestPoint(p, kPoints), (p, 1)))
- var pointStats = closest.reduceByKey {case ((x1, y1), (x2, y2)) => (x1 + x2, y1 + y2)}
+ var pointStats = closest.reduceByKey{case ((x1, y1), (x2, y2)) => (x1 + x2, y1 + y2)}
- var newPoints = pointStats.map {pair => (pair._1, pair._2._1 / pair._2._2)}.collect()
+ var newPoints = pointStats.map {pair => (pair._1, pair._2._1 / pair._2._2)}.collectAsMap()
tempDist = 0.0
- for (pair <- newPoints) {
- tempDist += kPoints.get(pair._1).get.squaredDist(pair._2)
+ for (i <- 0 until K) {
+ tempDist += kPoints(i).squaredDist(newPoints(i))
}
for (newP <- newPoints) {
- kPoints.put(newP._1, newP._2)
+ kPoints(newP._1) = newP._2
}
}
- println("Final centers: " + kPoints)
+ println("Final centers:")
+ kPoints.foreach(println)
System.exit(0)
}
}