From e77fa81a61798c89d5a9b6c9dc067d11785254b7 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 18 Sep 2014 18:11:48 -0700 Subject: [SPARK-3554] [PySpark] use broadcast automatically for large closure Py4j can not handle large string efficiently, so we should use broadcast for large closure automatically. (Broadcast use local filesystem to pass through data). Author: Davies Liu Closes #2417 from davies/command and squashes the following commits: fbf4e97 [Davies Liu] bugfix aefd508 [Davies Liu] use broadcast automatically for large closure --- python/pyspark/rdd.py | 4 ++++ python/pyspark/sql.py | 8 ++++++-- python/pyspark/tests.py | 6 ++++++ python/pyspark/worker.py | 4 +++- 4 files changed, 19 insertions(+), 3 deletions(-) (limited to 'python') diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index cb09c191be..b43606b730 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2061,8 +2061,12 @@ class PipelinedRDD(RDD): self._jrdd_deserializer = NoOpSerializer() command = (self.func, self._prev_jrdd_deserializer, self._jrdd_deserializer) + # the serialized command will be compressed by broadcast ser = CloudPickleSerializer() pickled_command = ser.dumps(command) + if pickled_command > (1 << 20): # 1M + broadcast = self.ctx.broadcast(pickled_command) + pickled_command = ser.dumps(broadcast) broadcast_vars = ListConverter().convert( [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], self.ctx._gateway._gateway_client) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 8f6dbab240..42a9920f10 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -27,7 +27,7 @@ import warnings from array import array from operator import itemgetter -from pyspark.rdd import RDD, PipelinedRDD +from pyspark.rdd import RDD from pyspark.serializers import BatchedSerializer, PickleSerializer, CloudPickleSerializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync @@ -975,7 +975,11 @@ class SQLContext(object): command = (func, BatchedSerializer(PickleSerializer(), 1024), BatchedSerializer(PickleSerializer(), 1024)) - pickled_command = CloudPickleSerializer().dumps(command) + ser = CloudPickleSerializer() + pickled_command = ser.dumps(command) + if pickled_command > (1 << 20): # 1M + broadcast = self._sc.broadcast(pickled_command) + pickled_command = ser.dumps(broadcast) broadcast_vars = ListConverter().convert( [x._jbroadcast for x in self._sc._pickled_broadcast_vars], self._sc._gateway._gateway_client) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 0b3854347a..7301966e48 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -434,6 +434,12 @@ class TestRDDFunctions(PySparkTestCase): m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() self.assertEquals(N, m) + def test_large_closure(self): + N = 1000000 + data = [float(i) for i in xrange(N)] + m = self.sc.parallelize(range(1), 1).map(lambda x: len(data)).sum() + self.assertEquals(N, m) + def test_zip_with_different_serializers(self): a = self.sc.parallelize(range(5)) b = self.sc.parallelize(range(100, 105)) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 252176ac65..d6c06e2dbe 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -77,10 +77,12 @@ def main(infile, outfile): _broadcastRegistry[bid] = Broadcast(bid, value) else: bid = - bid - 1 - _broadcastRegistry.remove(bid) + _broadcastRegistry.pop(bid) _accumulatorRegistry.clear() command = pickleSer._read_with_length(infile) + if isinstance(command, Broadcast): + command = pickleSer.loads(command.value) (func, deserializer, serializer) = command init_time = time.time() iterator = deserializer.load_stream(infile) -- cgit v1.2.3