aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/rdd.py12
-rw-r--r--python/pyspark/sql.py2
-rw-r--r--python/pyspark/tests.py8
3 files changed, 16 insertions, 6 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 8ed89e2f97..dc6497772e 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -2073,6 +2073,12 @@ class PipelinedRDD(RDD):
self._jrdd_deserializer = self.ctx.serializer
self._bypass_serializer = False
self._partitionFunc = prev._partitionFunc if self.preservesPartitioning else None
+ self._broadcast = None
+
+ def __del__(self):
+ if self._broadcast:
+ self._broadcast.unpersist()
+ self._broadcast = None
@property
def _jrdd(self):
@@ -2087,9 +2093,9 @@ class PipelinedRDD(RDD):
# the serialized command will be compressed by broadcast
ser = CloudPickleSerializer()
pickled_command = ser.dumps(command)
- if pickled_command > (1 << 20): # 1M
- broadcast = self.ctx.broadcast(pickled_command)
- pickled_command = ser.dumps(broadcast)
+ if len(pickled_command) > (1 << 20): # 1M
+ self._broadcast = self.ctx.broadcast(pickled_command)
+ pickled_command = ser.dumps(self._broadcast)
broadcast_vars = ListConverter().convert(
[x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
self.ctx._gateway._gateway_client)
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index d8bdf22355..974b5e287b 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -965,7 +965,7 @@ class SQLContext(object):
BatchedSerializer(PickleSerializer(), 1024))
ser = CloudPickleSerializer()
pickled_command = ser.dumps(command)
- if pickled_command > (1 << 20): # 1M
+ if len(pickled_command) > (1 << 20): # 1M
broadcast = self._sc.broadcast(pickled_command)
pickled_command = ser.dumps(broadcast)
broadcast_vars = ListConverter().convert(
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 7e2bbc9cb6..6fb6bc998c 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -467,8 +467,12 @@ class TestRDDFunctions(PySparkTestCase):
def test_large_closure(self):
N = 1000000
data = [float(i) for i in xrange(N)]
- m = self.sc.parallelize(range(1), 1).map(lambda x: len(data)).sum()
- self.assertEquals(N, m)
+ rdd = self.sc.parallelize(range(1), 1).map(lambda x: len(data))
+ self.assertEquals(N, rdd.first())
+ self.assertTrue(rdd._broadcast is not None)
+ rdd = self.sc.parallelize(range(1), 1).map(lambda x: 1)
+ self.assertEqual(1, rdd.first())
+ self.assertTrue(rdd._broadcast is None)
def test_zip_with_different_serializers(self):
a = self.sc.parallelize(range(5))