aboutsummaryrefslogtreecommitdiff
path: root/pyspark
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@eecs.berkeley.edu>2012-08-24 19:38:50 -0700
committerJosh Rosen <joshrosen@eecs.berkeley.edu>2012-08-24 19:44:14 -0700
commitf3b852ce66d193e3421eeecef71ea27bff73a94b (patch)
treeda22dadd6e8a6aa7c417d86e4d066245f406ac04 /pyspark
parent4b523004877cf94152225484de7683e9d17cdb56 (diff)
downloadspark-f3b852ce66d193e3421eeecef71ea27bff73a94b.tar.gz
spark-f3b852ce66d193e3421eeecef71ea27bff73a94b.tar.bz2
spark-f3b852ce66d193e3421eeecef71ea27bff73a94b.zip
Refactor Python MappedRDD to use iterator pipelines.
Diffstat (limited to 'pyspark')
-rw-r--r--pyspark/pyspark/rdd.py83
-rw-r--r--pyspark/pyspark/worker.py55
2 files changed, 41 insertions, 97 deletions
diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py
index ff9c483032..7d280d8844 100644
--- a/pyspark/pyspark/rdd.py
+++ b/pyspark/pyspark/rdd.py
@@ -1,4 +1,5 @@
from base64 import standard_b64encode as b64enc
+from itertools import chain, ifilter, imap
from pyspark import cloudpickle
from pyspark.serializers import PickleSerializer
@@ -15,8 +16,6 @@ class RDD(object):
@classmethod
def _get_pipe_command(cls, command, functions):
- if functions and not isinstance(functions, (list, tuple)):
- functions = [functions]
worker_args = [command]
for f in functions:
worker_args.append(b64enc(cloudpickle.dumps(f)))
@@ -28,7 +27,8 @@ class RDD(object):
return self
def map(self, f, preservesPartitioning=False):
- return MappedRDD(self, f, preservesPartitioning)
+ def func(iterator): return imap(f, iterator)
+ return PipelinedRDD(self, func, preservesPartitioning)
def flatMap(self, f):
"""
@@ -38,7 +38,8 @@ class RDD(object):
>>> sorted(rdd.flatMap(lambda x: [(x, x), (x, x)]).collect())
[(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)]
"""
- return MappedRDD(self, f, preservesPartitioning=False, command='flatmap')
+ def func(iterator): return chain.from_iterable(imap(f, iterator))
+ return PipelinedRDD(self, func)
def filter(self, f):
"""
@@ -46,10 +47,10 @@ class RDD(object):
>>> rdd.filter(lambda x: x % 2 == 0).collect()
[2, 4]
"""
- def filter_func(x): return x if f(x) else None
- return RDD(self._pipe(filter_func), self.ctx)
+ def func(iterator): return ifilter(f, iterator)
+ return PipelinedRDD(self, func)
- def _pipe(self, functions, command="map"):
+ def _pipe(self, functions, command):
class_manifest = self._jrdd.classManifest()
pipe_command = RDD._get_pipe_command(command, functions)
python_rdd = self.ctx.jvm.PythonRDD(self._jrdd.rdd(), pipe_command,
@@ -128,7 +129,16 @@ class RDD(object):
>>> sc.parallelize((2 for _ in range(10))).map(lambda x: 1).cache().reduce(add)
10
"""
- vals = MappedRDD(self, f, command="reduce", preservesPartitioning=False).collect()
+ def func(iterator):
+ acc = None
+ for obj in iterator:
+ if acc is None:
+ acc = obj
+ else:
+ acc = f(obj, acc)
+ if acc is not None:
+ yield acc
+ vals = PipelinedRDD(self, func).collect()
return reduce(f, vals)
# TODO: fold
@@ -230,8 +240,6 @@ class RDD(object):
jrdd = jrdd.map(self.ctx.jvm.ExtractValue())
return RDD(jrdd, self.ctx)
-
-
def combineByKey(self, createCombiner, mergeValue, mergeCombiners,
numSplits=None):
"""
@@ -297,7 +305,7 @@ class RDD(object):
# TODO: file saving
-class MappedRDD(RDD):
+class PipelinedRDD(RDD):
"""
Pipelined maps:
>>> rdd = sc.parallelize([1, 2, 3, 4])
@@ -313,68 +321,29 @@ class MappedRDD(RDD):
>>> rdd.flatMap(lambda x: [x, x]).reduce(add)
20
"""
- def __init__(self, prev, func, preservesPartitioning=False, command='map'):
- if isinstance(prev, MappedRDD) and not prev.is_cached:
+ def __init__(self, prev, func, preservesPartitioning=False):
+ if isinstance(prev, PipelinedRDD) and not prev.is_cached:
prev_func = prev.func
- if command == 'reduce':
- if prev.command == 'flatmap':
- def flatmap_reduce_func(x, acc):
- values = prev_func(x)
- if values is None:
- return acc
- if not acc:
- if len(values) == 1:
- return values[0]
- else:
- return reduce(func, values[1:], values[0])
- else:
- return reduce(func, values, acc)
- self.func = flatmap_reduce_func
- else:
- def reduce_func(x, acc):
- val = prev_func(x)
- if not val:
- return acc
- if acc is None:
- return val
- else:
- return func(val, acc)
- self.func = reduce_func
- else:
- if prev.command == 'flatmap':
- command = 'flatmap'
- self.func = lambda x: (func(y) for y in prev_func(x))
- else:
- self.func = lambda x: func(prev_func(x))
-
+ def pipeline_func(iterator):
+ return func(prev_func(iterator))
+ self.func = pipeline_func
self.preservesPartitioning = \
prev.preservesPartitioning and preservesPartitioning
self._prev_jrdd = prev._prev_jrdd
- self.is_pipelined = True
else:
- if command == 'reduce':
- def reduce_func(val, acc):
- if acc is None:
- return val
- else:
- return func(val, acc)
- self.func = reduce_func
- else:
- self.func = func
+ self.func = func
self.preservesPartitioning = preservesPartitioning
self._prev_jrdd = prev._jrdd
- self.is_pipelined = False
self.is_cached = False
self.ctx = prev.ctx
self.prev = prev
self._jrdd_val = None
- self.command = command
@property
def _jrdd(self):
if not self._jrdd_val:
funcs = [self.func]
- pipe_command = RDD._get_pipe_command(self.command, funcs)
+ pipe_command = RDD._get_pipe_command("pipeline", funcs)
class_manifest = self._prev_jrdd.classManifest()
python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(),
pipe_command, self.preservesPartitioning, self.ctx.pythonExec,
diff --git a/pyspark/pyspark/worker.py b/pyspark/pyspark/worker.py
index b13ed5699a..76b09918e7 100644
--- a/pyspark/pyspark/worker.py
+++ b/pyspark/pyspark/worker.py
@@ -25,17 +25,17 @@ def output(x):
def read_input():
try:
while True:
- yield loads(sys.stdin)
+ yield cPickle.loads(loads(sys.stdin))
except EOFError:
return
+
def do_combine_by_key():
create_combiner = load_function()
merge_value = load_function()
merge_combiners = load_function() # TODO: not used.
combiners = {}
- for obj in read_input():
- (key, value) = PickleSerializer.loads(obj)
+ for (key, value) in read_input():
if key not in combiners:
combiners[key] = create_combiner(value)
else:
@@ -44,57 +44,32 @@ def do_combine_by_key():
output(PickleSerializer.dumps((key, combiner)))
-def do_map(flat=False):
+def do_pipeline():
f = load_function()
- for obj in read_input():
- try:
- out = f(PickleSerializer.loads(obj))
- if out is not None:
- if flat:
- for x in out:
- output(PickleSerializer.dumps(x))
- else:
- output(PickleSerializer.dumps(out))
- except:
- sys.stderr.write("Error processing obj %s\n" % repr(obj))
- raise
+ for obj in f(read_input()):
+ output(PickleSerializer.dumps(obj))
def do_shuffle_map_step():
hashFunc = load_function()
- for obj in read_input():
- key = PickleSerializer.loads(obj)[0]
+ while True:
+ try:
+ pickled = loads(sys.stdin)
+ except EOFError:
+ return
+ key = cPickle.loads(pickled)[0]
output(str(hashFunc(key)))
- output(obj)
-
-
-def do_reduce():
- f = load_function()
- acc = None
- for obj in read_input():
- acc = f(PickleSerializer.loads(obj), acc)
- if acc is not None:
- output(PickleSerializer.dumps(acc))
-
-
-def do_echo():
- old_stdout.writelines(sys.stdin.readlines())
+ output(pickled)
def main():
command = sys.stdin.readline().strip()
- if command == "map":
- do_map(flat=False)
- elif command == "flatmap":
- do_map(flat=True)
+ if command == "pipeline":
+ do_pipeline()
elif command == "combine_by_key":
do_combine_by_key()
- elif command == "reduce":
- do_reduce()
elif command == "shuffle_map_step":
do_shuffle_map_step()
- elif command == "echo":
- do_echo()
else:
raise Exception("Unsupported command %s" % command)