aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/broadcast.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/broadcast.py')
-rw-r--r--python/pyspark/broadcast.py95
1 files changed, 64 insertions, 31 deletions
diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py
index 01cac3c72c..6b8a8b256a 100644
--- a/python/pyspark/broadcast.py
+++ b/python/pyspark/broadcast.py
@@ -15,21 +15,10 @@
# limitations under the License.
#
-"""
->>> from pyspark.context import SparkContext
->>> sc = SparkContext('local', 'test')
->>> b = sc.broadcast([1, 2, 3, 4, 5])
->>> b.value
-[1, 2, 3, 4, 5]
->>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect()
-[1, 2, 3, 4, 5, 1, 2, 3, 4, 5]
->>> b.unpersist()
-
->>> large_broadcast = sc.broadcast(list(range(10000)))
-"""
import os
-
-from pyspark.serializers import LargeObjectSerializer
+import cPickle
+import gc
+from tempfile import NamedTemporaryFile
__all__ = ['Broadcast']
@@ -49,44 +38,88 @@ def _from_id(bid):
class Broadcast(object):
"""
- A broadcast variable created with
- L{SparkContext.broadcast()<pyspark.context.SparkContext.broadcast>}.
+ A broadcast variable created with L{SparkContext.broadcast()}.
Access its value through C{.value}.
+
+ Examples:
+
+ >>> from pyspark.context import SparkContext
+ >>> sc = SparkContext('local', 'test')
+ >>> b = sc.broadcast([1, 2, 3, 4, 5])
+ >>> b.value
+ [1, 2, 3, 4, 5]
+ >>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect()
+ [1, 2, 3, 4, 5, 1, 2, 3, 4, 5]
+ >>> b.unpersist()
+
+ >>> large_broadcast = sc.broadcast(range(10000))
"""
- def __init__(self, bid, value, java_broadcast=None,
- pickle_registry=None, path=None):
+ def __init__(self, sc=None, value=None, pickle_registry=None, path=None):
"""
- Should not be called directly by users -- use
- L{SparkContext.broadcast()<pyspark.context.SparkContext.broadcast>}
+ Should not be called directly by users -- use L{SparkContext.broadcast()}
instead.
"""
- self.bid = bid
- if path is None:
- self._value = value
- self._jbroadcast = java_broadcast
- self._pickle_registry = pickle_registry
- self.path = path
+ if sc is not None:
+ f = NamedTemporaryFile(delete=False, dir=sc._temp_dir)
+ self._path = self.dump(value, f)
+ self._jbroadcast = sc._jvm.PythonRDD.readBroadcastFromFile(sc._jsc, self._path)
+ self._pickle_registry = pickle_registry
+ else:
+ self._jbroadcast = None
+ self._path = path
+
+ def dump(self, value, f):
+ if isinstance(value, basestring):
+ if isinstance(value, unicode):
+ f.write('U')
+ value = value.encode('utf8')
+ else:
+ f.write('S')
+ f.write(value)
+ else:
+ f.write('P')
+ cPickle.dump(value, f, 2)
+ f.close()
+ return f.name
+
+ def load(self, path):
+ with open(path, 'rb', 1 << 20) as f:
+ flag = f.read(1)
+ data = f.read()
+ if flag == 'P':
+ # cPickle.loads() may create lots of objects, disable GC
+ # temporary for better performance
+ gc.disable()
+ try:
+ return cPickle.loads(data)
+ finally:
+ gc.enable()
+ else:
+ return data.decode('utf8') if flag == 'U' else data
@property
def value(self):
""" Return the broadcasted value
"""
- if not hasattr(self, "_value") and self.path is not None:
- ser = LargeObjectSerializer()
- self._value = ser.load_stream(open(self.path)).next()
+ if not hasattr(self, "_value") and self._path is not None:
+ self._value = self.load(self._path)
return self._value
def unpersist(self, blocking=False):
"""
Delete cached copies of this broadcast on the executors.
"""
+ if self._jbroadcast is None:
+ raise Exception("Broadcast can only be unpersisted in driver")
self._jbroadcast.unpersist(blocking)
- os.unlink(self.path)
+ os.unlink(self._path)
def __reduce__(self):
+ if self._jbroadcast is None:
+ raise Exception("Broadcast can only be serialized in driver")
self._pickle_registry.add(self)
- return (_from_id, (self.bid, ))
+ return _from_id, (self._jbroadcast.id(),)
if __name__ == "__main__":