diff options
Diffstat (limited to 'python/pyspark')
-rw-r--r-- | python/pyspark/rdd.py | 12 | ||||
-rw-r--r-- | python/pyspark/sql.py | 2 | ||||
-rw-r--r-- | python/pyspark/tests.py | 8 |
3 files changed, 16 insertions, 6 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 8ed89e2f97..dc6497772e 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2073,6 +2073,12 @@ class PipelinedRDD(RDD): self._jrdd_deserializer = self.ctx.serializer self._bypass_serializer = False self._partitionFunc = prev._partitionFunc if self.preservesPartitioning else None + self._broadcast = None + + def __del__(self): + if self._broadcast: + self._broadcast.unpersist() + self._broadcast = None @property def _jrdd(self): @@ -2087,9 +2093,9 @@ class PipelinedRDD(RDD): # the serialized command will be compressed by broadcast ser = CloudPickleSerializer() pickled_command = ser.dumps(command) - if pickled_command > (1 << 20): # 1M - broadcast = self.ctx.broadcast(pickled_command) - pickled_command = ser.dumps(broadcast) + if len(pickled_command) > (1 << 20): # 1M + self._broadcast = self.ctx.broadcast(pickled_command) + pickled_command = ser.dumps(self._broadcast) broadcast_vars = ListConverter().convert( [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], self.ctx._gateway._gateway_client) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index d8bdf22355..974b5e287b 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -965,7 +965,7 @@ class SQLContext(object): BatchedSerializer(PickleSerializer(), 1024)) ser = CloudPickleSerializer() pickled_command = ser.dumps(command) - if pickled_command > (1 << 20): # 1M + if len(pickled_command) > (1 << 20): # 1M broadcast = self._sc.broadcast(pickled_command) pickled_command = ser.dumps(broadcast) broadcast_vars = ListConverter().convert( diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 7e2bbc9cb6..6fb6bc998c 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -467,8 +467,12 @@ class TestRDDFunctions(PySparkTestCase): def test_large_closure(self): N = 1000000 data = [float(i) for i in xrange(N)] - m = self.sc.parallelize(range(1), 1).map(lambda x: len(data)).sum() - self.assertEquals(N, m) + rdd = self.sc.parallelize(range(1), 1).map(lambda x: len(data)) + self.assertEquals(N, rdd.first()) + self.assertTrue(rdd._broadcast is not None) + rdd = self.sc.parallelize(range(1), 1).map(lambda x: 1) + self.assertEqual(1, rdd.first()) + self.assertTrue(rdd._broadcast is None) def test_zip_with_different_serializers(self): a = self.sc.parallelize(range(5)) |