aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/broadcast.py17
-rw-r--r--python/pyspark/tests.py15
2 files changed, 31 insertions, 1 deletions
diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py
index 663c9abe08..a0b819220e 100644
--- a/python/pyspark/broadcast.py
+++ b/python/pyspark/broadcast.py
@@ -99,11 +99,26 @@ class Broadcast(object):
def unpersist(self, blocking=False):
"""
- Delete cached copies of this broadcast on the executors.
+ Delete cached copies of this broadcast on the executors. If the
+ broadcast is used after this is called, it will need to be
+ re-sent to each executor.
+
+ :param blocking: Whether to block until unpersisting has completed
"""
if self._jbroadcast is None:
raise Exception("Broadcast can only be unpersisted in driver")
self._jbroadcast.unpersist(blocking)
+
+ def destroy(self):
+ """
+ Destroy all data and metadata related to this broadcast variable.
+ Use this with caution; once a broadcast variable has been destroyed,
+ it cannot be used again. This method blocks until destroy has
+ completed.
+ """
+ if self._jbroadcast is None:
+ raise Exception("Broadcast can only be destroyed in driver")
+ self._jbroadcast.destroy()
os.unlink(self._path)
def __reduce__(self):
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 40fccb8c00..15c87e22f9 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -694,6 +694,21 @@ class RDDTests(ReusedPySparkTestCase):
m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum()
self.assertEqual(N, m)
+ def test_unpersist(self):
+ N = 1000
+ data = [[float(i) for i in range(300)] for i in range(N)]
+ bdata = self.sc.broadcast(data) # 3MB
+ bdata.unpersist()
+ m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum()
+ self.assertEqual(N, m)
+ bdata.destroy()
+ try:
+ self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum()
+ except Exception as e:
+ pass
+ else:
+ raise Exception("job should fail after destroy the broadcast")
+
def test_multiple_broadcasts(self):
N = 1 << 21
b1 = self.sc.broadcast(set(range(N))) # multiple blocks in JVM