aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/broadcast.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/broadcast.py')
-rw-r--r--python/pyspark/broadcast.py48
1 files changed, 48 insertions, 0 deletions
diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py
new file mode 100644
index 0000000000..93876fa738
--- /dev/null
+++ b/python/pyspark/broadcast.py
@@ -0,0 +1,48 @@
+"""
+>>> from pyspark.context import SparkContext
+>>> sc = SparkContext('local', 'test')
+>>> b = sc.broadcast([1, 2, 3, 4, 5])
+>>> b.value
+[1, 2, 3, 4, 5]
+
+>>> from pyspark.broadcast import _broadcastRegistry
+>>> _broadcastRegistry[b.bid] = b
+>>> from cPickle import dumps, loads
+>>> loads(dumps(b)).value
+[1, 2, 3, 4, 5]
+
+>>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect()
+[1, 2, 3, 4, 5, 1, 2, 3, 4, 5]
+
+>>> large_broadcast = sc.broadcast(list(range(10000)))
+"""
+# Holds broadcasted data received from Java, keyed by its id.
+_broadcastRegistry = {}
+
+
+def _from_id(bid):
+ from pyspark.broadcast import _broadcastRegistry
+ if bid not in _broadcastRegistry:
+ raise Exception("Broadcast variable '%s' not loaded!" % bid)
+ return _broadcastRegistry[bid]
+
+
+class Broadcast(object):
+ def __init__(self, bid, value, java_broadcast=None, pickle_registry=None):
+ self.value = value
+ self.bid = bid
+ self._jbroadcast = java_broadcast
+ self._pickle_registry = pickle_registry
+
+ def __reduce__(self):
+ self._pickle_registry.add(self)
+ return (_from_id, (self.bid, ))
+
+
+def _test():
+ import doctest
+ doctest.testmod()
+
+
+if __name__ == "__main__":
+ _test()