diff options
author | Josh Rosen <joshrosen@eecs.berkeley.edu> | 2012-08-25 16:46:07 -0700 |
---|---|---|
committer | Josh Rosen <joshrosen@eecs.berkeley.edu> | 2012-08-27 00:24:39 -0700 |
commit | 200d248dcc5903295296bf897211cf543b37f8c1 (patch) | |
tree | 46df15fbccf99489a1f7f240c71cc56ef083d6d8 /pyspark | |
parent | 6904cb77d4306a14891cc71338c8f9f966d009f1 (diff) | |
download | spark-200d248dcc5903295296bf897211cf543b37f8c1.tar.gz spark-200d248dcc5903295296bf897211cf543b37f8c1.tar.bz2 spark-200d248dcc5903295296bf897211cf543b37f8c1.zip |
Simplify Python worker; pipeline the map step of partitionBy().
Diffstat (limited to 'pyspark')
-rw-r--r-- | pyspark/pyspark/context.py | 9 | ||||
-rw-r--r-- | pyspark/pyspark/rdd.py | 70 | ||||
-rw-r--r-- | pyspark/pyspark/serializers.py | 23 | ||||
-rw-r--r-- | pyspark/pyspark/worker.py | 50 |
4 files changed, 52 insertions, 100 deletions
diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py index 6f87206665..b8490019e3 100644 --- a/pyspark/pyspark/context.py +++ b/pyspark/pyspark/context.py @@ -4,7 +4,7 @@ from tempfile import NamedTemporaryFile from pyspark.broadcast import Broadcast from pyspark.java_gateway import launch_gateway -from pyspark.serializers import PickleSerializer, dumps +from pyspark.serializers import dump_pickle, write_with_length from pyspark.rdd import RDD @@ -16,9 +16,8 @@ class SparkContext(object): asPickle = jvm.spark.api.python.PythonRDD.asPickle arrayAsPickle = jvm.spark.api.python.PythonRDD.arrayAsPickle - def __init__(self, master, name, defaultParallelism=None, - pythonExec='python'): + pythonExec='python'): self.master = master self.name = name self._jsc = self.jvm.JavaSparkContext(master, name) @@ -52,7 +51,7 @@ class SparkContext(object): # objects are written to a file and loaded through textFile(). tempFile = NamedTemporaryFile(delete=False) for x in c: - dumps(PickleSerializer.dumps(x), tempFile) + write_with_length(dump_pickle(x), tempFile) tempFile.close() atexit.register(lambda: os.unlink(tempFile.name)) jrdd = self.pickleFile(self._jsc, tempFile.name, numSlices) @@ -64,6 +63,6 @@ class SparkContext(object): return RDD(jrdd, self) def broadcast(self, value): - jbroadcast = self._jsc.broadcast(bytearray(PickleSerializer.dumps(value))) + jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value))) return Broadcast(jbroadcast.uuid().toString(), value, jbroadcast, self._pickled_broadcast_vars) diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index 3528b8f308..21e822ba9f 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -3,7 +3,7 @@ from collections import Counter from itertools import chain, ifilter, imap from pyspark import cloudpickle -from pyspark.serializers import PickleSerializer +from pyspark.serializers import dump_pickle, load_pickle from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_cogroup @@ -17,17 +17,6 @@ class RDD(object): self.is_cached = False self.ctx = ctx - @classmethod - def _get_pipe_command(cls, ctx, command, functions): - worker_args = [command] - for f in functions: - worker_args.append(b64enc(cloudpickle.dumps(f))) - broadcast_vars = [x._jbroadcast for x in ctx._pickled_broadcast_vars] - broadcast_vars = ListConverter().convert(broadcast_vars, - ctx.gateway._gateway_client) - ctx._pickled_broadcast_vars.clear() - return (" ".join(worker_args), broadcast_vars) - def cache(self): self.is_cached = True self._jrdd.cache() @@ -66,14 +55,6 @@ class RDD(object): def func(iterator): return ifilter(f, iterator) return self.mapPartitions(func) - def _pipe(self, functions, command): - class_manifest = self._jrdd.classManifest() - (pipe_command, broadcast_vars) = \ - RDD._get_pipe_command(self.ctx, command, functions) - python_rdd = self.ctx.jvm.PythonRDD(self._jrdd.rdd(), pipe_command, - False, self.ctx.pythonExec, broadcast_vars, class_manifest) - return python_rdd.asJavaRDD() - def distinct(self): """ >>> sorted(sc.parallelize([1, 1, 2, 3]).distinct().collect()) @@ -89,7 +70,7 @@ class RDD(object): def takeSample(self, withReplacement, num, seed): vals = self._jrdd.takeSample(withReplacement, num, seed) - return [PickleSerializer.loads(bytes(x)) for x in vals] + return [load_pickle(bytes(x)) for x in vals] def union(self, other): """ @@ -148,7 +129,7 @@ class RDD(object): def collect(self): pickle = self.ctx.arrayAsPickle(self._jrdd.rdd().collect()) - return PickleSerializer.loads(bytes(pickle)) + return load_pickle(bytes(pickle)) def reduce(self, f): """ @@ -216,19 +197,17 @@ class RDD(object): [2, 3] """ pickle = self.ctx.arrayAsPickle(self._jrdd.rdd().take(num)) - return PickleSerializer.loads(bytes(pickle)) + return load_pickle(bytes(pickle)) def first(self): """ >>> sc.parallelize([2, 3, 4]).first() 2 """ - return PickleSerializer.loads(bytes(self.ctx.asPickle(self._jrdd.first()))) + return load_pickle(bytes(self.ctx.asPickle(self._jrdd.first()))) # TODO: saveAsTextFile - # TODO: saveAsObjectFile - # Pair functions def collectAsMap(self): @@ -303,19 +282,18 @@ class RDD(object): """ return python_right_outer_join(self, other, numSplits) - # TODO: pipelining - # TODO: optimizations def partitionBy(self, numSplits, hashFunc=hash): if numSplits is None: numSplits = self.ctx.defaultParallelism - (pipe_command, broadcast_vars) = \ - RDD._get_pipe_command(self.ctx, 'shuffle_map_step', [hashFunc]) - class_manifest = self._jrdd.classManifest() - python_rdd = self.ctx.jvm.PythonPairRDD(self._jrdd.rdd(), - pipe_command, False, self.ctx.pythonExec, broadcast_vars, - class_manifest) + def add_shuffle_key(iterator): + for (k, v) in iterator: + yield str(hashFunc(k)) + yield dump_pickle((k, v)) + keyed = PipelinedRDD(self, add_shuffle_key) + keyed._bypass_serializer = True + pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() partitioner = self.ctx.jvm.spark.HashPartitioner(numSplits) - jrdd = python_rdd.asJavaPairRDD().partitionBy(partitioner) + jrdd = pairRDD.partitionBy(partitioner) jrdd = jrdd.map(self.ctx.jvm.ExtractValue()) return RDD(jrdd, self.ctx) @@ -430,17 +408,23 @@ class PipelinedRDD(RDD): self.ctx = prev.ctx self.prev = prev self._jrdd_val = None + self._bypass_serializer = False @property def _jrdd(self): - if not self._jrdd_val: - (pipe_command, broadcast_vars) = \ - RDD._get_pipe_command(self.ctx, "pipeline", [self.func]) - class_manifest = self._prev_jrdd.classManifest() - python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(), - pipe_command, self.preservesPartitioning, self.ctx.pythonExec, - broadcast_vars, class_manifest) - self._jrdd_val = python_rdd.asJavaRDD() + if self._jrdd_val: + return self._jrdd_val + funcs = [self.func, self._bypass_serializer] + pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in funcs) + broadcast_vars = ListConverter().convert( + [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], + self.ctx.gateway._gateway_client) + self.ctx._pickled_broadcast_vars.clear() + class_manifest = self._prev_jrdd.classManifest() + python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(), + pipe_command, self.preservesPartitioning, self.ctx.pythonExec, + broadcast_vars, class_manifest) + self._jrdd_val = python_rdd.asJavaRDD() return self._jrdd_val diff --git a/pyspark/pyspark/serializers.py b/pyspark/pyspark/serializers.py index 7b3e6966e1..faa1e683c7 100644 --- a/pyspark/pyspark/serializers.py +++ b/pyspark/pyspark/serializers.py @@ -1,31 +1,20 @@ -""" -Data serialization methods. - -The Spark Python API is built on top of the Spark Java API. RDDs created in -Python are stored in Java as RDD[Array[Byte]]. Python objects are -automatically serialized/deserialized, so this representation is transparent to -the end-user. -""" -from collections import namedtuple -import cPickle import struct +import cPickle -Serializer = namedtuple("Serializer", ["dumps","loads"]) +def dump_pickle(obj): + return cPickle.dumps(obj, 2) -PickleSerializer = Serializer( - lambda obj: cPickle.dumps(obj, -1), - cPickle.loads) +load_pickle = cPickle.loads -def dumps(obj, stream): - # TODO: determining the length of non-byte objects. +def write_with_length(obj, stream): stream.write(struct.pack("!i", len(obj))) stream.write(obj) -def loads(stream): +def read_with_length(stream): length = stream.read(4) if length == "": raise EOFError diff --git a/pyspark/pyspark/worker.py b/pyspark/pyspark/worker.py index 0f90c6ff46..a9ed71892f 100644 --- a/pyspark/pyspark/worker.py +++ b/pyspark/pyspark/worker.py @@ -7,61 +7,41 @@ from base64 import standard_b64decode # copy_reg module. from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.cloudpickle import CloudPickler -from pyspark.serializers import dumps, loads, PickleSerializer -import cPickle +from pyspark.serializers import write_with_length, read_with_length, \ + dump_pickle, load_pickle + # Redirect stdout to stderr so that users must return values from functions. old_stdout = sys.stdout sys.stdout = sys.stderr -def load_function(): - return cPickle.loads(standard_b64decode(sys.stdin.readline().strip())) - - -def output(x): - dumps(x, old_stdout) +def load_obj(): + return load_pickle(standard_b64decode(sys.stdin.readline().strip())) def read_input(): try: while True: - yield cPickle.loads(loads(sys.stdin)) + yield load_pickle(read_with_length(sys.stdin)) except EOFError: return -def do_pipeline(): - f = load_function() - for obj in f(read_input()): - output(PickleSerializer.dumps(obj)) - - -def do_shuffle_map_step(): - hashFunc = load_function() - while True: - try: - pickled = loads(sys.stdin) - except EOFError: - return - key = cPickle.loads(pickled)[0] - output(str(hashFunc(key))) - output(pickled) - - def main(): num_broadcast_variables = int(sys.stdin.readline().strip()) for _ in range(num_broadcast_variables): uuid = sys.stdin.read(36) - value = loads(sys.stdin) - _broadcastRegistry[uuid] = Broadcast(uuid, cPickle.loads(value)) - command = sys.stdin.readline().strip() - if command == "pipeline": - do_pipeline() - elif command == "shuffle_map_step": - do_shuffle_map_step() + value = read_with_length(sys.stdin) + _broadcastRegistry[uuid] = Broadcast(uuid, load_pickle(value)) + func = load_obj() + bypassSerializer = load_obj() + if bypassSerializer: + dumps = lambda x: x else: - raise Exception("Unsupported command %s" % command) + dumps = dump_pickle + for obj in func(read_input()): + write_with_length(dumps(obj), old_stdout) if __name__ == '__main__': |