From 886b39de557b4d5f54f5ca11559fca9799534280 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 10 Aug 2012 01:10:02 -0700 Subject: Add Python API. --- pyspark/pyspark/__init__.py | 0 pyspark/pyspark/context.py | 69 +++++ pyspark/pyspark/examples/__init__.py | 0 pyspark/pyspark/examples/kmeans.py | 56 ++++ pyspark/pyspark/examples/pi.py | 20 ++ pyspark/pyspark/examples/tc.py | 49 ++++ pyspark/pyspark/java_gateway.py | 20 ++ pyspark/pyspark/join.py | 104 +++++++ pyspark/pyspark/rdd.py | 517 +++++++++++++++++++++++++++++++++++ pyspark/pyspark/serializers.py | 229 ++++++++++++++++ pyspark/pyspark/worker.py | 97 +++++++ pyspark/requirements.txt | 9 + 12 files changed, 1170 insertions(+) create mode 100644 pyspark/pyspark/__init__.py create mode 100644 pyspark/pyspark/context.py create mode 100644 pyspark/pyspark/examples/__init__.py create mode 100644 pyspark/pyspark/examples/kmeans.py create mode 100644 pyspark/pyspark/examples/pi.py create mode 100644 pyspark/pyspark/examples/tc.py create mode 100644 pyspark/pyspark/java_gateway.py create mode 100644 pyspark/pyspark/join.py create mode 100644 pyspark/pyspark/rdd.py create mode 100644 pyspark/pyspark/serializers.py create mode 100644 pyspark/pyspark/worker.py create mode 100644 pyspark/requirements.txt (limited to 'pyspark') diff --git a/pyspark/pyspark/__init__.py b/pyspark/pyspark/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py new file mode 100644 index 0000000000..587ab12b5f --- /dev/null +++ b/pyspark/pyspark/context.py @@ -0,0 +1,69 @@ +import os +import atexit +from tempfile import NamedTemporaryFile + +from pyspark.java_gateway import launch_gateway +from pyspark.serializers import JSONSerializer, NopSerializer +from pyspark.rdd import RDD, PairRDD + + +class SparkContext(object): + + gateway = launch_gateway() + jvm = gateway.jvm + python_dump = jvm.spark.api.python.PythonRDD.pythonDump + + def __init__(self, master, name, defaultSerializer=JSONSerializer, + defaultParallelism=None, pythonExec='python'): + self.master = master + self.name = name + self._jsc = self.jvm.JavaSparkContext(master, name) + self.defaultSerializer = defaultSerializer + self.defaultParallelism = \ + defaultParallelism or self._jsc.sc().defaultParallelism() + self.pythonExec = pythonExec + + def __del__(self): + if self._jsc: + self._jsc.stop() + + def stop(self): + self._jsc.stop() + self._jsc = None + + def parallelize(self, c, numSlices=None, serializer=None): + serializer = serializer or self.defaultSerializer + numSlices = numSlices or self.defaultParallelism + # Calling the Java parallelize() method with an ArrayList is too slow, + # because it sends O(n) Py4J commands. As an alternative, serialized + # objects are written to a file and loaded through textFile(). + tempFile = NamedTemporaryFile(delete=False) + tempFile.writelines(serializer.dumps(x) + '\n' for x in c) + tempFile.close() + atexit.register(lambda: os.unlink(tempFile.name)) + return self.textFile(tempFile.name, numSlices, serializer) + + def parallelizePairs(self, c, numSlices=None, keySerializer=None, + valSerializer=None): + """ + >>> sc = SparkContext("local", "test") + >>> rdd = sc.parallelizePairs([(1, 2), (3, 4)]) + >>> rdd.collect() + [(1, 2), (3, 4)] + """ + keySerializer = keySerializer or self.defaultSerializer + valSerializer = valSerializer or self.defaultSerializer + numSlices = numSlices or self.defaultParallelism + tempFile = NamedTemporaryFile(delete=False) + for (k, v) in c: + tempFile.write(keySerializer.dumps(k).rstrip('\r\n') + '\n') + tempFile.write(valSerializer.dumps(v).rstrip('\r\n') + '\n') + tempFile.close() + atexit.register(lambda: os.unlink(tempFile.name)) + jrdd = self.textFile(tempFile.name, numSlices)._pipePairs([], "echo") + return PairRDD(jrdd, self, keySerializer, valSerializer) + + def textFile(self, name, numSlices=None, serializer=NopSerializer): + numSlices = numSlices or self.defaultParallelism + jrdd = self._jsc.textFile(name, numSlices) + return RDD(jrdd, self, serializer) diff --git a/pyspark/pyspark/examples/__init__.py b/pyspark/pyspark/examples/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/pyspark/pyspark/examples/kmeans.py b/pyspark/pyspark/examples/kmeans.py new file mode 100644 index 0000000000..0761d6e395 --- /dev/null +++ b/pyspark/pyspark/examples/kmeans.py @@ -0,0 +1,56 @@ +import sys + +from pyspark.context import SparkContext + + +def parseVector(line): + return [float(x) for x in line.split(' ')] + + +def addVec(x, y): + return [a + b for (a, b) in zip(x, y)] + + +def squaredDist(x, y): + return sum((a - b) ** 2 for (a, b) in zip(x, y)) + + +def closestPoint(p, centers): + bestIndex = 0 + closest = float("+inf") + for i in range(len(centers)): + tempDist = squaredDist(p, centers[i]) + if tempDist < closest: + closest = tempDist + bestIndex = i + return bestIndex + + +if __name__ == "__main__": + if len(sys.argv) < 5: + print >> sys.stderr, \ + "Usage: PythonKMeans " + exit(-1) + sc = SparkContext(sys.argv[1], "PythonKMeans") + lines = sc.textFile(sys.argv[2]) + data = lines.map(parseVector).cache() + K = int(sys.argv[3]) + convergeDist = float(sys.argv[4]) + + kPoints = data.takeSample(False, K, 34) + tempDist = 1.0 + + while tempDist > convergeDist: + closest = data.mapPairs( + lambda p : (closestPoint(p, kPoints), (p, 1))) + pointStats = closest.reduceByKey( + lambda (x1, y1), (x2, y2): (addVec(x1, x2), y1 + y2)) + newPoints = pointStats.mapPairs( + lambda (x, (y, z)): (x, [a / z for a in y])).collect() + + tempDist = sum(squaredDist(kPoints[x], y) for (x, y) in newPoints) + + for (x, y) in newPoints: + kPoints[x] = y + + print "Final centers: " + str(kPoints) diff --git a/pyspark/pyspark/examples/pi.py b/pyspark/pyspark/examples/pi.py new file mode 100644 index 0000000000..ad77694c41 --- /dev/null +++ b/pyspark/pyspark/examples/pi.py @@ -0,0 +1,20 @@ +import sys +from random import random +from operator import add +from pyspark.context import SparkContext + + +if __name__ == "__main__": + if len(sys.argv) == 1: + print >> sys.stderr, \ + "Usage: PythonPi []" + exit(-1) + sc = SparkContext(sys.argv[1], "PythonKMeans") + slices = sys.argv[2] if len(sys.argv) > 2 else 2 + n = 100000 * slices + def f(_): + x = random() * 2 - 1 + y = random() * 2 - 1 + return 1 if x ** 2 + y ** 2 < 1 else 0 + count = sc.parallelize(xrange(1, n+1), slices).map(f).reduce(add) + print "Pi is roughly %f" % (4.0 * count / n) diff --git a/pyspark/pyspark/examples/tc.py b/pyspark/pyspark/examples/tc.py new file mode 100644 index 0000000000..2796fdc6ad --- /dev/null +++ b/pyspark/pyspark/examples/tc.py @@ -0,0 +1,49 @@ +import sys +from random import Random +from pyspark.context import SparkContext + +numEdges = 200 +numVertices = 100 +rand = Random(42) + + +def generateGraph(): + edges = set() + while len(edges) < numEdges: + src = rand.randrange(0, numEdges) + dst = rand.randrange(0, numEdges) + if src != dst: + edges.add((src, dst)) + return edges + + +if __name__ == "__main__": + if len(sys.argv) == 1: + print >> sys.stderr, \ + "Usage: PythonTC []" + exit(-1) + sc = SparkContext(sys.argv[1], "PythonKMeans") + slices = sys.argv[2] if len(sys.argv) > 2 else 2 + tc = sc.parallelizePairs(generateGraph(), slices).cache() + + # Linear transitive closure: each round grows paths by one edge, + # by joining the graph's edges with the already-discovered paths. + # e.g. join the path (y, z) from the TC with the edge (x, y) from + # the graph to obtain the path (x, z). + + # Because join() joins on keys, the edges are stored in reversed order. + edges = tc.mapPairs(lambda (x, y): (y, x)) + + oldCount = 0L + nextCount = tc.count() + while True: + oldCount = nextCount + # Perform the join, obtaining an RDD of (y, (z, x)) pairs, + # then project the result to obtain the new (x, z) paths. + new_edges = tc.join(edges).mapPairs(lambda (_, (a, b)): (b, a)) + tc = tc.union(new_edges).distinct().cache() + nextCount = tc.count() + if nextCount == oldCount: + break + + print "TC has %i edges" % tc.count() diff --git a/pyspark/pyspark/java_gateway.py b/pyspark/pyspark/java_gateway.py new file mode 100644 index 0000000000..2df80aee85 --- /dev/null +++ b/pyspark/pyspark/java_gateway.py @@ -0,0 +1,20 @@ +import glob +import os +from py4j.java_gateway import java_import, JavaGateway + + +SPARK_HOME = os.environ["SPARK_HOME"] + + +assembly_jar = glob.glob(os.path.join(SPARK_HOME, "core/target") + \ + "/spark-core-assembly-*-SNAPSHOT.jar")[0] + + +def launch_gateway(): + gateway = JavaGateway.launch_gateway(classpath=assembly_jar, + javaopts=["-Xmx256m"], die_on_exit=True) + java_import(gateway.jvm, "spark.api.java.*") + java_import(gateway.jvm, "spark.api.python.*") + java_import(gateway.jvm, "scala.Tuple2") + java_import(gateway.jvm, "spark.api.python.PythonRDD.pythonDump") + return gateway diff --git a/pyspark/pyspark/join.py b/pyspark/pyspark/join.py new file mode 100644 index 0000000000..c67520fce8 --- /dev/null +++ b/pyspark/pyspark/join.py @@ -0,0 +1,104 @@ +""" +Copyright (c) 2011, Douban Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + + * Neither the name of the Douban Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" +from pyspark.serializers import PairSerializer, OptionSerializer, \ + ArraySerializer + + +def _do_python_join(rdd, other, numSplits, dispatch, valSerializer): + vs = rdd.mapPairs(lambda (k, v): (k, (1, v))) + ws = other.mapPairs(lambda (k, v): (k, (2, v))) + return vs.union(ws).groupByKey(numSplits) \ + .flatMapValues(dispatch, valSerializer) + + +def python_join(rdd, other, numSplits): + def dispatch(seq): + vbuf, wbuf = [], [] + for (n, v) in seq: + if n == 1: + vbuf.append(v) + elif n == 2: + wbuf.append(v) + return [(v, w) for v in vbuf for w in wbuf] + valSerializer = PairSerializer(rdd.valSerializer, other.valSerializer) + return _do_python_join(rdd, other, numSplits, dispatch, valSerializer) + + +def python_right_outer_join(rdd, other, numSplits): + def dispatch(seq): + vbuf, wbuf = [], [] + for (n, v) in seq: + if n == 1: + vbuf.append(v) + elif n == 2: + wbuf.append(v) + if not vbuf: + vbuf.append(None) + return [(v, w) for v in vbuf for w in wbuf] + valSerializer = PairSerializer(OptionSerializer(rdd.valSerializer), + other.valSerializer) + return _do_python_join(rdd, other, numSplits, dispatch, valSerializer) + + +def python_left_outer_join(rdd, other, numSplits): + def dispatch(seq): + vbuf, wbuf = [], [] + for (n, v) in seq: + if n == 1: + vbuf.append(v) + elif n == 2: + wbuf.append(v) + if not wbuf: + wbuf.append(None) + return [(v, w) for v in vbuf for w in wbuf] + valSerializer = PairSerializer(rdd.valSerializer, + OptionSerializer(other.valSerializer)) + return _do_python_join(rdd, other, numSplits, dispatch, valSerializer) + + +def python_cogroup(rdd, other, numSplits): + resultValSerializer = PairSerializer( + ArraySerializer(rdd.valSerializer), + ArraySerializer(other.valSerializer)) + vs = rdd.mapPairs(lambda (k, v): (k, (1, v))) + ws = other.mapPairs(lambda (k, v): (k, (2, v))) + def dispatch(seq): + vbuf, wbuf = [], [] + for (n, v) in seq: + if n == 1: + vbuf.append(v) + elif n == 2: + wbuf.append(v) + return (vbuf, wbuf) + return vs.union(ws).groupByKey(numSplits) \ + .mapValues(dispatch, resultValSerializer) diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py new file mode 100644 index 0000000000..c892e86b93 --- /dev/null +++ b/pyspark/pyspark/rdd.py @@ -0,0 +1,517 @@ +from base64 import standard_b64encode as b64enc +from cloud.serialization import cloudpickle +from itertools import chain + +from pyspark.serializers import PairSerializer, NopSerializer, \ + OptionSerializer, ArraySerializer +from pyspark.join import python_join, python_left_outer_join, \ + python_right_outer_join, python_cogroup + + +class RDD(object): + + def __init__(self, jrdd, ctx, serializer=None): + self._jrdd = jrdd + self.is_cached = False + self.ctx = ctx + self.serializer = serializer or ctx.defaultSerializer + + def _builder(self, jrdd, ctx): + return RDD(jrdd, ctx, self.serializer) + + @property + def id(self): + return self._jrdd.id() + + @property + def splits(self): + return self._jrdd.splits() + + @classmethod + def _get_pipe_command(cls, command, functions): + if functions and not isinstance(functions, (list, tuple)): + functions = [functions] + worker_args = [command] + for f in functions: + worker_args.append(b64enc(cloudpickle.dumps(f))) + return " ".join(worker_args) + + def cache(self): + self.is_cached = True + self._jrdd.cache() + return self + + def map(self, f, serializer=None, preservesPartitioning=False): + return MappedRDD(self, f, serializer, preservesPartitioning) + + def mapPairs(self, f, keySerializer=None, valSerializer=None, + preservesPartitioning=False): + return PairMappedRDD(self, f, keySerializer, valSerializer, + preservesPartitioning) + + def flatMap(self, f, serializer=None): + """ + >>> rdd = sc.parallelize([2, 3, 4]) + >>> sorted(rdd.flatMap(lambda x: range(1, x)).collect()) + [1, 1, 1, 2, 2, 3] + """ + serializer = serializer or self.ctx.defaultSerializer + dumps = serializer.dumps + loads = self.serializer.loads + def func(x): + pickled_elems = (dumps(y) for y in f(loads(x))) + return "\n".join(pickled_elems) or None + pipe_command = RDD._get_pipe_command("map", [func]) + class_manifest = self._jrdd.classManifest() + jrdd = self.ctx.jvm.PythonRDD(self._jrdd.rdd(), pipe_command, + False, self.ctx.pythonExec, + class_manifest).asJavaRDD() + return RDD(jrdd, self.ctx, serializer) + + def flatMapPairs(self, f, keySerializer=None, valSerializer=None, + preservesPartitioning=False): + """ + >>> rdd = sc.parallelize([2, 3, 4]) + >>> sorted(rdd.flatMapPairs(lambda x: [(x, x), (x, x)]).collect()) + [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)] + """ + keySerializer = keySerializer or self.ctx.defaultSerializer + valSerializer = valSerializer or self.ctx.defaultSerializer + dumpk = keySerializer.dumps + dumpv = valSerializer.dumps + loads = self.serializer.loads + def func(x): + pairs = f(loads(x)) + pickled_pairs = ((dumpk(k), dumpv(v)) for (k, v) in pairs) + return "\n".join(chain.from_iterable(pickled_pairs)) or None + pipe_command = RDD._get_pipe_command("map", [func]) + class_manifest = self._jrdd.classManifest() + python_rdd = self.ctx.jvm.PythonPairRDD(self._jrdd.rdd(), pipe_command, + preservesPartitioning, self.ctx.pythonExec, class_manifest) + return PairRDD(python_rdd.asJavaPairRDD(), self.ctx, keySerializer, + valSerializer) + + def filter(self, f): + """ + >>> rdd = sc.parallelize([1, 2, 3, 4, 5]) + >>> rdd.filter(lambda x: x % 2 == 0).collect() + [2, 4] + """ + loads = self.serializer.loads + def filter_func(x): return x if f(loads(x)) else None + return self._builder(self._pipe(filter_func), self.ctx) + + def _pipe(self, functions, command="map"): + class_manifest = self._jrdd.classManifest() + pipe_command = RDD._get_pipe_command(command, functions) + python_rdd = self.ctx.jvm.PythonRDD(self._jrdd.rdd(), pipe_command, + False, self.ctx.pythonExec, class_manifest) + return python_rdd.asJavaRDD() + + def _pipePairs(self, functions, command="mapPairs", + preservesPartitioning=False): + class_manifest = self._jrdd.classManifest() + pipe_command = RDD._get_pipe_command(command, functions) + python_rdd = self.ctx.jvm.PythonPairRDD(self._jrdd.rdd(), pipe_command, + preservesPartitioning, self.ctx.pythonExec, class_manifest) + return python_rdd.asJavaPairRDD() + + def distinct(self): + """ + >>> sorted(sc.parallelize([1, 1, 2, 3]).distinct().collect()) + [1, 2, 3] + """ + if self.serializer.is_comparable: + return self._builder(self._jrdd.distinct(), self.ctx) + return self.mapPairs(lambda x: (x, "")) \ + .reduceByKey(lambda x, _: x) \ + .map(lambda (x, _): x) + + def sample(self, withReplacement, fraction, seed): + jrdd = self._jrdd.sample(withReplacement, fraction, seed) + return self._builder(jrdd, self.ctx) + + def takeSample(self, withReplacement, num, seed): + vals = self._jrdd.takeSample(withReplacement, num, seed) + return [self.serializer.loads(self.ctx.python_dump(x)) for x in vals] + + def union(self, other): + """ + >>> rdd = sc.parallelize([1, 1, 2, 3]) + >>> rdd.union(rdd).collect() + [1, 1, 2, 3, 1, 1, 2, 3] + """ + return self._builder(self._jrdd.union(other._jrdd), self.ctx) + + # TODO: sort + + # TODO: Overload __add___? + + # TODO: glom + + def cartesian(self, other): + """ + >>> rdd = sc.parallelize([1, 2]) + >>> sorted(rdd.cartesian(rdd).collect()) + [(1, 1), (1, 2), (2, 1), (2, 2)] + """ + return PairRDD(self._jrdd.cartesian(other._jrdd), self.ctx) + + # numsplits + def groupBy(self, f, numSplits=None): + """ + >>> rdd = sc.parallelize([1, 1, 2, 3, 5, 8]) + >>> sorted(rdd.groupBy(lambda x: x % 2).collect()) + [(0, [2, 8]), (1, [1, 1, 3, 5])] + """ + return self.mapPairs(lambda x: (f(x), x)).groupByKey(numSplits) + + # TODO: pipe + + # TODO: mapPartitions + + def foreach(self, f): + """ + >>> def f(x): print x + >>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f) + """ + self.map(f).collect() # Force evaluation + + def collect(self): + vals = self._jrdd.collect() + return [self.serializer.loads(self.ctx.python_dump(x)) for x in vals] + + def reduce(self, f, serializer=None): + """ + >>> import operator + >>> sc.parallelize([1, 2, 3, 4, 5]).reduce(operator.add) + 15 + """ + serializer = serializer or self.ctx.defaultSerializer + loads = self.serializer.loads + dumps = serializer.dumps + def reduceFunction(x, acc): + if acc is None: + return loads(x) + else: + return f(loads(x), acc) + vals = self._pipe([reduceFunction, dumps], command="reduce").collect() + return reduce(f, (serializer.loads(x) for x in vals)) + + # TODO: fold + + # TODO: aggregate + + def count(self): + """ + >>> sc.parallelize([2, 3, 4]).count() + 3L + """ + return self._jrdd.count() + + # TODO: count approx methods + + def take(self, num): + """ + >>> sc.parallelize([2, 3, 4]).take(2) + [2, 3] + """ + vals = self._jrdd.take(num) + return [self.serializer.loads(self.ctx.python_dump(x)) for x in vals] + + def first(self): + """ + >>> sc.parallelize([2, 3, 4]).first() + 2 + """ + return self.serializer.loads(self.ctx.python_dump(self._jrdd.first())) + + # TODO: saveAsTextFile + + # TODO: saveAsObjectFile + + +class PairRDD(RDD): + + def __init__(self, jrdd, ctx, keySerializer=None, valSerializer=None): + RDD.__init__(self, jrdd, ctx) + self.keySerializer = keySerializer or ctx.defaultSerializer + self.valSerializer = valSerializer or ctx.defaultSerializer + self.serializer = \ + PairSerializer(self.keySerializer, self.valSerializer) + + def _builder(self, jrdd, ctx): + return PairRDD(jrdd, ctx, self.keySerializer, self.valSerializer) + + def reduceByKey(self, func, numSplits=None): + """ + >>> x = sc.parallelizePairs([("a", 1), ("b", 1), ("a", 1)]) + >>> sorted(x.reduceByKey(lambda a, b: a + b).collect()) + [('a', 2), ('b', 1)] + """ + return self.combineByKey(lambda x: x, func, func, numSplits) + + # TODO: reduceByKeyLocally() + + # TODO: countByKey() + + # TODO: partitionBy + + def join(self, other, numSplits=None): + """ + >>> x = sc.parallelizePairs([("a", 1), ("b", 4)]) + >>> y = sc.parallelizePairs([("a", 2), ("a", 3)]) + >>> x.join(y).collect() + [('a', (1, 2)), ('a', (1, 3))] + + Check that we get a PairRDD-like object back: + >>> assert x.join(y).join + """ + assert self.keySerializer.name == other.keySerializer.name + if self.keySerializer.is_comparable: + return PairRDD(self._jrdd.join(other._jrdd), + self.ctx, self.keySerializer, + PairSerializer(self.valSerializer, other.valSerializer)) + else: + return python_join(self, other, numSplits) + + def leftOuterJoin(self, other, numSplits=None): + """ + >>> x = sc.parallelizePairs([("a", 1), ("b", 4)]) + >>> y = sc.parallelizePairs([("a", 2)]) + >>> sorted(x.leftOuterJoin(y).collect()) + [('a', (1, 2)), ('b', (4, None))] + """ + assert self.keySerializer.name == other.keySerializer.name + if self.keySerializer.is_comparable: + return PairRDD(self._jrdd.leftOuterJoin(other._jrdd), + self.ctx, self.keySerializer, + PairSerializer(self.valSerializer, + OptionSerializer(other.valSerializer))) + else: + return python_left_outer_join(self, other, numSplits) + + def rightOuterJoin(self, other, numSplits=None): + """ + >>> x = sc.parallelizePairs([("a", 1), ("b", 4)]) + >>> y = sc.parallelizePairs([("a", 2)]) + >>> sorted(y.rightOuterJoin(x).collect()) + [('a', (2, 1)), ('b', (None, 4))] + """ + assert self.keySerializer.name == other.keySerializer.name + if self.keySerializer.is_comparable: + return PairRDD(self._jrdd.rightOuterJoin(other._jrdd), + self.ctx, self.keySerializer, + PairSerializer(OptionSerializer(self.valSerializer), + other.valSerializer)) + else: + return python_right_outer_join(self, other, numSplits) + + def combineByKey(self, createCombiner, mergeValue, mergeCombiners, + numSplits=None, serializer=None): + """ + >>> x = sc.parallelizePairs([("a", 1), ("b", 1), ("a", 1)]) + >>> def f(x): return x + >>> def add(a, b): return a + str(b) + >>> sorted(x.combineByKey(str, add, add).collect()) + [('a', '11'), ('b', '1')] + """ + serializer = serializer or self.ctx.defaultSerializer + if numSplits is None: + numSplits = self.ctx.defaultParallelism + # Use hash() to create keys that are comparable in Java. + loadkv = self.serializer.loads + def pairify(kv): + # TODO: add method to deserialize only the key or value from + # a PairSerializer? + key = loadkv(kv)[0] + return (str(hash(key)), kv) + partitioner = self.ctx.jvm.spark.HashPartitioner(numSplits) + jrdd = self._pipePairs(pairify).partitionBy(partitioner) + pairified = PairRDD(jrdd, self.ctx, NopSerializer, self.serializer) + + loads = PairSerializer(NopSerializer, self.serializer).loads + dumpk = self.keySerializer.dumps + dumpc = serializer.dumps + + functions = [createCombiner, mergeValue, mergeCombiners, loads, dumpk, + dumpc] + jpairs = pairified._pipePairs(functions, "combine_by_key", + preservesPartitioning=True) + return PairRDD(jpairs, self.ctx, self.keySerializer, serializer) + + def groupByKey(self, numSplits=None): + """ + >>> x = sc.parallelizePairs([("a", 1), ("b", 1), ("a", 1)]) + >>> sorted(x.groupByKey().collect()) + [('a', [1, 1]), ('b', [1])] + """ + + def createCombiner(x): + return [x] + + def mergeValue(xs, x): + xs.append(x) + return xs + + def mergeCombiners(a, b): + return a + b + + return self.combineByKey(createCombiner, mergeValue, mergeCombiners, + numSplits) + + def collectAsMap(self): + """ + >>> m = sc.parallelizePairs([(1, 2), (3, 4)]).collectAsMap() + >>> m[1] + 2 + >>> m[3] + 4 + """ + m = self._jrdd.collectAsMap() + def loads(x): + (k, v) = x + return (self.keySerializer.loads(k), self.valSerializer.loads(v)) + return dict(loads(x) for x in m.items()) + + def flatMapValues(self, f, valSerializer=None): + flat_map_fn = lambda (k, v): ((k, x) for x in f(v)) + return self.flatMapPairs(flat_map_fn, self.keySerializer, + valSerializer, True) + + def mapValues(self, f, valSerializer=None): + map_values_fn = lambda (k, v): (k, f(v)) + return self.mapPairs(map_values_fn, self.keySerializer, valSerializer, + True) + + # TODO: support varargs cogroup of several RDDs. + def groupWith(self, other): + return self.cogroup(other) + + def cogroup(self, other, numSplits=None): + """ + >>> x = sc.parallelizePairs([("a", 1), ("b", 4)]) + >>> y = sc.parallelizePairs([("a", 2)]) + >>> x.cogroup(y).collect() + [('a', ([1], [2])), ('b', ([4], []))] + """ + assert self.keySerializer.name == other.keySerializer.name + resultValSerializer = PairSerializer( + ArraySerializer(self.valSerializer), + ArraySerializer(other.valSerializer)) + if self.keySerializer.is_comparable: + return PairRDD(self._jrdd.cogroup(other._jrdd), + self.ctx, self.keySerializer, resultValSerializer) + else: + return python_cogroup(self, other, numSplits) + + # TODO: `lookup` is disabled because we can't make direct comparisons based + # on the key; we need to compare the hash of the key to the hash of the + # keys in the pairs. This could be an expensive operation, since those + # hashes aren't retained. + + # TODO: file saving + + +class MappedRDDBase(object): + def __init__(self, prev, func, serializer, preservesPartitioning=False): + if isinstance(prev, MappedRDDBase) and not prev.is_cached: + prev_func = prev.func + self.func = lambda x: func(prev_func(x)) + self.preservesPartitioning = \ + prev.preservesPartitioning and preservesPartitioning + self._prev_jrdd = prev._prev_jrdd + self._prev_serializer = prev._prev_serializer + else: + self.func = func + self.preservesPartitioning = preservesPartitioning + self._prev_jrdd = prev._jrdd + self._prev_serializer = prev.serializer + self.serializer = serializer or prev.ctx.defaultSerializer + self.is_cached = False + self.ctx = prev.ctx + self.prev = prev + self._jrdd_val = None + + +class MappedRDD(MappedRDDBase, RDD): + """ + >>> rdd = sc.parallelize([1, 2, 3, 4]) + >>> rdd.map(lambda x: 2 * x).cache().map(lambda x: 2 * x).collect() + [4, 8, 12, 16] + >>> rdd.map(lambda x: 2 * x).map(lambda x: 2 * x).collect() + [4, 8, 12, 16] + """ + + @property + def _jrdd(self): + if not self._jrdd_val: + udf = self.func + loads = self._prev_serializer.loads + dumps = self.serializer.dumps + func = lambda x: dumps(udf(loads(x))) + pipe_command = RDD._get_pipe_command("map", [func]) + class_manifest = self._prev_jrdd.classManifest() + python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(), + pipe_command, self.preservesPartitioning, self.ctx.pythonExec, + class_manifest) + self._jrdd_val = python_rdd.asJavaRDD() + return self._jrdd_val + + +class PairMappedRDD(MappedRDDBase, PairRDD): + """ + >>> rdd = sc.parallelize([1, 2, 3, 4]) + >>> rdd.mapPairs(lambda x: (x, x)) \\ + ... .mapPairs(lambda (x, y): (2*x, 2*y)) \\ + ... .collect() + [(2, 2), (4, 4), (6, 6), (8, 8)] + >>> rdd.mapPairs(lambda x: (x, x)) \\ + ... .mapPairs(lambda (x, y): (2*x, 2*y)) \\ + ... .map(lambda (x, _): x).collect() + [2, 4, 6, 8] + """ + + def __init__(self, prev, func, keySerializer=None, valSerializer=None, + preservesPartitioning=False): + self.keySerializer = keySerializer or prev.ctx.defaultSerializer + self.valSerializer = valSerializer or prev.ctx.defaultSerializer + serializer = PairSerializer(self.keySerializer, self.valSerializer) + MappedRDDBase.__init__(self, prev, func, serializer, + preservesPartitioning) + + @property + def _jrdd(self): + if not self._jrdd_val: + udf = self.func + loads = self._prev_serializer.loads + dumpk = self.keySerializer.dumps + dumpv = self.valSerializer.dumps + def func(x): + (k, v) = udf(loads(x)) + return (dumpk(k), dumpv(v)) + pipe_command = RDD._get_pipe_command("mapPairs", [func]) + class_manifest = self._prev_jrdd.classManifest() + self._jrdd_val = self.ctx.jvm.PythonPairRDD(self._prev_jrdd.rdd(), + pipe_command, self.preservesPartitioning, self.ctx.pythonExec, + class_manifest).asJavaPairRDD() + return self._jrdd_val + + +def _test(): + import doctest + from pyspark.context import SparkContext + from pyspark.serializers import PickleSerializer, JSONSerializer + globs = globals().copy() + globs['sc'] = SparkContext('local', 'PythonTest', + defaultSerializer=JSONSerializer) + doctest.testmod(globs=globs) + globs['sc'].stop() + globs['sc'] = SparkContext('local', 'PythonTest', + defaultSerializer=PickleSerializer) + doctest.testmod(globs=globs) + globs['sc'].stop() + + +if __name__ == "__main__": + _test() diff --git a/pyspark/pyspark/serializers.py b/pyspark/pyspark/serializers.py new file mode 100644 index 0000000000..b113f5656b --- /dev/null +++ b/pyspark/pyspark/serializers.py @@ -0,0 +1,229 @@ +""" +Data serialization methods. + +The Spark Python API is built on top of the Spark Java API. RDDs created in +Python are stored in Java as RDDs of Strings. Python objects are automatically +serialized/deserialized, so this representation is transparent to the end-user. + +------------------ +Serializer objects +------------------ + +`Serializer` objects are used to customize how an RDD's values are serialized. + +Each `Serializer` is a named tuple with four fields: + + - A `dumps` function, for serializing a Python object to a string. + + - A `loads` function, for deserializing a Python object from a string. + + - An `is_comparable` field, True if equal Python objects are serialized to + equal strings, and False otherwise. + + - A `name` field, used to identify the Serializer. Serializers are + compared for equality by comparing their names. + +The serializer's output should be base64-encoded. + +------------------------------------------------------------------ +`is_comparable`: comparing serialized representations for equality +------------------------------------------------------------------ + +If `is_comparable` is False, the serializer's representations of equal objects +are not required to be equal: + +>>> import pickle +>>> a = {1: 0, 9: 0} +>>> b = {9: 0, 1: 0} +>>> a == b +True +>>> pickle.dumps(a) == pickle.dumps(b) +False + +RDDs with comparable serializers can use native Java implementations of +operations like join() and distinct(), which may lead to better performance by +eliminating deserialization and Python comparisons. + +The default JSONSerializer produces comparable representations of common Python +data structures. + +-------------------------------------- +Examples of serialized representations +-------------------------------------- + +The RDD transformations that use Python UDFs are implemented in terms of +a modified `PipedRDD.pipe()` function. For each record `x` in the RDD, the +`pipe()` function pipes `x.toString()` to a Python worker process, which +deserializes the string into a Python object, executes user-defined functions, +and outputs serialized Python objects. + +The regular `toString()` method returns an ambiguous representation, due to the +way that Scala `Option` instances are printed: + +>>> from context import SparkContext +>>> sc = SparkContext("local", "SerializerDocs") +>>> x = sc.parallelizePairs([("a", 1), ("b", 4)]) +>>> y = sc.parallelizePairs([("a", 2)]) + +>>> print y.rightOuterJoin(x)._jrdd.first().toString() +(ImEi,(Some(Mg==),MQ==)) + +In Java, preprocessing is performed to handle Option instances, so the Python +process receives unambiguous input: + +>>> print sc.python_dump(y.rightOuterJoin(x)._jrdd.first()) +(ImEi,(Mg==,MQ==)) + +The base64-encoding eliminates the need to escape newlines, parentheses and +other special characters. + +---------------------- +Serializer composition +---------------------- + +In order to handle nested structures, which could contain object serialized +with different serializers, the RDD module composes serializers. For example, +the serializers in the previous example are: + +>>> print x.serializer.name +PairSerializer + +>>> print y.serializer.name +PairSerializer + +>>> print y.rightOuterJoin(x).serializer.name +PairSerializer, JSONSerializer>> +""" +from base64 import standard_b64encode, standard_b64decode +from collections import namedtuple +import cPickle +import simplejson + + +Serializer = namedtuple("Serializer", + ["dumps","loads", "is_comparable", "name"]) + + +NopSerializer = Serializer(str, str, True, "NopSerializer") + + +JSONSerializer = Serializer( + lambda obj: standard_b64encode(simplejson.dumps(obj, sort_keys=True, + separators=(',', ':'))), + lambda s: simplejson.loads(standard_b64decode(s)), + True, + "JSONSerializer" +) + + +PickleSerializer = Serializer( + lambda obj: standard_b64encode(cPickle.dumps(obj)), + lambda s: cPickle.loads(standard_b64decode(s)), + False, + "PickleSerializer" +) + + +def OptionSerializer(serializer): + """ + >>> ser = OptionSerializer(NopSerializer) + >>> ser.loads(ser.dumps("Hello, World!")) + 'Hello, World!' + >>> ser.loads(ser.dumps(None)) is None + True + """ + none_placeholder = '*' + + def dumps(x): + if x is None: + return none_placeholder + else: + return serializer.dumps(x) + + def loads(x): + if x == none_placeholder: + return None + else: + return serializer.loads(x) + + name = "OptionSerializer<%s>" % serializer.name + return Serializer(dumps, loads, serializer.is_comparable, name) + + +def PairSerializer(keySerializer, valSerializer): + """ + Returns a Serializer for a (key, value) pair. + + >>> ser = PairSerializer(JSONSerializer, JSONSerializer) + >>> ser.loads(ser.dumps((1, 2))) + (1, 2) + + >>> ser = PairSerializer(JSONSerializer, ser) + >>> ser.loads(ser.dumps((1, (2, 3)))) + (1, (2, 3)) + """ + def loads(kv): + try: + (key, val) = kv[1:-1].split(',', 1) + key = keySerializer.loads(key) + val = valSerializer.loads(val) + return (key, val) + except: + print "Error in deserializing pair from '%s'" % str(kv) + raise + + def dumps(kv): + (key, val) = kv + return"(%s,%s)" % (keySerializer.dumps(key), valSerializer.dumps(val)) + is_comparable = \ + keySerializer.is_comparable and valSerializer.is_comparable + name = "PairSerializer<%s, %s>" % (keySerializer.name, valSerializer.name) + return Serializer(dumps, loads, is_comparable, name) + + +def ArraySerializer(serializer): + """ + >>> ser = ArraySerializer(JSONSerializer) + >>> ser.loads(ser.dumps([1, 2, 3, 4])) + [1, 2, 3, 4] + >>> ser = ArraySerializer(PairSerializer(JSONSerializer, PickleSerializer)) + >>> ser.loads(ser.dumps([('a', 1), ('b', 2)])) + [('a', 1), ('b', 2)] + >>> ser.loads(ser.dumps([('a', 1)])) + [('a', 1)] + >>> ser.loads(ser.dumps([])) + [] + """ + def dumps(arr): + if arr == []: + return '[]' + else: + return '[' + '|'.join(serializer.dumps(x) for x in arr) + ']' + + def loads(s): + if s == '[]': + return [] + items = s[1:-1] + if '|' in items: + items = items.split('|') + else: + items = [items] + return [serializer.loads(x) for x in items] + + name = "ArraySerializer<%s>" % serializer.name + return Serializer(dumps, loads, serializer.is_comparable, name) + + +# TODO: IntegerSerializer + + +# TODO: DoubleSerializer + + +def _test(): + import doctest + doctest.testmod() + + +if __name__ == "__main__": + _test() diff --git a/pyspark/pyspark/worker.py b/pyspark/pyspark/worker.py new file mode 100644 index 0000000000..4d4cc939c3 --- /dev/null +++ b/pyspark/pyspark/worker.py @@ -0,0 +1,97 @@ +""" +Worker that receives input from Piped RDD. +""" +import sys +from base64 import standard_b64decode +# CloudPickler needs to be imported so that depicklers are registered using the +# copy_reg module. +from cloud.serialization.cloudpickle import CloudPickler +import cPickle + + +# Redirect stdout to stderr so that users must return values from functions. +old_stdout = sys.stdout +sys.stdout = sys.stderr + + +def load_function(): + return cPickle.loads(standard_b64decode(sys.stdin.readline().strip())) + + +def output(x): + for line in x.split("\n"): + old_stdout.write(line.rstrip("\r\n") + "\n") + + +def read_input(): + for line in sys.stdin: + yield line.rstrip("\r\n") + + +def do_combine_by_key(): + create_combiner = load_function() + merge_value = load_function() + merge_combiners = load_function() # TODO: not used. + depickler = load_function() + key_pickler = load_function() + combiner_pickler = load_function() + combiners = {} + for line in read_input(): + # Discard the hashcode added in the Python combineByKey() method. + (key, value) = depickler(line)[1] + if key not in combiners: + combiners[key] = create_combiner(value) + else: + combiners[key] = merge_value(combiners[key], value) + for (key, combiner) in combiners.iteritems(): + output(key_pickler(key)) + output(combiner_pickler(combiner)) + + +def do_map(map_pairs=False): + f = load_function() + for line in read_input(): + try: + out = f(line) + if out is not None: + if map_pairs: + for x in out: + output(x) + else: + output(out) + except: + sys.stderr.write("Error processing line '%s'\n" % line) + raise + + +def do_reduce(): + f = load_function() + dumps = load_function() + acc = None + for line in read_input(): + acc = f(line, acc) + output(dumps(acc)) + + +def do_echo(): + old_stdout.writelines(sys.stdin.readlines()) + + +def main(): + command = sys.stdin.readline().strip() + if command == "map": + do_map(map_pairs=False) + elif command == "mapPairs": + do_map(map_pairs=True) + elif command == "combine_by_key": + do_combine_by_key() + elif command == "reduce": + do_reduce() + elif command == "echo": + do_echo() + else: + raise Exception("Unsupported command %s" % command) + + +if __name__ == '__main__': + main() diff --git a/pyspark/requirements.txt b/pyspark/requirements.txt new file mode 100644 index 0000000000..d9b3fe40bd --- /dev/null +++ b/pyspark/requirements.txt @@ -0,0 +1,9 @@ +# The Python API relies on some new features from the Py4J development branch. +# pip can't install Py4J from git because the setup.py file for the Python +# package is not at the root of the git repository. It may be possible to +# install Py4J from git once https://github.com/pypa/pip/pull/526 is merged. + +# git+git://github.com/bartdag/py4j.git@3dbf380d3d2cdeb9aab394454ea74d80c4aba1ea + +simplejson==2.6.1 +cloud==2.5.5 -- cgit v1.2.3