aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-04-06 10:46:34 -0700
committerYin Huai <yhuai@databricks.com>2016-04-06 10:46:34 -0700
commit90ca1844865baf96656a9e5efdf56f415f2646be (patch)
tree6b15a9383fccdb3119e9e4179ff00ba0f44e72c8 /python
parent59236e5c5b9d24f90fcf8d09b23ae8b06355657e (diff)
downloadspark-90ca1844865baf96656a9e5efdf56f415f2646be.tar.gz
spark-90ca1844865baf96656a9e5efdf56f415f2646be.tar.bz2
spark-90ca1844865baf96656a9e5efdf56f415f2646be.zip
[SPARK-14418][PYSPARK] fix unpersist of Broadcast in Python
## What changes were proposed in this pull request? Currently, Broaccast.unpersist() will remove the file of broadcast, which should be the behavior of destroy(). This PR added destroy() for Broadcast in Python, to match the sematics in Scala. ## How was this patch tested? Added regression tests. Author: Davies Liu <davies@databricks.com> Closes #12189 from davies/py_unpersist.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/broadcast.py17
-rw-r--r--python/pyspark/tests.py15
2 files changed, 31 insertions, 1 deletions
diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py
index 663c9abe08..a0b819220e 100644
--- a/python/pyspark/broadcast.py
+++ b/python/pyspark/broadcast.py
@@ -99,11 +99,26 @@ class Broadcast(object):
def unpersist(self, blocking=False):
"""
- Delete cached copies of this broadcast on the executors.
+ Delete cached copies of this broadcast on the executors. If the
+ broadcast is used after this is called, it will need to be
+ re-sent to each executor.
+
+ :param blocking: Whether to block until unpersisting has completed
"""
if self._jbroadcast is None:
raise Exception("Broadcast can only be unpersisted in driver")
self._jbroadcast.unpersist(blocking)
+
+ def destroy(self):
+ """
+ Destroy all data and metadata related to this broadcast variable.
+ Use this with caution; once a broadcast variable has been destroyed,
+ it cannot be used again. This method blocks until destroy has
+ completed.
+ """
+ if self._jbroadcast is None:
+ raise Exception("Broadcast can only be destroyed in driver")
+ self._jbroadcast.destroy()
os.unlink(self._path)
def __reduce__(self):
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 40fccb8c00..15c87e22f9 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -694,6 +694,21 @@ class RDDTests(ReusedPySparkTestCase):
m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum()
self.assertEqual(N, m)
+ def test_unpersist(self):
+ N = 1000
+ data = [[float(i) for i in range(300)] for i in range(N)]
+ bdata = self.sc.broadcast(data) # 3MB
+ bdata.unpersist()
+ m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum()
+ self.assertEqual(N, m)
+ bdata.destroy()
+ try:
+ self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum()
+ except Exception as e:
+ pass
+ else:
+ raise Exception("job should fail after destroy the broadcast")
+
def test_multiple_broadcasts(self):
N = 1 << 21
b1 = self.sc.broadcast(set(range(N))) # multiple blocks in JVM