aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala8
-rw-r--r--python/pyspark/broadcast.py37
-rw-r--r--python/pyspark/context.py20
-rw-r--r--python/pyspark/rdd.py5
-rw-r--r--python/pyspark/serializers.py17
-rw-r--r--python/pyspark/tests.py7
-rw-r--r--python/pyspark/worker.py8
7 files changed, 81 insertions, 21 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 9f5c5bd30f..10210a2927 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -315,6 +315,14 @@ private[spark] object PythonRDD extends Logging {
JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
}
+ def readBroadcastFromFile(sc: JavaSparkContext, filename: String): Broadcast[Array[Byte]] = {
+ val file = new DataInputStream(new FileInputStream(filename))
+ val length = file.readInt()
+ val obj = new Array[Byte](length)
+ file.readFully(obj)
+ sc.broadcast(obj)
+ }
+
def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) {
// The right way to implement this would be to use TypeTags to get the full
// type of T. Since I don't want to introduce breaking changes throughout the
diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py
index f3e64989ed..675a2fcd2f 100644
--- a/python/pyspark/broadcast.py
+++ b/python/pyspark/broadcast.py
@@ -21,18 +21,16 @@
>>> b = sc.broadcast([1, 2, 3, 4, 5])
>>> b.value
[1, 2, 3, 4, 5]
-
->>> from pyspark.broadcast import _broadcastRegistry
->>> _broadcastRegistry[b.bid] = b
->>> from cPickle import dumps, loads
->>> loads(dumps(b)).value
-[1, 2, 3, 4, 5]
-
>>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect()
[1, 2, 3, 4, 5, 1, 2, 3, 4, 5]
+>>> b.unpersist()
>>> large_broadcast = sc.broadcast(list(range(10000)))
"""
+import os
+
+from pyspark.serializers import CompressedSerializer, PickleSerializer
+
# Holds broadcasted data received from Java, keyed by its id.
_broadcastRegistry = {}
@@ -52,17 +50,38 @@ class Broadcast(object):
Access its value through C{.value}.
"""
- def __init__(self, bid, value, java_broadcast=None, pickle_registry=None):
+ def __init__(self, bid, value, java_broadcast=None,
+ pickle_registry=None, path=None):
"""
Should not be called directly by users -- use
L{SparkContext.broadcast()<pyspark.context.SparkContext.broadcast>}
instead.
"""
- self.value = value
self.bid = bid
+ if path is None:
+ self.value = value
self._jbroadcast = java_broadcast
self._pickle_registry = pickle_registry
+ self.path = path
+
+ def unpersist(self, blocking=False):
+ self._jbroadcast.unpersist(blocking)
+ os.unlink(self.path)
def __reduce__(self):
self._pickle_registry.add(self)
return (_from_id, (self.bid, ))
+
+ def __getattr__(self, item):
+ if item == 'value' and self.path is not None:
+ ser = CompressedSerializer(PickleSerializer())
+ value = ser.load_stream(open(self.path)).next()
+ self.value = value
+ return value
+
+ raise AttributeError(item)
+
+
+if __name__ == "__main__":
+ import doctest
+ doctest.testmod()
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 6c04923881..a90870ed3a 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -29,7 +29,7 @@ from pyspark.conf import SparkConf
from pyspark.files import SparkFiles
from pyspark.java_gateway import launch_gateway
from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \
- PairDeserializer
+ PairDeserializer, CompressedSerializer
from pyspark.storagelevel import StorageLevel
from pyspark import rdd
from pyspark.rdd import RDD
@@ -566,13 +566,19 @@ class SparkContext(object):
"""
Broadcast a read-only variable to the cluster, returning a
L{Broadcast<pyspark.broadcast.Broadcast>}
- object for reading it in distributed functions. The variable will be
- sent to each cluster only once.
+ object for reading it in distributed functions. The variable will
+ be sent to each cluster only once.
+
+ :keep: Keep the `value` in driver or not.
"""
- pickleSer = PickleSerializer()
- pickled = pickleSer.dumps(value)
- jbroadcast = self._jsc.broadcast(bytearray(pickled))
- return Broadcast(jbroadcast.id(), value, jbroadcast, self._pickled_broadcast_vars)
+ ser = CompressedSerializer(PickleSerializer())
+ # pass large object by py4j is very slow and need much memory
+ tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
+ ser.dump_stream([value], tempFile)
+ tempFile.close()
+ jbroadcast = self._jvm.PythonRDD.readBroadcastFromFile(self._jsc, tempFile.name)
+ return Broadcast(jbroadcast.id(), None, jbroadcast,
+ self._pickled_broadcast_vars, tempFile.name)
def accumulator(self, value, accum_param=None):
"""
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 3934bdda0a..240381e5ba 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -36,7 +36,7 @@ from math import sqrt, log
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
BatchedSerializer, CloudPickleSerializer, PairDeserializer, \
- PickleSerializer, pack_long
+ PickleSerializer, pack_long, CompressedSerializer
from pyspark.join import python_join, python_left_outer_join, \
python_right_outer_join, python_cogroup
from pyspark.statcounter import StatCounter
@@ -1810,7 +1810,8 @@ class PipelinedRDD(RDD):
self._jrdd_deserializer = NoOpSerializer()
command = (self.func, self._prev_jrdd_deserializer,
self._jrdd_deserializer)
- pickled_command = CloudPickleSerializer().dumps(command)
+ ser = CompressedSerializer(CloudPickleSerializer())
+ pickled_command = ser.dumps(command)
broadcast_vars = ListConverter().convert(
[x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
self.ctx._gateway._gateway_client)
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index df90cafb24..74870c0edc 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -67,6 +67,7 @@ import struct
import sys
import types
import collections
+import zlib
from pyspark import cloudpickle
@@ -403,6 +404,22 @@ class AutoSerializer(FramedSerializer):
raise ValueError("invalid sevialization type: %s" % _type)
+class CompressedSerializer(FramedSerializer):
+ """
+ compress the serialized data
+ """
+
+ def __init__(self, serializer):
+ FramedSerializer.__init__(self)
+ self.serializer = serializer
+
+ def dumps(self, obj):
+ return zlib.compress(self.serializer.dumps(obj), 1)
+
+ def loads(self, obj):
+ return self.serializer.loads(zlib.decompress(obj))
+
+
class UTF8Deserializer(Serializer):
"""
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 22b51110ed..f1fece998c 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -323,6 +323,13 @@ class TestRDDFunctions(PySparkTestCase):
theDoes = self.sc.parallelize([jon, jane])
self.assertEquals([jon, jane], theDoes.collect())
+ def test_large_broadcast(self):
+ N = 100000
+ data = [[float(i) for i in range(300)] for i in range(N)]
+ bdata = self.sc.broadcast(data) # 270MB
+ m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum()
+ self.assertEquals(N, m)
+
class TestIO(PySparkTestCase):
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 2770f63059..77a9c4a0e0 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -30,7 +30,8 @@ from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.cloudpickle import CloudPickler
from pyspark.files import SparkFiles
from pyspark.serializers import write_with_length, write_int, read_long, \
- write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer
+ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
+ CompressedSerializer
pickleSer = PickleSerializer()
@@ -65,12 +66,13 @@ def main(infile, outfile):
# fetch names and values of broadcast variables
num_broadcast_variables = read_int(infile)
+ ser = CompressedSerializer(pickleSer)
for _ in range(num_broadcast_variables):
bid = read_long(infile)
- value = pickleSer._read_with_length(infile)
+ value = ser._read_with_length(infile)
_broadcastRegistry[bid] = Broadcast(bid, value)
- command = pickleSer._read_with_length(infile)
+ command = ser._read_with_length(infile)
(func, deserializer, serializer) = command
init_time = time.time()
iterator = deserializer.load_stream(infile)