diff options
Diffstat (limited to 'python/pyspark')
-rw-r--r-- | python/pyspark/broadcast.py | 17 | ||||
-rw-r--r-- | python/pyspark/tests.py | 15 |
2 files changed, 31 insertions, 1 deletions
diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index 663c9abe08..a0b819220e 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -99,11 +99,26 @@ class Broadcast(object): def unpersist(self, blocking=False): """ - Delete cached copies of this broadcast on the executors. + Delete cached copies of this broadcast on the executors. If the + broadcast is used after this is called, it will need to be + re-sent to each executor. + + :param blocking: Whether to block until unpersisting has completed """ if self._jbroadcast is None: raise Exception("Broadcast can only be unpersisted in driver") self._jbroadcast.unpersist(blocking) + + def destroy(self): + """ + Destroy all data and metadata related to this broadcast variable. + Use this with caution; once a broadcast variable has been destroyed, + it cannot be used again. This method blocks until destroy has + completed. + """ + if self._jbroadcast is None: + raise Exception("Broadcast can only be destroyed in driver") + self._jbroadcast.destroy() os.unlink(self._path) def __reduce__(self): diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 40fccb8c00..15c87e22f9 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -694,6 +694,21 @@ class RDDTests(ReusedPySparkTestCase): m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() self.assertEqual(N, m) + def test_unpersist(self): + N = 1000 + data = [[float(i) for i in range(300)] for i in range(N)] + bdata = self.sc.broadcast(data) # 3MB + bdata.unpersist() + m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() + self.assertEqual(N, m) + bdata.destroy() + try: + self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() + except Exception as e: + pass + else: + raise Exception("job should fail after destroy the broadcast") + def test_multiple_broadcasts(self): N = 1 << 21 b1 = self.sc.broadcast(set(range(N))) # multiple blocks in JVM |