From f3b852ce66d193e3421eeecef71ea27bff73a94b Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 24 Aug 2012 19:38:50 -0700 Subject: Refactor Python MappedRDD to use iterator pipelines. --- pyspark/pyspark/rdd.py | 83 +++++++++++++++-------------------------------- pyspark/pyspark/worker.py | 55 +++++++++---------------------- 2 files changed, 41 insertions(+), 97 deletions(-) (limited to 'pyspark') diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index ff9c483032..7d280d8844 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -1,4 +1,5 @@ from base64 import standard_b64encode as b64enc +from itertools import chain, ifilter, imap from pyspark import cloudpickle from pyspark.serializers import PickleSerializer @@ -15,8 +16,6 @@ class RDD(object): @classmethod def _get_pipe_command(cls, command, functions): - if functions and not isinstance(functions, (list, tuple)): - functions = [functions] worker_args = [command] for f in functions: worker_args.append(b64enc(cloudpickle.dumps(f))) @@ -28,7 +27,8 @@ class RDD(object): return self def map(self, f, preservesPartitioning=False): - return MappedRDD(self, f, preservesPartitioning) + def func(iterator): return imap(f, iterator) + return PipelinedRDD(self, func, preservesPartitioning) def flatMap(self, f): """ @@ -38,7 +38,8 @@ class RDD(object): >>> sorted(rdd.flatMap(lambda x: [(x, x), (x, x)]).collect()) [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)] """ - return MappedRDD(self, f, preservesPartitioning=False, command='flatmap') + def func(iterator): return chain.from_iterable(imap(f, iterator)) + return PipelinedRDD(self, func) def filter(self, f): """ @@ -46,10 +47,10 @@ class RDD(object): >>> rdd.filter(lambda x: x % 2 == 0).collect() [2, 4] """ - def filter_func(x): return x if f(x) else None - return RDD(self._pipe(filter_func), self.ctx) + def func(iterator): return ifilter(f, iterator) + return PipelinedRDD(self, func) - def _pipe(self, functions, command="map"): + def _pipe(self, functions, command): class_manifest = self._jrdd.classManifest() pipe_command = RDD._get_pipe_command(command, functions) python_rdd = self.ctx.jvm.PythonRDD(self._jrdd.rdd(), pipe_command, @@ -128,7 +129,16 @@ class RDD(object): >>> sc.parallelize((2 for _ in range(10))).map(lambda x: 1).cache().reduce(add) 10 """ - vals = MappedRDD(self, f, command="reduce", preservesPartitioning=False).collect() + def func(iterator): + acc = None + for obj in iterator: + if acc is None: + acc = obj + else: + acc = f(obj, acc) + if acc is not None: + yield acc + vals = PipelinedRDD(self, func).collect() return reduce(f, vals) # TODO: fold @@ -230,8 +240,6 @@ class RDD(object): jrdd = jrdd.map(self.ctx.jvm.ExtractValue()) return RDD(jrdd, self.ctx) - - def combineByKey(self, createCombiner, mergeValue, mergeCombiners, numSplits=None): """ @@ -297,7 +305,7 @@ class RDD(object): # TODO: file saving -class MappedRDD(RDD): +class PipelinedRDD(RDD): """ Pipelined maps: >>> rdd = sc.parallelize([1, 2, 3, 4]) @@ -313,68 +321,29 @@ class MappedRDD(RDD): >>> rdd.flatMap(lambda x: [x, x]).reduce(add) 20 """ - def __init__(self, prev, func, preservesPartitioning=False, command='map'): - if isinstance(prev, MappedRDD) and not prev.is_cached: + def __init__(self, prev, func, preservesPartitioning=False): + if isinstance(prev, PipelinedRDD) and not prev.is_cached: prev_func = prev.func - if command == 'reduce': - if prev.command == 'flatmap': - def flatmap_reduce_func(x, acc): - values = prev_func(x) - if values is None: - return acc - if not acc: - if len(values) == 1: - return values[0] - else: - return reduce(func, values[1:], values[0]) - else: - return reduce(func, values, acc) - self.func = flatmap_reduce_func - else: - def reduce_func(x, acc): - val = prev_func(x) - if not val: - return acc - if acc is None: - return val - else: - return func(val, acc) - self.func = reduce_func - else: - if prev.command == 'flatmap': - command = 'flatmap' - self.func = lambda x: (func(y) for y in prev_func(x)) - else: - self.func = lambda x: func(prev_func(x)) - + def pipeline_func(iterator): + return func(prev_func(iterator)) + self.func = pipeline_func self.preservesPartitioning = \ prev.preservesPartitioning and preservesPartitioning self._prev_jrdd = prev._prev_jrdd - self.is_pipelined = True else: - if command == 'reduce': - def reduce_func(val, acc): - if acc is None: - return val - else: - return func(val, acc) - self.func = reduce_func - else: - self.func = func + self.func = func self.preservesPartitioning = preservesPartitioning self._prev_jrdd = prev._jrdd - self.is_pipelined = False self.is_cached = False self.ctx = prev.ctx self.prev = prev self._jrdd_val = None - self.command = command @property def _jrdd(self): if not self._jrdd_val: funcs = [self.func] - pipe_command = RDD._get_pipe_command(self.command, funcs) + pipe_command = RDD._get_pipe_command("pipeline", funcs) class_manifest = self._prev_jrdd.classManifest() python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(), pipe_command, self.preservesPartitioning, self.ctx.pythonExec, diff --git a/pyspark/pyspark/worker.py b/pyspark/pyspark/worker.py index b13ed5699a..76b09918e7 100644 --- a/pyspark/pyspark/worker.py +++ b/pyspark/pyspark/worker.py @@ -25,17 +25,17 @@ def output(x): def read_input(): try: while True: - yield loads(sys.stdin) + yield cPickle.loads(loads(sys.stdin)) except EOFError: return + def do_combine_by_key(): create_combiner = load_function() merge_value = load_function() merge_combiners = load_function() # TODO: not used. combiners = {} - for obj in read_input(): - (key, value) = PickleSerializer.loads(obj) + for (key, value) in read_input(): if key not in combiners: combiners[key] = create_combiner(value) else: @@ -44,57 +44,32 @@ def do_combine_by_key(): output(PickleSerializer.dumps((key, combiner))) -def do_map(flat=False): +def do_pipeline(): f = load_function() - for obj in read_input(): - try: - out = f(PickleSerializer.loads(obj)) - if out is not None: - if flat: - for x in out: - output(PickleSerializer.dumps(x)) - else: - output(PickleSerializer.dumps(out)) - except: - sys.stderr.write("Error processing obj %s\n" % repr(obj)) - raise + for obj in f(read_input()): + output(PickleSerializer.dumps(obj)) def do_shuffle_map_step(): hashFunc = load_function() - for obj in read_input(): - key = PickleSerializer.loads(obj)[0] + while True: + try: + pickled = loads(sys.stdin) + except EOFError: + return + key = cPickle.loads(pickled)[0] output(str(hashFunc(key))) - output(obj) - - -def do_reduce(): - f = load_function() - acc = None - for obj in read_input(): - acc = f(PickleSerializer.loads(obj), acc) - if acc is not None: - output(PickleSerializer.dumps(acc)) - - -def do_echo(): - old_stdout.writelines(sys.stdin.readlines()) + output(pickled) def main(): command = sys.stdin.readline().strip() - if command == "map": - do_map(flat=False) - elif command == "flatmap": - do_map(flat=True) + if command == "pipeline": + do_pipeline() elif command == "combine_by_key": do_combine_by_key() - elif command == "reduce": - do_reduce() elif command == "shuffle_map_step": do_shuffle_map_step() - elif command == "echo": - do_echo() else: raise Exception("Unsupported command %s" % command) -- cgit v1.2.3