aboutsummaryrefslogtreecommitdiff
path: root/pyspark
diff options
context:
space:
mode:
authorJosh Rosen <rosenville@gmail.com>2012-08-10 01:10:02 -0700
committerJosh Rosen <rosenville@gmail.com>2012-08-18 22:33:51 -0700
commit886b39de557b4d5f54f5ca11559fca9799534280 (patch)
treeff4504773f3f75b2408f5acbc1a9e0e0b3b3ff64 /pyspark
parent9a0c128feceb63685513ce9c1022ef2d4de43fbf (diff)
downloadspark-886b39de557b4d5f54f5ca11559fca9799534280.tar.gz
spark-886b39de557b4d5f54f5ca11559fca9799534280.tar.bz2
spark-886b39de557b4d5f54f5ca11559fca9799534280.zip
Add Python API.
Diffstat (limited to 'pyspark')
-rw-r--r--pyspark/pyspark/__init__.py0
-rw-r--r--pyspark/pyspark/context.py69
-rw-r--r--pyspark/pyspark/examples/__init__.py0
-rw-r--r--pyspark/pyspark/examples/kmeans.py56
-rw-r--r--pyspark/pyspark/examples/pi.py20
-rw-r--r--pyspark/pyspark/examples/tc.py49
-rw-r--r--pyspark/pyspark/java_gateway.py20
-rw-r--r--pyspark/pyspark/join.py104
-rw-r--r--pyspark/pyspark/rdd.py517
-rw-r--r--pyspark/pyspark/serializers.py229
-rw-r--r--pyspark/pyspark/worker.py97
-rw-r--r--pyspark/requirements.txt9
12 files changed, 1170 insertions, 0 deletions
diff --git a/pyspark/pyspark/__init__.py b/pyspark/pyspark/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/pyspark/pyspark/__init__.py
diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py
new file mode 100644
index 0000000000..587ab12b5f
--- /dev/null
+++ b/pyspark/pyspark/context.py
@@ -0,0 +1,69 @@
+import os
+import atexit
+from tempfile import NamedTemporaryFile
+
+from pyspark.java_gateway import launch_gateway
+from pyspark.serializers import JSONSerializer, NopSerializer
+from pyspark.rdd import RDD, PairRDD
+
+
+class SparkContext(object):
+
+ gateway = launch_gateway()
+ jvm = gateway.jvm
+ python_dump = jvm.spark.api.python.PythonRDD.pythonDump
+
+ def __init__(self, master, name, defaultSerializer=JSONSerializer,
+ defaultParallelism=None, pythonExec='python'):
+ self.master = master
+ self.name = name
+ self._jsc = self.jvm.JavaSparkContext(master, name)
+ self.defaultSerializer = defaultSerializer
+ self.defaultParallelism = \
+ defaultParallelism or self._jsc.sc().defaultParallelism()
+ self.pythonExec = pythonExec
+
+ def __del__(self):
+ if self._jsc:
+ self._jsc.stop()
+
+ def stop(self):
+ self._jsc.stop()
+ self._jsc = None
+
+ def parallelize(self, c, numSlices=None, serializer=None):
+ serializer = serializer or self.defaultSerializer
+ numSlices = numSlices or self.defaultParallelism
+ # Calling the Java parallelize() method with an ArrayList is too slow,
+ # because it sends O(n) Py4J commands. As an alternative, serialized
+ # objects are written to a file and loaded through textFile().
+ tempFile = NamedTemporaryFile(delete=False)
+ tempFile.writelines(serializer.dumps(x) + '\n' for x in c)
+ tempFile.close()
+ atexit.register(lambda: os.unlink(tempFile.name))
+ return self.textFile(tempFile.name, numSlices, serializer)
+
+ def parallelizePairs(self, c, numSlices=None, keySerializer=None,
+ valSerializer=None):
+ """
+ >>> sc = SparkContext("local", "test")
+ >>> rdd = sc.parallelizePairs([(1, 2), (3, 4)])
+ >>> rdd.collect()
+ [(1, 2), (3, 4)]
+ """
+ keySerializer = keySerializer or self.defaultSerializer
+ valSerializer = valSerializer or self.defaultSerializer
+ numSlices = numSlices or self.defaultParallelism
+ tempFile = NamedTemporaryFile(delete=False)
+ for (k, v) in c:
+ tempFile.write(keySerializer.dumps(k).rstrip('\r\n') + '\n')
+ tempFile.write(valSerializer.dumps(v).rstrip('\r\n') + '\n')
+ tempFile.close()
+ atexit.register(lambda: os.unlink(tempFile.name))
+ jrdd = self.textFile(tempFile.name, numSlices)._pipePairs([], "echo")
+ return PairRDD(jrdd, self, keySerializer, valSerializer)
+
+ def textFile(self, name, numSlices=None, serializer=NopSerializer):
+ numSlices = numSlices or self.defaultParallelism
+ jrdd = self._jsc.textFile(name, numSlices)
+ return RDD(jrdd, self, serializer)
diff --git a/pyspark/pyspark/examples/__init__.py b/pyspark/pyspark/examples/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/pyspark/pyspark/examples/__init__.py
diff --git a/pyspark/pyspark/examples/kmeans.py b/pyspark/pyspark/examples/kmeans.py
new file mode 100644
index 0000000000..0761d6e395
--- /dev/null
+++ b/pyspark/pyspark/examples/kmeans.py
@@ -0,0 +1,56 @@
+import sys
+
+from pyspark.context import SparkContext
+
+
+def parseVector(line):
+ return [float(x) for x in line.split(' ')]
+
+
+def addVec(x, y):
+ return [a + b for (a, b) in zip(x, y)]
+
+
+def squaredDist(x, y):
+ return sum((a - b) ** 2 for (a, b) in zip(x, y))
+
+
+def closestPoint(p, centers):
+ bestIndex = 0
+ closest = float("+inf")
+ for i in range(len(centers)):
+ tempDist = squaredDist(p, centers[i])
+ if tempDist < closest:
+ closest = tempDist
+ bestIndex = i
+ return bestIndex
+
+
+if __name__ == "__main__":
+ if len(sys.argv) < 5:
+ print >> sys.stderr, \
+ "Usage: PythonKMeans <master> <file> <k> <convergeDist>"
+ exit(-1)
+ sc = SparkContext(sys.argv[1], "PythonKMeans")
+ lines = sc.textFile(sys.argv[2])
+ data = lines.map(parseVector).cache()
+ K = int(sys.argv[3])
+ convergeDist = float(sys.argv[4])
+
+ kPoints = data.takeSample(False, K, 34)
+ tempDist = 1.0
+
+ while tempDist > convergeDist:
+ closest = data.mapPairs(
+ lambda p : (closestPoint(p, kPoints), (p, 1)))
+ pointStats = closest.reduceByKey(
+ lambda (x1, y1), (x2, y2): (addVec(x1, x2), y1 + y2))
+ newPoints = pointStats.mapPairs(
+ lambda (x, (y, z)): (x, [a / z for a in y])).collect()
+
+ tempDist = sum(squaredDist(kPoints[x], y) for (x, y) in newPoints)
+
+ for (x, y) in newPoints:
+ kPoints[x] = y
+
+ print "Final centers: " + str(kPoints)
diff --git a/pyspark/pyspark/examples/pi.py b/pyspark/pyspark/examples/pi.py
new file mode 100644
index 0000000000..ad77694c41
--- /dev/null
+++ b/pyspark/pyspark/examples/pi.py
@@ -0,0 +1,20 @@
+import sys
+from random import random
+from operator import add
+from pyspark.context import SparkContext
+
+
+if __name__ == "__main__":
+ if len(sys.argv) == 1:
+ print >> sys.stderr, \
+ "Usage: PythonPi <host> [<slices>]"
+ exit(-1)
+ sc = SparkContext(sys.argv[1], "PythonKMeans")
+ slices = sys.argv[2] if len(sys.argv) > 2 else 2
+ n = 100000 * slices
+ def f(_):
+ x = random() * 2 - 1
+ y = random() * 2 - 1
+ return 1 if x ** 2 + y ** 2 < 1 else 0
+ count = sc.parallelize(xrange(1, n+1), slices).map(f).reduce(add)
+ print "Pi is roughly %f" % (4.0 * count / n)
diff --git a/pyspark/pyspark/examples/tc.py b/pyspark/pyspark/examples/tc.py
new file mode 100644
index 0000000000..2796fdc6ad
--- /dev/null
+++ b/pyspark/pyspark/examples/tc.py
@@ -0,0 +1,49 @@
+import sys
+from random import Random
+from pyspark.context import SparkContext
+
+numEdges = 200
+numVertices = 100
+rand = Random(42)
+
+
+def generateGraph():
+ edges = set()
+ while len(edges) < numEdges:
+ src = rand.randrange(0, numEdges)
+ dst = rand.randrange(0, numEdges)
+ if src != dst:
+ edges.add((src, dst))
+ return edges
+
+
+if __name__ == "__main__":
+ if len(sys.argv) == 1:
+ print >> sys.stderr, \
+ "Usage: PythonTC <host> [<slices>]"
+ exit(-1)
+ sc = SparkContext(sys.argv[1], "PythonKMeans")
+ slices = sys.argv[2] if len(sys.argv) > 2 else 2
+ tc = sc.parallelizePairs(generateGraph(), slices).cache()
+
+ # Linear transitive closure: each round grows paths by one edge,
+ # by joining the graph's edges with the already-discovered paths.
+ # e.g. join the path (y, z) from the TC with the edge (x, y) from
+ # the graph to obtain the path (x, z).
+
+ # Because join() joins on keys, the edges are stored in reversed order.
+ edges = tc.mapPairs(lambda (x, y): (y, x))
+
+ oldCount = 0L
+ nextCount = tc.count()
+ while True:
+ oldCount = nextCount
+ # Perform the join, obtaining an RDD of (y, (z, x)) pairs,
+ # then project the result to obtain the new (x, z) paths.
+ new_edges = tc.join(edges).mapPairs(lambda (_, (a, b)): (b, a))
+ tc = tc.union(new_edges).distinct().cache()
+ nextCount = tc.count()
+ if nextCount == oldCount:
+ break
+
+ print "TC has %i edges" % tc.count()
diff --git a/pyspark/pyspark/java_gateway.py b/pyspark/pyspark/java_gateway.py
new file mode 100644
index 0000000000..2df80aee85
--- /dev/null
+++ b/pyspark/pyspark/java_gateway.py
@@ -0,0 +1,20 @@
+import glob
+import os
+from py4j.java_gateway import java_import, JavaGateway
+
+
+SPARK_HOME = os.environ["SPARK_HOME"]
+
+
+assembly_jar = glob.glob(os.path.join(SPARK_HOME, "core/target") + \
+ "/spark-core-assembly-*-SNAPSHOT.jar")[0]
+
+
+def launch_gateway():
+ gateway = JavaGateway.launch_gateway(classpath=assembly_jar,
+ javaopts=["-Xmx256m"], die_on_exit=True)
+ java_import(gateway.jvm, "spark.api.java.*")
+ java_import(gateway.jvm, "spark.api.python.*")
+ java_import(gateway.jvm, "scala.Tuple2")
+ java_import(gateway.jvm, "spark.api.python.PythonRDD.pythonDump")
+ return gateway
diff --git a/pyspark/pyspark/join.py b/pyspark/pyspark/join.py
new file mode 100644
index 0000000000..c67520fce8
--- /dev/null
+++ b/pyspark/pyspark/join.py
@@ -0,0 +1,104 @@
+"""
+Copyright (c) 2011, Douban Inc. <http://www.douban.com/>
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+ * Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+
+ * Redistributions in binary form must reproduce the above
+copyright notice, this list of conditions and the following disclaimer
+in the documentation and/or other materials provided with the
+distribution.
+
+ * Neither the name of the Douban Inc. nor the names of its
+contributors may be used to endorse or promote products derived from
+this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+"""
+from pyspark.serializers import PairSerializer, OptionSerializer, \
+ ArraySerializer
+
+
+def _do_python_join(rdd, other, numSplits, dispatch, valSerializer):
+ vs = rdd.mapPairs(lambda (k, v): (k, (1, v)))
+ ws = other.mapPairs(lambda (k, v): (k, (2, v)))
+ return vs.union(ws).groupByKey(numSplits) \
+ .flatMapValues(dispatch, valSerializer)
+
+
+def python_join(rdd, other, numSplits):
+ def dispatch(seq):
+ vbuf, wbuf = [], []
+ for (n, v) in seq:
+ if n == 1:
+ vbuf.append(v)
+ elif n == 2:
+ wbuf.append(v)
+ return [(v, w) for v in vbuf for w in wbuf]
+ valSerializer = PairSerializer(rdd.valSerializer, other.valSerializer)
+ return _do_python_join(rdd, other, numSplits, dispatch, valSerializer)
+
+
+def python_right_outer_join(rdd, other, numSplits):
+ def dispatch(seq):
+ vbuf, wbuf = [], []
+ for (n, v) in seq:
+ if n == 1:
+ vbuf.append(v)
+ elif n == 2:
+ wbuf.append(v)
+ if not vbuf:
+ vbuf.append(None)
+ return [(v, w) for v in vbuf for w in wbuf]
+ valSerializer = PairSerializer(OptionSerializer(rdd.valSerializer),
+ other.valSerializer)
+ return _do_python_join(rdd, other, numSplits, dispatch, valSerializer)
+
+
+def python_left_outer_join(rdd, other, numSplits):
+ def dispatch(seq):
+ vbuf, wbuf = [], []
+ for (n, v) in seq:
+ if n == 1:
+ vbuf.append(v)
+ elif n == 2:
+ wbuf.append(v)
+ if not wbuf:
+ wbuf.append(None)
+ return [(v, w) for v in vbuf for w in wbuf]
+ valSerializer = PairSerializer(rdd.valSerializer,
+ OptionSerializer(other.valSerializer))
+ return _do_python_join(rdd, other, numSplits, dispatch, valSerializer)
+
+
+def python_cogroup(rdd, other, numSplits):
+ resultValSerializer = PairSerializer(
+ ArraySerializer(rdd.valSerializer),
+ ArraySerializer(other.valSerializer))
+ vs = rdd.mapPairs(lambda (k, v): (k, (1, v)))
+ ws = other.mapPairs(lambda (k, v): (k, (2, v)))
+ def dispatch(seq):
+ vbuf, wbuf = [], []
+ for (n, v) in seq:
+ if n == 1:
+ vbuf.append(v)
+ elif n == 2:
+ wbuf.append(v)
+ return (vbuf, wbuf)
+ return vs.union(ws).groupByKey(numSplits) \
+ .mapValues(dispatch, resultValSerializer)
diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py
new file mode 100644
index 0000000000..c892e86b93
--- /dev/null
+++ b/pyspark/pyspark/rdd.py
@@ -0,0 +1,517 @@
+from base64 import standard_b64encode as b64enc
+from cloud.serialization import cloudpickle
+from itertools import chain
+
+from pyspark.serializers import PairSerializer, NopSerializer, \
+ OptionSerializer, ArraySerializer
+from pyspark.join import python_join, python_left_outer_join, \
+ python_right_outer_join, python_cogroup
+
+
+class RDD(object):
+
+ def __init__(self, jrdd, ctx, serializer=None):
+ self._jrdd = jrdd
+ self.is_cached = False
+ self.ctx = ctx
+ self.serializer = serializer or ctx.defaultSerializer
+
+ def _builder(self, jrdd, ctx):
+ return RDD(jrdd, ctx, self.serializer)
+
+ @property
+ def id(self):
+ return self._jrdd.id()
+
+ @property
+ def splits(self):
+ return self._jrdd.splits()
+
+ @classmethod
+ def _get_pipe_command(cls, command, functions):
+ if functions and not isinstance(functions, (list, tuple)):
+ functions = [functions]
+ worker_args = [command]
+ for f in functions:
+ worker_args.append(b64enc(cloudpickle.dumps(f)))
+ return " ".join(worker_args)
+
+ def cache(self):
+ self.is_cached = True
+ self._jrdd.cache()
+ return self
+
+ def map(self, f, serializer=None, preservesPartitioning=False):
+ return MappedRDD(self, f, serializer, preservesPartitioning)
+
+ def mapPairs(self, f, keySerializer=None, valSerializer=None,
+ preservesPartitioning=False):
+ return PairMappedRDD(self, f, keySerializer, valSerializer,
+ preservesPartitioning)
+
+ def flatMap(self, f, serializer=None):
+ """
+ >>> rdd = sc.parallelize([2, 3, 4])
+ >>> sorted(rdd.flatMap(lambda x: range(1, x)).collect())
+ [1, 1, 1, 2, 2, 3]
+ """
+ serializer = serializer or self.ctx.defaultSerializer
+ dumps = serializer.dumps
+ loads = self.serializer.loads
+ def func(x):
+ pickled_elems = (dumps(y) for y in f(loads(x)))
+ return "\n".join(pickled_elems) or None
+ pipe_command = RDD._get_pipe_command("map", [func])
+ class_manifest = self._jrdd.classManifest()
+ jrdd = self.ctx.jvm.PythonRDD(self._jrdd.rdd(), pipe_command,
+ False, self.ctx.pythonExec,
+ class_manifest).asJavaRDD()
+ return RDD(jrdd, self.ctx, serializer)
+
+ def flatMapPairs(self, f, keySerializer=None, valSerializer=None,
+ preservesPartitioning=False):
+ """
+ >>> rdd = sc.parallelize([2, 3, 4])
+ >>> sorted(rdd.flatMapPairs(lambda x: [(x, x), (x, x)]).collect())
+ [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)]
+ """
+ keySerializer = keySerializer or self.ctx.defaultSerializer
+ valSerializer = valSerializer or self.ctx.defaultSerializer
+ dumpk = keySerializer.dumps
+ dumpv = valSerializer.dumps
+ loads = self.serializer.loads
+ def func(x):
+ pairs = f(loads(x))
+ pickled_pairs = ((dumpk(k), dumpv(v)) for (k, v) in pairs)
+ return "\n".join(chain.from_iterable(pickled_pairs)) or None
+ pipe_command = RDD._get_pipe_command("map", [func])
+ class_manifest = self._jrdd.classManifest()
+ python_rdd = self.ctx.jvm.PythonPairRDD(self._jrdd.rdd(), pipe_command,
+ preservesPartitioning, self.ctx.pythonExec, class_manifest)
+ return PairRDD(python_rdd.asJavaPairRDD(), self.ctx, keySerializer,
+ valSerializer)
+
+ def filter(self, f):
+ """
+ >>> rdd = sc.parallelize([1, 2, 3, 4, 5])
+ >>> rdd.filter(lambda x: x % 2 == 0).collect()
+ [2, 4]
+ """
+ loads = self.serializer.loads
+ def filter_func(x): return x if f(loads(x)) else None
+ return self._builder(self._pipe(filter_func), self.ctx)
+
+ def _pipe(self, functions, command="map"):
+ class_manifest = self._jrdd.classManifest()
+ pipe_command = RDD._get_pipe_command(command, functions)
+ python_rdd = self.ctx.jvm.PythonRDD(self._jrdd.rdd(), pipe_command,
+ False, self.ctx.pythonExec, class_manifest)
+ return python_rdd.asJavaRDD()
+
+ def _pipePairs(self, functions, command="mapPairs",
+ preservesPartitioning=False):
+ class_manifest = self._jrdd.classManifest()
+ pipe_command = RDD._get_pipe_command(command, functions)
+ python_rdd = self.ctx.jvm.PythonPairRDD(self._jrdd.rdd(), pipe_command,
+ preservesPartitioning, self.ctx.pythonExec, class_manifest)
+ return python_rdd.asJavaPairRDD()
+
+ def distinct(self):
+ """
+ >>> sorted(sc.parallelize([1, 1, 2, 3]).distinct().collect())
+ [1, 2, 3]
+ """
+ if self.serializer.is_comparable:
+ return self._builder(self._jrdd.distinct(), self.ctx)
+ return self.mapPairs(lambda x: (x, "")) \
+ .reduceByKey(lambda x, _: x) \
+ .map(lambda (x, _): x)
+
+ def sample(self, withReplacement, fraction, seed):
+ jrdd = self._jrdd.sample(withReplacement, fraction, seed)
+ return self._builder(jrdd, self.ctx)
+
+ def takeSample(self, withReplacement, num, seed):
+ vals = self._jrdd.takeSample(withReplacement, num, seed)
+ return [self.serializer.loads(self.ctx.python_dump(x)) for x in vals]
+
+ def union(self, other):
+ """
+ >>> rdd = sc.parallelize([1, 1, 2, 3])
+ >>> rdd.union(rdd).collect()
+ [1, 1, 2, 3, 1, 1, 2, 3]
+ """
+ return self._builder(self._jrdd.union(other._jrdd), self.ctx)
+
+ # TODO: sort
+
+ # TODO: Overload __add___?
+
+ # TODO: glom
+
+ def cartesian(self, other):
+ """
+ >>> rdd = sc.parallelize([1, 2])
+ >>> sorted(rdd.cartesian(rdd).collect())
+ [(1, 1), (1, 2), (2, 1), (2, 2)]
+ """
+ return PairRDD(self._jrdd.cartesian(other._jrdd), self.ctx)
+
+ # numsplits
+ def groupBy(self, f, numSplits=None):
+ """
+ >>> rdd = sc.parallelize([1, 1, 2, 3, 5, 8])
+ >>> sorted(rdd.groupBy(lambda x: x % 2).collect())
+ [(0, [2, 8]), (1, [1, 1, 3, 5])]
+ """
+ return self.mapPairs(lambda x: (f(x), x)).groupByKey(numSplits)
+
+ # TODO: pipe
+
+ # TODO: mapPartitions
+
+ def foreach(self, f):
+ """
+ >>> def f(x): print x
+ >>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f)
+ """
+ self.map(f).collect() # Force evaluation
+
+ def collect(self):
+ vals = self._jrdd.collect()
+ return [self.serializer.loads(self.ctx.python_dump(x)) for x in vals]
+
+ def reduce(self, f, serializer=None):
+ """
+ >>> import operator
+ >>> sc.parallelize([1, 2, 3, 4, 5]).reduce(operator.add)
+ 15
+ """
+ serializer = serializer or self.ctx.defaultSerializer
+ loads = self.serializer.loads
+ dumps = serializer.dumps
+ def reduceFunction(x, acc):
+ if acc is None:
+ return loads(x)
+ else:
+ return f(loads(x), acc)
+ vals = self._pipe([reduceFunction, dumps], command="reduce").collect()
+ return reduce(f, (serializer.loads(x) for x in vals))
+
+ # TODO: fold
+
+ # TODO: aggregate
+
+ def count(self):
+ """
+ >>> sc.parallelize([2, 3, 4]).count()
+ 3L
+ """
+ return self._jrdd.count()
+
+ # TODO: count approx methods
+
+ def take(self, num):
+ """
+ >>> sc.parallelize([2, 3, 4]).take(2)
+ [2, 3]
+ """
+ vals = self._jrdd.take(num)
+ return [self.serializer.loads(self.ctx.python_dump(x)) for x in vals]
+
+ def first(self):
+ """
+ >>> sc.parallelize([2, 3, 4]).first()
+ 2
+ """
+ return self.serializer.loads(self.ctx.python_dump(self._jrdd.first()))
+
+ # TODO: saveAsTextFile
+
+ # TODO: saveAsObjectFile
+
+
+class PairRDD(RDD):
+
+ def __init__(self, jrdd, ctx, keySerializer=None, valSerializer=None):
+ RDD.__init__(self, jrdd, ctx)
+ self.keySerializer = keySerializer or ctx.defaultSerializer
+ self.valSerializer = valSerializer or ctx.defaultSerializer
+ self.serializer = \
+ PairSerializer(self.keySerializer, self.valSerializer)
+
+ def _builder(self, jrdd, ctx):
+ return PairRDD(jrdd, ctx, self.keySerializer, self.valSerializer)
+
+ def reduceByKey(self, func, numSplits=None):
+ """
+ >>> x = sc.parallelizePairs([("a", 1), ("b", 1), ("a", 1)])
+ >>> sorted(x.reduceByKey(lambda a, b: a + b).collect())
+ [('a', 2), ('b', 1)]
+ """
+ return self.combineByKey(lambda x: x, func, func, numSplits)
+
+ # TODO: reduceByKeyLocally()
+
+ # TODO: countByKey()
+
+ # TODO: partitionBy
+
+ def join(self, other, numSplits=None):
+ """
+ >>> x = sc.parallelizePairs([("a", 1), ("b", 4)])
+ >>> y = sc.parallelizePairs([("a", 2), ("a", 3)])
+ >>> x.join(y).collect()
+ [('a', (1, 2)), ('a', (1, 3))]
+
+ Check that we get a PairRDD-like object back:
+ >>> assert x.join(y).join
+ """
+ assert self.keySerializer.name == other.keySerializer.name
+ if self.keySerializer.is_comparable:
+ return PairRDD(self._jrdd.join(other._jrdd),
+ self.ctx, self.keySerializer,
+ PairSerializer(self.valSerializer, other.valSerializer))
+ else:
+ return python_join(self, other, numSplits)
+
+ def leftOuterJoin(self, other, numSplits=None):
+ """
+ >>> x = sc.parallelizePairs([("a", 1), ("b", 4)])
+ >>> y = sc.parallelizePairs([("a", 2)])
+ >>> sorted(x.leftOuterJoin(y).collect())
+ [('a', (1, 2)), ('b', (4, None))]
+ """
+ assert self.keySerializer.name == other.keySerializer.name
+ if self.keySerializer.is_comparable:
+ return PairRDD(self._jrdd.leftOuterJoin(other._jrdd),
+ self.ctx, self.keySerializer,
+ PairSerializer(self.valSerializer,
+ OptionSerializer(other.valSerializer)))
+ else:
+ return python_left_outer_join(self, other, numSplits)
+
+ def rightOuterJoin(self, other, numSplits=None):
+ """
+ >>> x = sc.parallelizePairs([("a", 1), ("b", 4)])
+ >>> y = sc.parallelizePairs([("a", 2)])
+ >>> sorted(y.rightOuterJoin(x).collect())
+ [('a', (2, 1)), ('b', (None, 4))]
+ """
+ assert self.keySerializer.name == other.keySerializer.name
+ if self.keySerializer.is_comparable:
+ return PairRDD(self._jrdd.rightOuterJoin(other._jrdd),
+ self.ctx, self.keySerializer,
+ PairSerializer(OptionSerializer(self.valSerializer),
+ other.valSerializer))
+ else:
+ return python_right_outer_join(self, other, numSplits)
+
+ def combineByKey(self, createCombiner, mergeValue, mergeCombiners,
+ numSplits=None, serializer=None):
+ """
+ >>> x = sc.parallelizePairs([("a", 1), ("b", 1), ("a", 1)])
+ >>> def f(x): return x
+ >>> def add(a, b): return a + str(b)
+ >>> sorted(x.combineByKey(str, add, add).collect())
+ [('a', '11'), ('b', '1')]
+ """
+ serializer = serializer or self.ctx.defaultSerializer
+ if numSplits is None:
+ numSplits = self.ctx.defaultParallelism
+ # Use hash() to create keys that are comparable in Java.
+ loadkv = self.serializer.loads
+ def pairify(kv):
+ # TODO: add method to deserialize only the key or value from
+ # a PairSerializer?
+ key = loadkv(kv)[0]
+ return (str(hash(key)), kv)
+ partitioner = self.ctx.jvm.spark.HashPartitioner(numSplits)
+ jrdd = self._pipePairs(pairify).partitionBy(partitioner)
+ pairified = PairRDD(jrdd, self.ctx, NopSerializer, self.serializer)
+
+ loads = PairSerializer(NopSerializer, self.serializer).loads
+ dumpk = self.keySerializer.dumps
+ dumpc = serializer.dumps
+
+ functions = [createCombiner, mergeValue, mergeCombiners, loads, dumpk,
+ dumpc]
+ jpairs = pairified._pipePairs(functions, "combine_by_key",
+ preservesPartitioning=True)
+ return PairRDD(jpairs, self.ctx, self.keySerializer, serializer)
+
+ def groupByKey(self, numSplits=None):
+ """
+ >>> x = sc.parallelizePairs([("a", 1), ("b", 1), ("a", 1)])
+ >>> sorted(x.groupByKey().collect())
+ [('a', [1, 1]), ('b', [1])]
+ """
+
+ def createCombiner(x):
+ return [x]
+
+ def mergeValue(xs, x):
+ xs.append(x)
+ return xs
+
+ def mergeCombiners(a, b):
+ return a + b
+
+ return self.combineByKey(createCombiner, mergeValue, mergeCombiners,
+ numSplits)
+
+ def collectAsMap(self):
+ """
+ >>> m = sc.parallelizePairs([(1, 2), (3, 4)]).collectAsMap()
+ >>> m[1]
+ 2
+ >>> m[3]
+ 4
+ """
+ m = self._jrdd.collectAsMap()
+ def loads(x):
+ (k, v) = x
+ return (self.keySerializer.loads(k), self.valSerializer.loads(v))
+ return dict(loads(x) for x in m.items())
+
+ def flatMapValues(self, f, valSerializer=None):
+ flat_map_fn = lambda (k, v): ((k, x) for x in f(v))
+ return self.flatMapPairs(flat_map_fn, self.keySerializer,
+ valSerializer, True)
+
+ def mapValues(self, f, valSerializer=None):
+ map_values_fn = lambda (k, v): (k, f(v))
+ return self.mapPairs(map_values_fn, self.keySerializer, valSerializer,
+ True)
+
+ # TODO: support varargs cogroup of several RDDs.
+ def groupWith(self, other):
+ return self.cogroup(other)
+
+ def cogroup(self, other, numSplits=None):
+ """
+ >>> x = sc.parallelizePairs([("a", 1), ("b", 4)])
+ >>> y = sc.parallelizePairs([("a", 2)])
+ >>> x.cogroup(y).collect()
+ [('a', ([1], [2])), ('b', ([4], []))]
+ """
+ assert self.keySerializer.name == other.keySerializer.name
+ resultValSerializer = PairSerializer(
+ ArraySerializer(self.valSerializer),
+ ArraySerializer(other.valSerializer))
+ if self.keySerializer.is_comparable:
+ return PairRDD(self._jrdd.cogroup(other._jrdd),
+ self.ctx, self.keySerializer, resultValSerializer)
+ else:
+ return python_cogroup(self, other, numSplits)
+
+ # TODO: `lookup` is disabled because we can't make direct comparisons based
+ # on the key; we need to compare the hash of the key to the hash of the
+ # keys in the pairs. This could be an expensive operation, since those
+ # hashes aren't retained.
+
+ # TODO: file saving
+
+
+class MappedRDDBase(object):
+ def __init__(self, prev, func, serializer, preservesPartitioning=False):
+ if isinstance(prev, MappedRDDBase) and not prev.is_cached:
+ prev_func = prev.func
+ self.func = lambda x: func(prev_func(x))
+ self.preservesPartitioning = \
+ prev.preservesPartitioning and preservesPartitioning
+ self._prev_jrdd = prev._prev_jrdd
+ self._prev_serializer = prev._prev_serializer
+ else:
+ self.func = func
+ self.preservesPartitioning = preservesPartitioning
+ self._prev_jrdd = prev._jrdd
+ self._prev_serializer = prev.serializer
+ self.serializer = serializer or prev.ctx.defaultSerializer
+ self.is_cached = False
+ self.ctx = prev.ctx
+ self.prev = prev
+ self._jrdd_val = None
+
+
+class MappedRDD(MappedRDDBase, RDD):
+ """
+ >>> rdd = sc.parallelize([1, 2, 3, 4])
+ >>> rdd.map(lambda x: 2 * x).cache().map(lambda x: 2 * x).collect()
+ [4, 8, 12, 16]
+ >>> rdd.map(lambda x: 2 * x).map(lambda x: 2 * x).collect()
+ [4, 8, 12, 16]
+ """
+
+ @property
+ def _jrdd(self):
+ if not self._jrdd_val:
+ udf = self.func
+ loads = self._prev_serializer.loads
+ dumps = self.serializer.dumps
+ func = lambda x: dumps(udf(loads(x)))
+ pipe_command = RDD._get_pipe_command("map", [func])
+ class_manifest = self._prev_jrdd.classManifest()
+ python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(),
+ pipe_command, self.preservesPartitioning, self.ctx.pythonExec,
+ class_manifest)
+ self._jrdd_val = python_rdd.asJavaRDD()
+ return self._jrdd_val
+
+
+class PairMappedRDD(MappedRDDBase, PairRDD):
+ """
+ >>> rdd = sc.parallelize([1, 2, 3, 4])
+ >>> rdd.mapPairs(lambda x: (x, x)) \\
+ ... .mapPairs(lambda (x, y): (2*x, 2*y)) \\
+ ... .collect()
+ [(2, 2), (4, 4), (6, 6), (8, 8)]
+ >>> rdd.mapPairs(lambda x: (x, x)) \\
+ ... .mapPairs(lambda (x, y): (2*x, 2*y)) \\
+ ... .map(lambda (x, _): x).collect()
+ [2, 4, 6, 8]
+ """
+
+ def __init__(self, prev, func, keySerializer=None, valSerializer=None,
+ preservesPartitioning=False):
+ self.keySerializer = keySerializer or prev.ctx.defaultSerializer
+ self.valSerializer = valSerializer or prev.ctx.defaultSerializer
+ serializer = PairSerializer(self.keySerializer, self.valSerializer)
+ MappedRDDBase.__init__(self, prev, func, serializer,
+ preservesPartitioning)
+
+ @property
+ def _jrdd(self):
+ if not self._jrdd_val:
+ udf = self.func
+ loads = self._prev_serializer.loads
+ dumpk = self.keySerializer.dumps
+ dumpv = self.valSerializer.dumps
+ def func(x):
+ (k, v) = udf(loads(x))
+ return (dumpk(k), dumpv(v))
+ pipe_command = RDD._get_pipe_command("mapPairs", [func])
+ class_manifest = self._prev_jrdd.classManifest()
+ self._jrdd_val = self.ctx.jvm.PythonPairRDD(self._prev_jrdd.rdd(),
+ pipe_command, self.preservesPartitioning, self.ctx.pythonExec,
+ class_manifest).asJavaPairRDD()
+ return self._jrdd_val
+
+
+def _test():
+ import doctest
+ from pyspark.context import SparkContext
+ from pyspark.serializers import PickleSerializer, JSONSerializer
+ globs = globals().copy()
+ globs['sc'] = SparkContext('local', 'PythonTest',
+ defaultSerializer=JSONSerializer)
+ doctest.testmod(globs=globs)
+ globs['sc'].stop()
+ globs['sc'] = SparkContext('local', 'PythonTest',
+ defaultSerializer=PickleSerializer)
+ doctest.testmod(globs=globs)
+ globs['sc'].stop()
+
+
+if __name__ == "__main__":
+ _test()
diff --git a/pyspark/pyspark/serializers.py b/pyspark/pyspark/serializers.py
new file mode 100644
index 0000000000..b113f5656b
--- /dev/null
+++ b/pyspark/pyspark/serializers.py
@@ -0,0 +1,229 @@
+"""
+Data serialization methods.
+
+The Spark Python API is built on top of the Spark Java API. RDDs created in
+Python are stored in Java as RDDs of Strings. Python objects are automatically
+serialized/deserialized, so this representation is transparent to the end-user.
+
+------------------
+Serializer objects
+------------------
+
+`Serializer` objects are used to customize how an RDD's values are serialized.
+
+Each `Serializer` is a named tuple with four fields:
+
+ - A `dumps` function, for serializing a Python object to a string.
+
+ - A `loads` function, for deserializing a Python object from a string.
+
+ - An `is_comparable` field, True if equal Python objects are serialized to
+ equal strings, and False otherwise.
+
+ - A `name` field, used to identify the Serializer. Serializers are
+ compared for equality by comparing their names.
+
+The serializer's output should be base64-encoded.
+
+------------------------------------------------------------------
+`is_comparable`: comparing serialized representations for equality
+------------------------------------------------------------------
+
+If `is_comparable` is False, the serializer's representations of equal objects
+are not required to be equal:
+
+>>> import pickle
+>>> a = {1: 0, 9: 0}
+>>> b = {9: 0, 1: 0}
+>>> a == b
+True
+>>> pickle.dumps(a) == pickle.dumps(b)
+False
+
+RDDs with comparable serializers can use native Java implementations of
+operations like join() and distinct(), which may lead to better performance by
+eliminating deserialization and Python comparisons.
+
+The default JSONSerializer produces comparable representations of common Python
+data structures.
+
+--------------------------------------
+Examples of serialized representations
+--------------------------------------
+
+The RDD transformations that use Python UDFs are implemented in terms of
+a modified `PipedRDD.pipe()` function. For each record `x` in the RDD, the
+`pipe()` function pipes `x.toString()` to a Python worker process, which
+deserializes the string into a Python object, executes user-defined functions,
+and outputs serialized Python objects.
+
+The regular `toString()` method returns an ambiguous representation, due to the
+way that Scala `Option` instances are printed:
+
+>>> from context import SparkContext
+>>> sc = SparkContext("local", "SerializerDocs")
+>>> x = sc.parallelizePairs([("a", 1), ("b", 4)])
+>>> y = sc.parallelizePairs([("a", 2)])
+
+>>> print y.rightOuterJoin(x)._jrdd.first().toString()
+(ImEi,(Some(Mg==),MQ==))
+
+In Java, preprocessing is performed to handle Option instances, so the Python
+process receives unambiguous input:
+
+>>> print sc.python_dump(y.rightOuterJoin(x)._jrdd.first())
+(ImEi,(Mg==,MQ==))
+
+The base64-encoding eliminates the need to escape newlines, parentheses and
+other special characters.
+
+----------------------
+Serializer composition
+----------------------
+
+In order to handle nested structures, which could contain object serialized
+with different serializers, the RDD module composes serializers. For example,
+the serializers in the previous example are:
+
+>>> print x.serializer.name
+PairSerializer<JSONSerializer, JSONSerializer>
+
+>>> print y.serializer.name
+PairSerializer<JSONSerializer, JSONSerializer>
+
+>>> print y.rightOuterJoin(x).serializer.name
+PairSerializer<JSONSerializer, PairSerializer<OptionSerializer<JSONSerializer>, JSONSerializer>>
+"""
+from base64 import standard_b64encode, standard_b64decode
+from collections import namedtuple
+import cPickle
+import simplejson
+
+
+Serializer = namedtuple("Serializer",
+ ["dumps","loads", "is_comparable", "name"])
+
+
+NopSerializer = Serializer(str, str, True, "NopSerializer")
+
+
+JSONSerializer = Serializer(
+ lambda obj: standard_b64encode(simplejson.dumps(obj, sort_keys=True,
+ separators=(',', ':'))),
+ lambda s: simplejson.loads(standard_b64decode(s)),
+ True,
+ "JSONSerializer"
+)
+
+
+PickleSerializer = Serializer(
+ lambda obj: standard_b64encode(cPickle.dumps(obj)),
+ lambda s: cPickle.loads(standard_b64decode(s)),
+ False,
+ "PickleSerializer"
+)
+
+
+def OptionSerializer(serializer):
+ """
+ >>> ser = OptionSerializer(NopSerializer)
+ >>> ser.loads(ser.dumps("Hello, World!"))
+ 'Hello, World!'
+ >>> ser.loads(ser.dumps(None)) is None
+ True
+ """
+ none_placeholder = '*'
+
+ def dumps(x):
+ if x is None:
+ return none_placeholder
+ else:
+ return serializer.dumps(x)
+
+ def loads(x):
+ if x == none_placeholder:
+ return None
+ else:
+ return serializer.loads(x)
+
+ name = "OptionSerializer<%s>" % serializer.name
+ return Serializer(dumps, loads, serializer.is_comparable, name)
+
+
+def PairSerializer(keySerializer, valSerializer):
+ """
+ Returns a Serializer for a (key, value) pair.
+
+ >>> ser = PairSerializer(JSONSerializer, JSONSerializer)
+ >>> ser.loads(ser.dumps((1, 2)))
+ (1, 2)
+
+ >>> ser = PairSerializer(JSONSerializer, ser)
+ >>> ser.loads(ser.dumps((1, (2, 3))))
+ (1, (2, 3))
+ """
+ def loads(kv):
+ try:
+ (key, val) = kv[1:-1].split(',', 1)
+ key = keySerializer.loads(key)
+ val = valSerializer.loads(val)
+ return (key, val)
+ except:
+ print "Error in deserializing pair from '%s'" % str(kv)
+ raise
+
+ def dumps(kv):
+ (key, val) = kv
+ return"(%s,%s)" % (keySerializer.dumps(key), valSerializer.dumps(val))
+ is_comparable = \
+ keySerializer.is_comparable and valSerializer.is_comparable
+ name = "PairSerializer<%s, %s>" % (keySerializer.name, valSerializer.name)
+ return Serializer(dumps, loads, is_comparable, name)
+
+
+def ArraySerializer(serializer):
+ """
+ >>> ser = ArraySerializer(JSONSerializer)
+ >>> ser.loads(ser.dumps([1, 2, 3, 4]))
+ [1, 2, 3, 4]
+ >>> ser = ArraySerializer(PairSerializer(JSONSerializer, PickleSerializer))
+ >>> ser.loads(ser.dumps([('a', 1), ('b', 2)]))
+ [('a', 1), ('b', 2)]
+ >>> ser.loads(ser.dumps([('a', 1)]))
+ [('a', 1)]
+ >>> ser.loads(ser.dumps([]))
+ []
+ """
+ def dumps(arr):
+ if arr == []:
+ return '[]'
+ else:
+ return '[' + '|'.join(serializer.dumps(x) for x in arr) + ']'
+
+ def loads(s):
+ if s == '[]':
+ return []
+ items = s[1:-1]
+ if '|' in items:
+ items = items.split('|')
+ else:
+ items = [items]
+ return [serializer.loads(x) for x in items]
+
+ name = "ArraySerializer<%s>" % serializer.name
+ return Serializer(dumps, loads, serializer.is_comparable, name)
+
+
+# TODO: IntegerSerializer
+
+
+# TODO: DoubleSerializer
+
+
+def _test():
+ import doctest
+ doctest.testmod()
+
+
+if __name__ == "__main__":
+ _test()
diff --git a/pyspark/pyspark/worker.py b/pyspark/pyspark/worker.py
new file mode 100644
index 0000000000..4d4cc939c3
--- /dev/null
+++ b/pyspark/pyspark/worker.py
@@ -0,0 +1,97 @@
+"""
+Worker that receives input from Piped RDD.
+"""
+import sys
+from base64 import standard_b64decode
+# CloudPickler needs to be imported so that depicklers are registered using the
+# copy_reg module.
+from cloud.serialization.cloudpickle import CloudPickler
+import cPickle
+
+
+# Redirect stdout to stderr so that users must return values from functions.
+old_stdout = sys.stdout
+sys.stdout = sys.stderr
+
+
+def load_function():
+ return cPickle.loads(standard_b64decode(sys.stdin.readline().strip()))
+
+
+def output(x):
+ for line in x.split("\n"):
+ old_stdout.write(line.rstrip("\r\n") + "\n")
+
+
+def read_input():
+ for line in sys.stdin:
+ yield line.rstrip("\r\n")
+
+
+def do_combine_by_key():
+ create_combiner = load_function()
+ merge_value = load_function()
+ merge_combiners = load_function() # TODO: not used.
+ depickler = load_function()
+ key_pickler = load_function()
+ combiner_pickler = load_function()
+ combiners = {}
+ for line in read_input():
+ # Discard the hashcode added in the Python combineByKey() method.
+ (key, value) = depickler(line)[1]
+ if key not in combiners:
+ combiners[key] = create_combiner(value)
+ else:
+ combiners[key] = merge_value(combiners[key], value)
+ for (key, combiner) in combiners.iteritems():
+ output(key_pickler(key))
+ output(combiner_pickler(combiner))
+
+
+def do_map(map_pairs=False):
+ f = load_function()
+ for line in read_input():
+ try:
+ out = f(line)
+ if out is not None:
+ if map_pairs:
+ for x in out:
+ output(x)
+ else:
+ output(out)
+ except:
+ sys.stderr.write("Error processing line '%s'\n" % line)
+ raise
+
+
+def do_reduce():
+ f = load_function()
+ dumps = load_function()
+ acc = None
+ for line in read_input():
+ acc = f(line, acc)
+ output(dumps(acc))
+
+
+def do_echo():
+ old_stdout.writelines(sys.stdin.readlines())
+
+
+def main():
+ command = sys.stdin.readline().strip()
+ if command == "map":
+ do_map(map_pairs=False)
+ elif command == "mapPairs":
+ do_map(map_pairs=True)
+ elif command == "combine_by_key":
+ do_combine_by_key()
+ elif command == "reduce":
+ do_reduce()
+ elif command == "echo":
+ do_echo()
+ else:
+ raise Exception("Unsupported command %s" % command)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/pyspark/requirements.txt b/pyspark/requirements.txt
new file mode 100644
index 0000000000..d9b3fe40bd
--- /dev/null
+++ b/pyspark/requirements.txt
@@ -0,0 +1,9 @@
+# The Python API relies on some new features from the Py4J development branch.
+# pip can't install Py4J from git because the setup.py file for the Python
+# package is not at the root of the git repository. It may be possible to
+# install Py4J from git once https://github.com/pypa/pip/pull/526 is merged.
+
+# git+git://github.com/bartdag/py4j.git@3dbf380d3d2cdeb9aab394454ea74d80c4aba1ea
+
+simplejson==2.6.1
+cloud==2.5.5