diff options
author | Josh Rosen <joshrosen@eecs.berkeley.edu> | 2012-12-24 15:01:13 -0800 |
---|---|---|
committer | Josh Rosen <joshrosen@eecs.berkeley.edu> | 2012-12-24 15:01:13 -0800 |
commit | ccd075cf960df6c6c449b709515cdd81499a52be (patch) | |
tree | ecf2d94b4949625d3db03ebe3726a53aaaef5fba /pyspark | |
parent | 2ccf3b665280bf5b0919e3801d028126cb070dbd (diff) | |
download | spark-ccd075cf960df6c6c449b709515cdd81499a52be.tar.gz spark-ccd075cf960df6c6c449b709515cdd81499a52be.tar.bz2 spark-ccd075cf960df6c6c449b709515cdd81499a52be.zip |
Reduce object overhead in Pyspark shuffle and collect
Diffstat (limited to 'pyspark')
-rw-r--r-- | pyspark/pyspark/rdd.py | 19 |
1 files changed, 14 insertions, 5 deletions
diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index 85a24c6854..708ea6eb55 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -145,8 +145,10 @@ class RDD(object): self.map(f).collect() # Force evaluation def collect(self): - pickle = self.ctx.arrayAsPickle(self._jrdd.rdd().collect()) - return load_pickle(bytes(pickle)) + def asList(iterator): + yield list(iterator) + pickles = self.mapPartitions(asList)._jrdd.rdd().collect() + return list(chain.from_iterable(load_pickle(bytes(p)) for p in pickles)) def reduce(self, f): """ @@ -319,16 +321,23 @@ class RDD(object): if numSplits is None: numSplits = self.ctx.defaultParallelism def add_shuffle_key(iterator): + buckets = defaultdict(list) for (k, v) in iterator: - yield str(hashFunc(k)) - yield dump_pickle((k, v)) + buckets[hashFunc(k) % numSplits].append((k, v)) + for (split, items) in buckets.iteritems(): + yield str(split) + yield dump_pickle(items) keyed = PipelinedRDD(self, add_shuffle_key) keyed._bypass_serializer = True pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() partitioner = self.ctx.jvm.spark.api.python.PythonPartitioner(numSplits) + # Transferring O(n) objects to Java is too expensive. Instead, we'll + # form the hash buckets in Python, transferring O(numSplits) objects + # to Java. Each object is a (splitNumber, [objects]) pair. jrdd = pairRDD.partitionBy(partitioner) jrdd = jrdd.map(self.ctx.jvm.ExtractValue()) - return RDD(jrdd, self.ctx) + # Flatten the resulting RDD: + return RDD(jrdd, self.ctx).flatMap(lambda items: items) def combineByKey(self, createCombiner, mergeValue, mergeCombiners, numSplits=None): |