aboutsummaryrefslogtreecommitdiff
path: root/pyspark
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@eecs.berkeley.edu>2012-12-24 15:01:13 -0800
committerJosh Rosen <joshrosen@eecs.berkeley.edu>2012-12-24 15:01:13 -0800
commitccd075cf960df6c6c449b709515cdd81499a52be (patch)
treeecf2d94b4949625d3db03ebe3726a53aaaef5fba /pyspark
parent2ccf3b665280bf5b0919e3801d028126cb070dbd (diff)
downloadspark-ccd075cf960df6c6c449b709515cdd81499a52be.tar.gz
spark-ccd075cf960df6c6c449b709515cdd81499a52be.tar.bz2
spark-ccd075cf960df6c6c449b709515cdd81499a52be.zip
Reduce object overhead in Pyspark shuffle and collect
Diffstat (limited to 'pyspark')
-rw-r--r--pyspark/pyspark/rdd.py19
1 files changed, 14 insertions, 5 deletions
diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py
index 85a24c6854..708ea6eb55 100644
--- a/pyspark/pyspark/rdd.py
+++ b/pyspark/pyspark/rdd.py
@@ -145,8 +145,10 @@ class RDD(object):
self.map(f).collect() # Force evaluation
def collect(self):
- pickle = self.ctx.arrayAsPickle(self._jrdd.rdd().collect())
- return load_pickle(bytes(pickle))
+ def asList(iterator):
+ yield list(iterator)
+ pickles = self.mapPartitions(asList)._jrdd.rdd().collect()
+ return list(chain.from_iterable(load_pickle(bytes(p)) for p in pickles))
def reduce(self, f):
"""
@@ -319,16 +321,23 @@ class RDD(object):
if numSplits is None:
numSplits = self.ctx.defaultParallelism
def add_shuffle_key(iterator):
+ buckets = defaultdict(list)
for (k, v) in iterator:
- yield str(hashFunc(k))
- yield dump_pickle((k, v))
+ buckets[hashFunc(k) % numSplits].append((k, v))
+ for (split, items) in buckets.iteritems():
+ yield str(split)
+ yield dump_pickle(items)
keyed = PipelinedRDD(self, add_shuffle_key)
keyed._bypass_serializer = True
pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
partitioner = self.ctx.jvm.spark.api.python.PythonPartitioner(numSplits)
+ # Transferring O(n) objects to Java is too expensive. Instead, we'll
+ # form the hash buckets in Python, transferring O(numSplits) objects
+ # to Java. Each object is a (splitNumber, [objects]) pair.
jrdd = pairRDD.partitionBy(partitioner)
jrdd = jrdd.map(self.ctx.jvm.ExtractValue())
- return RDD(jrdd, self.ctx)
+ # Flatten the resulting RDD:
+ return RDD(jrdd, self.ctx).flatMap(lambda items: items)
def combineByKey(self, createCombiner, mergeValue, mergeCombiners,
numSplits=None):