From ccd075cf960df6c6c449b709515cdd81499a52be Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 24 Dec 2012 15:01:13 -0800 Subject: Reduce object overhead in Pyspark shuffle and collect --- pyspark/pyspark/rdd.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) (limited to 'pyspark') diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index 85a24c6854..708ea6eb55 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -145,8 +145,10 @@ class RDD(object): self.map(f).collect() # Force evaluation def collect(self): - pickle = self.ctx.arrayAsPickle(self._jrdd.rdd().collect()) - return load_pickle(bytes(pickle)) + def asList(iterator): + yield list(iterator) + pickles = self.mapPartitions(asList)._jrdd.rdd().collect() + return list(chain.from_iterable(load_pickle(bytes(p)) for p in pickles)) def reduce(self, f): """ @@ -319,16 +321,23 @@ class RDD(object): if numSplits is None: numSplits = self.ctx.defaultParallelism def add_shuffle_key(iterator): + buckets = defaultdict(list) for (k, v) in iterator: - yield str(hashFunc(k)) - yield dump_pickle((k, v)) + buckets[hashFunc(k) % numSplits].append((k, v)) + for (split, items) in buckets.iteritems(): + yield str(split) + yield dump_pickle(items) keyed = PipelinedRDD(self, add_shuffle_key) keyed._bypass_serializer = True pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() partitioner = self.ctx.jvm.spark.api.python.PythonPartitioner(numSplits) + # Transferring O(n) objects to Java is too expensive. Instead, we'll + # form the hash buckets in Python, transferring O(numSplits) objects + # to Java. Each object is a (splitNumber, [objects]) pair. jrdd = pairRDD.partitionBy(partitioner) jrdd = jrdd.map(self.ctx.jvm.ExtractValue()) - return RDD(jrdd, self.ctx) + # Flatten the resulting RDD: + return RDD(jrdd, self.ctx).flatMap(lambda items: items) def combineByKey(self, createCombiner, mergeValue, mergeCombiners, numSplits=None): -- cgit v1.2.3