aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/rdd.py4
-rw-r--r--python/pyspark/sql.py8
-rw-r--r--python/pyspark/tests.py6
-rw-r--r--python/pyspark/worker.py4
4 files changed, 19 insertions, 3 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index cb09c191be..b43606b730 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -2061,8 +2061,12 @@ class PipelinedRDD(RDD):
self._jrdd_deserializer = NoOpSerializer()
command = (self.func, self._prev_jrdd_deserializer,
self._jrdd_deserializer)
+ # the serialized command will be compressed by broadcast
ser = CloudPickleSerializer()
pickled_command = ser.dumps(command)
+ if pickled_command > (1 << 20): # 1M
+ broadcast = self.ctx.broadcast(pickled_command)
+ pickled_command = ser.dumps(broadcast)
broadcast_vars = ListConverter().convert(
[x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
self.ctx._gateway._gateway_client)
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index 8f6dbab240..42a9920f10 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -27,7 +27,7 @@ import warnings
from array import array
from operator import itemgetter
-from pyspark.rdd import RDD, PipelinedRDD
+from pyspark.rdd import RDD
from pyspark.serializers import BatchedSerializer, PickleSerializer, CloudPickleSerializer
from pyspark.storagelevel import StorageLevel
from pyspark.traceback_utils import SCCallSiteSync
@@ -975,7 +975,11 @@ class SQLContext(object):
command = (func,
BatchedSerializer(PickleSerializer(), 1024),
BatchedSerializer(PickleSerializer(), 1024))
- pickled_command = CloudPickleSerializer().dumps(command)
+ ser = CloudPickleSerializer()
+ pickled_command = ser.dumps(command)
+ if pickled_command > (1 << 20): # 1M
+ broadcast = self._sc.broadcast(pickled_command)
+ pickled_command = ser.dumps(broadcast)
broadcast_vars = ListConverter().convert(
[x._jbroadcast for x in self._sc._pickled_broadcast_vars],
self._sc._gateway._gateway_client)
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 0b3854347a..7301966e48 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -434,6 +434,12 @@ class TestRDDFunctions(PySparkTestCase):
m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum()
self.assertEquals(N, m)
+ def test_large_closure(self):
+ N = 1000000
+ data = [float(i) for i in xrange(N)]
+ m = self.sc.parallelize(range(1), 1).map(lambda x: len(data)).sum()
+ self.assertEquals(N, m)
+
def test_zip_with_different_serializers(self):
a = self.sc.parallelize(range(5))
b = self.sc.parallelize(range(100, 105))
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 252176ac65..d6c06e2dbe 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -77,10 +77,12 @@ def main(infile, outfile):
_broadcastRegistry[bid] = Broadcast(bid, value)
else:
bid = - bid - 1
- _broadcastRegistry.remove(bid)
+ _broadcastRegistry.pop(bid)
_accumulatorRegistry.clear()
command = pickleSer._read_with_length(infile)
+ if isinstance(command, Broadcast):
+ command = pickleSer.loads(command.value)
(func, deserializer, serializer) = command
init_time = time.time()
iterator = deserializer.load_stream(infile)