aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark')
-rw-r--r--python/pyspark/broadcast.py95
-rw-r--r--python/pyspark/context.py12
-rw-r--r--python/pyspark/serializers.py178
-rw-r--r--python/pyspark/tests.py18
-rw-r--r--python/pyspark/worker.py10
5 files changed, 80 insertions, 233 deletions
diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py
index 01cac3c72c..6b8a8b256a 100644
--- a/python/pyspark/broadcast.py
+++ b/python/pyspark/broadcast.py
@@ -15,21 +15,10 @@
# limitations under the License.
#
-"""
->>> from pyspark.context import SparkContext
->>> sc = SparkContext('local', 'test')
->>> b = sc.broadcast([1, 2, 3, 4, 5])
->>> b.value
-[1, 2, 3, 4, 5]
->>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect()
-[1, 2, 3, 4, 5, 1, 2, 3, 4, 5]
->>> b.unpersist()
-
->>> large_broadcast = sc.broadcast(list(range(10000)))
-"""
import os
-
-from pyspark.serializers import LargeObjectSerializer
+import cPickle
+import gc
+from tempfile import NamedTemporaryFile
__all__ = ['Broadcast']
@@ -49,44 +38,88 @@ def _from_id(bid):
class Broadcast(object):
"""
- A broadcast variable created with
- L{SparkContext.broadcast()<pyspark.context.SparkContext.broadcast>}.
+ A broadcast variable created with L{SparkContext.broadcast()}.
Access its value through C{.value}.
+
+ Examples:
+
+ >>> from pyspark.context import SparkContext
+ >>> sc = SparkContext('local', 'test')
+ >>> b = sc.broadcast([1, 2, 3, 4, 5])
+ >>> b.value
+ [1, 2, 3, 4, 5]
+ >>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect()
+ [1, 2, 3, 4, 5, 1, 2, 3, 4, 5]
+ >>> b.unpersist()
+
+ >>> large_broadcast = sc.broadcast(range(10000))
"""
- def __init__(self, bid, value, java_broadcast=None,
- pickle_registry=None, path=None):
+ def __init__(self, sc=None, value=None, pickle_registry=None, path=None):
"""
- Should not be called directly by users -- use
- L{SparkContext.broadcast()<pyspark.context.SparkContext.broadcast>}
+ Should not be called directly by users -- use L{SparkContext.broadcast()}
instead.
"""
- self.bid = bid
- if path is None:
- self._value = value
- self._jbroadcast = java_broadcast
- self._pickle_registry = pickle_registry
- self.path = path
+ if sc is not None:
+ f = NamedTemporaryFile(delete=False, dir=sc._temp_dir)
+ self._path = self.dump(value, f)
+ self._jbroadcast = sc._jvm.PythonRDD.readBroadcastFromFile(sc._jsc, self._path)
+ self._pickle_registry = pickle_registry
+ else:
+ self._jbroadcast = None
+ self._path = path
+
+ def dump(self, value, f):
+ if isinstance(value, basestring):
+ if isinstance(value, unicode):
+ f.write('U')
+ value = value.encode('utf8')
+ else:
+ f.write('S')
+ f.write(value)
+ else:
+ f.write('P')
+ cPickle.dump(value, f, 2)
+ f.close()
+ return f.name
+
+ def load(self, path):
+ with open(path, 'rb', 1 << 20) as f:
+ flag = f.read(1)
+ data = f.read()
+ if flag == 'P':
+ # cPickle.loads() may create lots of objects, disable GC
+ # temporary for better performance
+ gc.disable()
+ try:
+ return cPickle.loads(data)
+ finally:
+ gc.enable()
+ else:
+ return data.decode('utf8') if flag == 'U' else data
@property
def value(self):
""" Return the broadcasted value
"""
- if not hasattr(self, "_value") and self.path is not None:
- ser = LargeObjectSerializer()
- self._value = ser.load_stream(open(self.path)).next()
+ if not hasattr(self, "_value") and self._path is not None:
+ self._value = self.load(self._path)
return self._value
def unpersist(self, blocking=False):
"""
Delete cached copies of this broadcast on the executors.
"""
+ if self._jbroadcast is None:
+ raise Exception("Broadcast can only be unpersisted in driver")
self._jbroadcast.unpersist(blocking)
- os.unlink(self.path)
+ os.unlink(self._path)
def __reduce__(self):
+ if self._jbroadcast is None:
+ raise Exception("Broadcast can only be serialized in driver")
self._pickle_registry.add(self)
- return (_from_id, (self.bid, ))
+ return _from_id, (self._jbroadcast.id(),)
if __name__ == "__main__":
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index ec67ec8d0f..ed7351d60c 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -29,7 +29,7 @@ from pyspark.conf import SparkConf
from pyspark.files import SparkFiles
from pyspark.java_gateway import launch_gateway
from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \
- PairDeserializer, AutoBatchedSerializer, NoOpSerializer, LargeObjectSerializer
+ PairDeserializer, AutoBatchedSerializer, NoOpSerializer
from pyspark.storagelevel import StorageLevel
from pyspark.rdd import RDD
from pyspark.traceback_utils import CallSite, first_spark_call
@@ -624,15 +624,7 @@ class SparkContext(object):
object for reading it in distributed functions. The variable will
be sent to each cluster only once.
"""
- ser = LargeObjectSerializer()
-
- # pass large object by py4j is very slow and need much memory
- tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
- ser.dump_stream([value], tempFile)
- tempFile.close()
- jbroadcast = self._jvm.PythonRDD.readBroadcastFromFile(self._jsc, tempFile.name)
- return Broadcast(jbroadcast.id(), None, jbroadcast,
- self._pickled_broadcast_vars, tempFile.name)
+ return Broadcast(self, value, self._pickled_broadcast_vars)
def accumulator(self, value, accum_param=None):
"""
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 760a509f0e..33aa55f7f1 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -448,184 +448,20 @@ class AutoSerializer(FramedSerializer):
raise ValueError("invalid sevialization type: %s" % _type)
-class SizeLimitedStream(object):
- """
- Read at most `limit` bytes from underlying stream
-
- >>> from StringIO import StringIO
- >>> io = StringIO()
- >>> io.write("Hello world")
- >>> io.seek(0)
- >>> lio = SizeLimitedStream(io, 5)
- >>> lio.read()
- 'Hello'
- """
- def __init__(self, stream, limit):
- self.stream = stream
- self.limit = limit
-
- def read(self, n=0):
- if n > self.limit or n == 0:
- n = self.limit
- buf = self.stream.read(n)
- self.limit -= len(buf)
- return buf
-
-
-class CompressedStream(object):
- """
- Compress the data using zlib
-
- >>> from StringIO import StringIO
- >>> io = StringIO()
- >>> wio = CompressedStream(io, 'w')
- >>> wio.write("Hello world")
- >>> wio.flush()
- >>> io.seek(0)
- >>> rio = CompressedStream(io, 'r')
- >>> rio.read()
- 'Hello world'
- >>> rio.read()
- ''
- """
- MAX_BATCH = 1 << 20 # 1MB
-
- def __init__(self, stream, mode='w', level=1):
- self.stream = stream
- self.mode = mode
- if mode == 'w':
- self.compresser = zlib.compressobj(level)
- elif mode == 'r':
- self.decompresser = zlib.decompressobj()
- self.buf = ''
- else:
- raise ValueError("can only support mode 'w' or 'r' ")
-
- def write(self, buf):
- assert self.mode == 'w', "It's not opened for write"
- if len(buf) > self.MAX_BATCH:
- # zlib can not compress string larger than 2G
- batches = len(buf) / self.MAX_BATCH + 1 # last one may be empty
- for i in xrange(batches):
- self.write(buf[i * self.MAX_BATCH:(i + 1) * self.MAX_BATCH])
- else:
- compressed = self.compresser.compress(buf)
- self.stream.write(compressed)
-
- def flush(self, mode=zlib.Z_FULL_FLUSH):
- if self.mode == 'w':
- d = self.compresser.flush(mode)
- self.stream.write(d)
- self.stream.flush()
-
- def close(self):
- if self.mode == 'w':
- self.flush(zlib.Z_FINISH)
- self.stream.close()
-
- def read(self, size=0):
- assert self.mode == 'r', "It's not opened for read"
- if not size:
- data = self.stream.read()
- result = self.decompresser.decompress(data)
- last = self.decompresser.flush()
- return self.buf + result + last
-
- # fast path for small read()
- if size <= len(self.buf):
- result = self.buf[:size]
- self.buf = self.buf[size:]
- return result
-
- result = [self.buf]
- size -= len(self.buf)
- self.buf = ''
- while size:
- need = min(size, self.MAX_BATCH)
- input = self.stream.read(need)
- if input:
- buf = self.decompresser.decompress(input)
- else:
- buf = self.decompresser.flush()
-
- if len(buf) >= size:
- self.buf = buf[size:]
- result.append(buf[:size])
- return ''.join(result)
-
- size -= len(buf)
- result.append(buf)
- if not input:
- return ''.join(result)
-
- def readline(self):
- """
- This is needed for pickle, but not used in protocol 2
- """
- line = []
- b = self.read(1)
- while b and b != '\n':
- line.append(b)
- b = self.read(1)
- line.append(b)
- return ''.join(line)
-
-
-class LargeObjectSerializer(Serializer):
- """
- Serialize large object which could be larger than 2G
-
- It uses cPickle to serialize the objects
- """
- def dump_stream(self, iterator, stream):
- stream = CompressedStream(stream, 'w')
- for value in iterator:
- if isinstance(value, basestring):
- if isinstance(value, unicode):
- stream.write('U')
- value = value.encode("utf-8")
- else:
- stream.write('S')
- write_long(len(value), stream)
- stream.write(value)
- else:
- stream.write('P')
- cPickle.dump(value, stream, 2)
- stream.flush()
-
- def load_stream(self, stream):
- stream = CompressedStream(stream, 'r')
- while True:
- type = stream.read(1)
- if not type:
- return
- if type in ('S', 'U'):
- length = read_long(stream)
- value = stream.read(length)
- if type == 'U':
- value = value.decode('utf-8')
- yield value
- elif type == 'P':
- yield cPickle.load(stream)
- else:
- raise ValueError("unknown type: %s" % type)
-
-
-class CompressedSerializer(Serializer):
+class CompressedSerializer(FramedSerializer):
"""
Compress the serialized data
"""
def __init__(self, serializer):
+ FramedSerializer.__init__(self)
+ assert isinstance(serializer, FramedSerializer), "serializer must be a FramedSerializer"
self.serializer = serializer
- def load_stream(self, stream):
- stream = CompressedStream(stream, "r")
- return self.serializer.load_stream(stream)
+ def dumps(self, obj):
+ return zlib.compress(self.serializer.dumps(obj), 1)
- def dump_stream(self, iterator, stream):
- stream = CompressedStream(stream, "w")
- self.serializer.dump_stream(iterator, stream)
- stream.flush()
+ def loads(self, obj):
+ return self.serializer.loads(zlib.decompress(obj))
class UTF8Deserializer(Serializer):
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 29bcd38908..32645778c2 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -48,7 +48,7 @@ from pyspark.conf import SparkConf
from pyspark.context import SparkContext
from pyspark.files import SparkFiles
from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \
- CloudPickleSerializer, SizeLimitedStream, CompressedSerializer, LargeObjectSerializer
+ CloudPickleSerializer, CompressedSerializer
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter
from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \
UserDefinedType, DoubleType
@@ -237,26 +237,16 @@ class SerializationTestCase(unittest.TestCase):
self.assertTrue("exit" in foo.func_code.co_names)
ser.dumps(foo)
- def _test_serializer(self, ser):
+ def test_compressed_serializer(self):
+ ser = CompressedSerializer(PickleSerializer())
from StringIO import StringIO
io = StringIO()
ser.dump_stream(["abc", u"123", range(5)], io)
io.seek(0)
self.assertEqual(["abc", u"123", range(5)], list(ser.load_stream(io)))
- size = io.tell()
ser.dump_stream(range(1000), io)
io.seek(0)
- first = SizeLimitedStream(io, size)
- self.assertEqual(["abc", u"123", range(5)], list(ser.load_stream(first)))
- self.assertEqual(range(1000), list(ser.load_stream(io)))
-
- def test_compressed_serializer(self):
- ser = CompressedSerializer(PickleSerializer())
- self._test_serializer(ser)
-
- def test_large_object_serializer(self):
- ser = LargeObjectSerializer()
- self._test_serializer(ser)
+ self.assertEqual(["abc", u"123", range(5)] + range(1000), list(ser.load_stream(io)))
class PySparkTestCase(unittest.TestCase):
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index e1552a0b0b..7e5343c973 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -30,8 +30,7 @@ from pyspark.accumulators import _accumulatorRegistry
from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.files import SparkFiles
from pyspark.serializers import write_with_length, write_int, read_long, \
- write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
- SizeLimitedStream, LargeObjectSerializer
+ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer
from pyspark import shuffle
pickleSer = PickleSerializer()
@@ -78,14 +77,11 @@ def main(infile, outfile):
# fetch names and values of broadcast variables
num_broadcast_variables = read_int(infile)
- bser = LargeObjectSerializer()
for _ in range(num_broadcast_variables):
bid = read_long(infile)
if bid >= 0:
- size = read_long(infile)
- s = SizeLimitedStream(infile, size)
- value = list((bser.load_stream(s)))[0] # read out all the bytes
- _broadcastRegistry[bid] = Broadcast(bid, value)
+ path = utf8_deserializer.loads(infile)
+ _broadcastRegistry[bid] = Broadcast(path=path)
else:
bid = - bid - 1
_broadcastRegistry.pop(bid)