aboutsummaryrefslogtreecommitdiff
path: root/pyspark/pyspark/rdd.py
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@eecs.berkeley.edu>2012-08-25 14:19:07 -0700
committerJosh Rosen <joshrosen@eecs.berkeley.edu>2012-08-27 00:19:26 -0700
commit6904cb77d4306a14891cc71338c8f9f966d009f1 (patch)
treea9f192b74b3731898c02e115158a95587c77e69d /pyspark/pyspark/rdd.py
parent8b64b7ecd80c52f2f09a517f1517c0ece7a3d57f (diff)
downloadspark-6904cb77d4306a14891cc71338c8f9f966d009f1.tar.gz
spark-6904cb77d4306a14891cc71338c8f9f966d009f1.tar.bz2
spark-6904cb77d4306a14891cc71338c8f9f966d009f1.zip
Use local combiners in Python API combineByKey().
Diffstat (limited to 'pyspark/pyspark/rdd.py')
-rw-r--r--pyspark/pyspark/rdd.py33
1 files changed, 24 insertions, 9 deletions
diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py
index fd41ea0b17..3528b8f308 100644
--- a/pyspark/pyspark/rdd.py
+++ b/pyspark/pyspark/rdd.py
@@ -46,7 +46,7 @@ class RDD(object):
[(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)]
"""
def func(iterator): return chain.from_iterable(imap(f, iterator))
- return PipelinedRDD(self, func)
+ return self.mapPartitions(func)
def mapPartitions(self, f):
"""
@@ -64,7 +64,7 @@ class RDD(object):
[2, 4]
"""
def func(iterator): return ifilter(f, iterator)
- return PipelinedRDD(self, func)
+ return self.mapPartitions(func)
def _pipe(self, functions, command):
class_manifest = self._jrdd.classManifest()
@@ -118,7 +118,7 @@ class RDD(object):
[1, 2]
"""
def func(iterator): yield list(iterator)
- return PipelinedRDD(self, func)
+ return self.mapPartitions(func)
def cartesian(self, other):
"""
@@ -167,7 +167,7 @@ class RDD(object):
acc = f(obj, acc)
if acc is not None:
yield acc
- vals = PipelinedRDD(self, func).collect()
+ vals = self.mapPartitions(func).collect()
return reduce(f, vals)
def fold(self, zeroValue, op):
@@ -187,7 +187,7 @@ class RDD(object):
for obj in iterator:
acc = op(obj, acc)
yield acc
- vals = PipelinedRDD(self, func).collect()
+ vals = self.mapPartitions(func).collect()
return reduce(op, vals, zeroValue)
# TODO: aggregate
@@ -330,10 +330,25 @@ class RDD(object):
"""
if numSplits is None:
numSplits = self.ctx.defaultParallelism
- shuffled = self.partitionBy(numSplits)
- functions = [createCombiner, mergeValue, mergeCombiners]
- jpairs = shuffled._pipe(functions, "combine_by_key")
- return RDD(jpairs, self.ctx)
+ def combineLocally(iterator):
+ combiners = {}
+ for (k, v) in iterator:
+ if k not in combiners:
+ combiners[k] = createCombiner(v)
+ else:
+ combiners[k] = mergeValue(combiners[k], v)
+ return combiners.iteritems()
+ locally_combined = self.mapPartitions(combineLocally)
+ shuffled = locally_combined.partitionBy(numSplits)
+ def _mergeCombiners(iterator):
+ combiners = {}
+ for (k, v) in iterator:
+ if not k in combiners:
+ combiners[k] = v
+ else:
+ combiners[k] = mergeCombiners(combiners[k], v)
+ return combiners.iteritems()
+ return shuffled.mapPartitions(_mergeCombiners)
def groupByKey(self, numSplits=None):
"""