aboutsummaryrefslogtreecommitdiff
path: root/python/examples/kmeans.py
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@eecs.berkeley.edu>2013-01-01 14:48:45 -0800
committerJosh Rosen <joshrosen@eecs.berkeley.edu>2013-01-01 15:05:00 -0800
commitb58340dbd9a741331fc4c3829b08c093560056c2 (patch)
tree52b0e94c47892a8f884b2f80a59ccdb1a428b389 /python/examples/kmeans.py
parent170e451fbdd308ae77065bd9c0f2bd278abf0cb7 (diff)
downloadspark-b58340dbd9a741331fc4c3829b08c093560056c2.tar.gz
spark-b58340dbd9a741331fc4c3829b08c093560056c2.tar.bz2
spark-b58340dbd9a741331fc4c3829b08c093560056c2.zip
Rename top-level 'pyspark' directory to 'python'
Diffstat (limited to 'python/examples/kmeans.py')
-rw-r--r--python/examples/kmeans.py52
1 files changed, 52 insertions, 0 deletions
diff --git a/python/examples/kmeans.py b/python/examples/kmeans.py
new file mode 100644
index 0000000000..ad2be21178
--- /dev/null
+++ b/python/examples/kmeans.py
@@ -0,0 +1,52 @@
+"""
+This example requires numpy (http://www.numpy.org/)
+"""
+import sys
+
+import numpy as np
+from pyspark import SparkContext
+
+
+def parseVector(line):
+ return np.array([float(x) for x in line.split(' ')])
+
+
+def closestPoint(p, centers):
+ bestIndex = 0
+ closest = float("+inf")
+ for i in range(len(centers)):
+ tempDist = np.sum((p - centers[i]) ** 2)
+ if tempDist < closest:
+ closest = tempDist
+ bestIndex = i
+ return bestIndex
+
+
+if __name__ == "__main__":
+ if len(sys.argv) < 5:
+ print >> sys.stderr, \
+ "Usage: PythonKMeans <master> <file> <k> <convergeDist>"
+ exit(-1)
+ sc = SparkContext(sys.argv[1], "PythonKMeans")
+ lines = sc.textFile(sys.argv[2])
+ data = lines.map(parseVector).cache()
+ K = int(sys.argv[3])
+ convergeDist = float(sys.argv[4])
+
+ kPoints = data.takeSample(False, K, 34)
+ tempDist = 1.0
+
+ while tempDist > convergeDist:
+ closest = data.map(
+ lambda p : (closestPoint(p, kPoints), (p, 1)))
+ pointStats = closest.reduceByKey(
+ lambda (x1, y1), (x2, y2): (x1 + x2, y1 + y2))
+ newPoints = pointStats.map(
+ lambda (x, (y, z)): (x, y / z)).collect()
+
+ tempDist = sum(np.sum((kPoints[x] - y) ** 2) for (x, y) in newPoints)
+
+ for (x, y) in newPoints:
+ kPoints[x] = y
+
+ print "Final centers: " + str(kPoints)