aboutsummaryrefslogtreecommitdiff
path: root/pyspark/examples/kmeans.py
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@eecs.berkeley.edu>2012-12-28 22:51:28 -0800
committerJosh Rosen <joshrosen@eecs.berkeley.edu>2012-12-28 22:51:28 -0800
commitc2b105af34f7241ac0597d9c35fbf66633a3eaf6 (patch)
treee96946d2b714365937019f60741bf3ae62d565c6 /pyspark/examples/kmeans.py
parent7ec3595de28d53839cb3a45e940ec16f81ffdf45 (diff)
downloadspark-c2b105af34f7241ac0597d9c35fbf66633a3eaf6.tar.gz
spark-c2b105af34f7241ac0597d9c35fbf66633a3eaf6.tar.bz2
spark-c2b105af34f7241ac0597d9c35fbf66633a3eaf6.zip
Add documentation for Python API.
Diffstat (limited to 'pyspark/examples/kmeans.py')
-rw-r--r--pyspark/examples/kmeans.py49
1 files changed, 49 insertions, 0 deletions
diff --git a/pyspark/examples/kmeans.py b/pyspark/examples/kmeans.py
new file mode 100644
index 0000000000..9cc366f03c
--- /dev/null
+++ b/pyspark/examples/kmeans.py
@@ -0,0 +1,49 @@
+import sys
+
+from pyspark.context import SparkContext
+from numpy import array, sum as np_sum
+
+
+def parseVector(line):
+ return array([float(x) for x in line.split(' ')])
+
+
+def closestPoint(p, centers):
+ bestIndex = 0
+ closest = float("+inf")
+ for i in range(len(centers)):
+ tempDist = np_sum((p - centers[i]) ** 2)
+ if tempDist < closest:
+ closest = tempDist
+ bestIndex = i
+ return bestIndex
+
+
+if __name__ == "__main__":
+ if len(sys.argv) < 5:
+ print >> sys.stderr, \
+ "Usage: PythonKMeans <master> <file> <k> <convergeDist>"
+ exit(-1)
+ sc = SparkContext(sys.argv[1], "PythonKMeans")
+ lines = sc.textFile(sys.argv[2])
+ data = lines.map(parseVector).cache()
+ K = int(sys.argv[3])
+ convergeDist = float(sys.argv[4])
+
+ kPoints = data.takeSample(False, K, 34)
+ tempDist = 1.0
+
+ while tempDist > convergeDist:
+ closest = data.map(
+ lambda p : (closestPoint(p, kPoints), (p, 1)))
+ pointStats = closest.reduceByKey(
+ lambda (x1, y1), (x2, y2): (x1 + x2, y1 + y2))
+ newPoints = pointStats.map(
+ lambda (x, (y, z)): (x, y / z)).collect()
+
+ tempDist = sum(np_sum((kPoints[x] - y) ** 2) for (x, y) in newPoints)
+
+ for (x, y) in newPoints:
+ kPoints[x] = y
+
+ print "Final centers: " + str(kPoints)