diff options
Diffstat (limited to 'python/pyspark/broadcast.py')
-rw-r--r-- | python/pyspark/broadcast.py | 95 |
1 files changed, 64 insertions, 31 deletions
diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index 01cac3c72c..6b8a8b256a 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -15,21 +15,10 @@ # limitations under the License. # -""" ->>> from pyspark.context import SparkContext ->>> sc = SparkContext('local', 'test') ->>> b = sc.broadcast([1, 2, 3, 4, 5]) ->>> b.value -[1, 2, 3, 4, 5] ->>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect() -[1, 2, 3, 4, 5, 1, 2, 3, 4, 5] ->>> b.unpersist() - ->>> large_broadcast = sc.broadcast(list(range(10000))) -""" import os - -from pyspark.serializers import LargeObjectSerializer +import cPickle +import gc +from tempfile import NamedTemporaryFile __all__ = ['Broadcast'] @@ -49,44 +38,88 @@ def _from_id(bid): class Broadcast(object): """ - A broadcast variable created with - L{SparkContext.broadcast()<pyspark.context.SparkContext.broadcast>}. + A broadcast variable created with L{SparkContext.broadcast()}. Access its value through C{.value}. + + Examples: + + >>> from pyspark.context import SparkContext + >>> sc = SparkContext('local', 'test') + >>> b = sc.broadcast([1, 2, 3, 4, 5]) + >>> b.value + [1, 2, 3, 4, 5] + >>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect() + [1, 2, 3, 4, 5, 1, 2, 3, 4, 5] + >>> b.unpersist() + + >>> large_broadcast = sc.broadcast(range(10000)) """ - def __init__(self, bid, value, java_broadcast=None, - pickle_registry=None, path=None): + def __init__(self, sc=None, value=None, pickle_registry=None, path=None): """ - Should not be called directly by users -- use - L{SparkContext.broadcast()<pyspark.context.SparkContext.broadcast>} + Should not be called directly by users -- use L{SparkContext.broadcast()} instead. """ - self.bid = bid - if path is None: - self._value = value - self._jbroadcast = java_broadcast - self._pickle_registry = pickle_registry - self.path = path + if sc is not None: + f = NamedTemporaryFile(delete=False, dir=sc._temp_dir) + self._path = self.dump(value, f) + self._jbroadcast = sc._jvm.PythonRDD.readBroadcastFromFile(sc._jsc, self._path) + self._pickle_registry = pickle_registry + else: + self._jbroadcast = None + self._path = path + + def dump(self, value, f): + if isinstance(value, basestring): + if isinstance(value, unicode): + f.write('U') + value = value.encode('utf8') + else: + f.write('S') + f.write(value) + else: + f.write('P') + cPickle.dump(value, f, 2) + f.close() + return f.name + + def load(self, path): + with open(path, 'rb', 1 << 20) as f: + flag = f.read(1) + data = f.read() + if flag == 'P': + # cPickle.loads() may create lots of objects, disable GC + # temporary for better performance + gc.disable() + try: + return cPickle.loads(data) + finally: + gc.enable() + else: + return data.decode('utf8') if flag == 'U' else data @property def value(self): """ Return the broadcasted value """ - if not hasattr(self, "_value") and self.path is not None: - ser = LargeObjectSerializer() - self._value = ser.load_stream(open(self.path)).next() + if not hasattr(self, "_value") and self._path is not None: + self._value = self.load(self._path) return self._value def unpersist(self, blocking=False): """ Delete cached copies of this broadcast on the executors. """ + if self._jbroadcast is None: + raise Exception("Broadcast can only be unpersisted in driver") self._jbroadcast.unpersist(blocking) - os.unlink(self.path) + os.unlink(self._path) def __reduce__(self): + if self._jbroadcast is None: + raise Exception("Broadcast can only be serialized in driver") self._pickle_registry.add(self) - return (_from_id, (self.bid, )) + return _from_id, (self._jbroadcast.id(),) if __name__ == "__main__": |