aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/serializers.py
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2014-11-18 16:17:51 -0800
committerJosh Rosen <joshrosen@databricks.com>2014-11-18 16:17:51 -0800
commit4a377aff2d36b64a65b54192a987aba44b8f78e0 (patch)
tree2c0e4cfd8c8f7ae21eb3d3048df7232f09c47304 /python/pyspark/serializers.py
parentd2e29516f2064f93f3a9070c91fc7460706e0b0a (diff)
downloadspark-4a377aff2d36b64a65b54192a987aba44b8f78e0.tar.gz
spark-4a377aff2d36b64a65b54192a987aba44b8f78e0.tar.bz2
spark-4a377aff2d36b64a65b54192a987aba44b8f78e0.zip
[SPARK-3721] [PySpark] broadcast objects larger than 2G
This patch will bring support for broadcasting objects larger than 2G. pickle, zlib, FrameSerializer and Array[Byte] all can not support objects larger than 2G, so this patch introduce LargeObjectSerializer to serialize broadcast objects, the object will be serialized and compressed into small chunks, it also change the type of Broadcast[Array[Byte]]] into Broadcast[Array[Array[Byte]]]]. Testing for support broadcast objects larger than 2G is slow and memory hungry, so this is tested manually, could be added into SparkPerf. Author: Davies Liu <davies@databricks.com> Author: Davies Liu <davies.liu@gmail.com> Closes #2659 from davies/huge and squashes the following commits: 7b57a14 [Davies Liu] add more tests for broadcast 28acff9 [Davies Liu] Merge branch 'master' of github.com:apache/spark into huge a2f6a02 [Davies Liu] bug fix 4820613 [Davies Liu] Merge branch 'master' of github.com:apache/spark into huge 5875c73 [Davies Liu] address comments 10a349b [Davies Liu] address comments 0c33016 [Davies Liu] Merge branch 'master' of github.com:apache/spark into huge 6182c8f [Davies Liu] Merge branch 'master' into huge d94b68f [Davies Liu] Merge branch 'master' of github.com:apache/spark into huge 2514848 [Davies Liu] address comments fda395b [Davies Liu] Merge branch 'master' of github.com:apache/spark into huge 1c2d928 [Davies Liu] fix scala style 091b107 [Davies Liu] broadcast objects larger than 2G
Diffstat (limited to 'python/pyspark/serializers.py')
-rw-r--r--python/pyspark/serializers.py185
1 files changed, 178 insertions, 7 deletions
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index d597cbf94e..760a509f0e 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -133,6 +133,8 @@ class FramedSerializer(Serializer):
def _write_with_length(self, obj, stream):
serialized = self.dumps(obj)
+ if len(serialized) > (1 << 31):
+ raise ValueError("can not serialize object larger than 2G")
write_int(len(serialized), stream)
if self._only_write_strings:
stream.write(str(serialized))
@@ -446,20 +448,184 @@ class AutoSerializer(FramedSerializer):
raise ValueError("invalid sevialization type: %s" % _type)
-class CompressedSerializer(FramedSerializer):
+class SizeLimitedStream(object):
"""
- Compress the serialized data
+ Read at most `limit` bytes from underlying stream
+
+ >>> from StringIO import StringIO
+ >>> io = StringIO()
+ >>> io.write("Hello world")
+ >>> io.seek(0)
+ >>> lio = SizeLimitedStream(io, 5)
+ >>> lio.read()
+ 'Hello'
+ """
+ def __init__(self, stream, limit):
+ self.stream = stream
+ self.limit = limit
+
+ def read(self, n=0):
+ if n > self.limit or n == 0:
+ n = self.limit
+ buf = self.stream.read(n)
+ self.limit -= len(buf)
+ return buf
+
+
+class CompressedStream(object):
+ """
+ Compress the data using zlib
+
+ >>> from StringIO import StringIO
+ >>> io = StringIO()
+ >>> wio = CompressedStream(io, 'w')
+ >>> wio.write("Hello world")
+ >>> wio.flush()
+ >>> io.seek(0)
+ >>> rio = CompressedStream(io, 'r')
+ >>> rio.read()
+ 'Hello world'
+ >>> rio.read()
+ ''
+ """
+ MAX_BATCH = 1 << 20 # 1MB
+
+ def __init__(self, stream, mode='w', level=1):
+ self.stream = stream
+ self.mode = mode
+ if mode == 'w':
+ self.compresser = zlib.compressobj(level)
+ elif mode == 'r':
+ self.decompresser = zlib.decompressobj()
+ self.buf = ''
+ else:
+ raise ValueError("can only support mode 'w' or 'r' ")
+
+ def write(self, buf):
+ assert self.mode == 'w', "It's not opened for write"
+ if len(buf) > self.MAX_BATCH:
+ # zlib can not compress string larger than 2G
+ batches = len(buf) / self.MAX_BATCH + 1 # last one may be empty
+ for i in xrange(batches):
+ self.write(buf[i * self.MAX_BATCH:(i + 1) * self.MAX_BATCH])
+ else:
+ compressed = self.compresser.compress(buf)
+ self.stream.write(compressed)
+
+ def flush(self, mode=zlib.Z_FULL_FLUSH):
+ if self.mode == 'w':
+ d = self.compresser.flush(mode)
+ self.stream.write(d)
+ self.stream.flush()
+
+ def close(self):
+ if self.mode == 'w':
+ self.flush(zlib.Z_FINISH)
+ self.stream.close()
+
+ def read(self, size=0):
+ assert self.mode == 'r', "It's not opened for read"
+ if not size:
+ data = self.stream.read()
+ result = self.decompresser.decompress(data)
+ last = self.decompresser.flush()
+ return self.buf + result + last
+
+ # fast path for small read()
+ if size <= len(self.buf):
+ result = self.buf[:size]
+ self.buf = self.buf[size:]
+ return result
+
+ result = [self.buf]
+ size -= len(self.buf)
+ self.buf = ''
+ while size:
+ need = min(size, self.MAX_BATCH)
+ input = self.stream.read(need)
+ if input:
+ buf = self.decompresser.decompress(input)
+ else:
+ buf = self.decompresser.flush()
+
+ if len(buf) >= size:
+ self.buf = buf[size:]
+ result.append(buf[:size])
+ return ''.join(result)
+
+ size -= len(buf)
+ result.append(buf)
+ if not input:
+ return ''.join(result)
+
+ def readline(self):
+ """
+ This is needed for pickle, but not used in protocol 2
+ """
+ line = []
+ b = self.read(1)
+ while b and b != '\n':
+ line.append(b)
+ b = self.read(1)
+ line.append(b)
+ return ''.join(line)
+
+
+class LargeObjectSerializer(Serializer):
+ """
+ Serialize large object which could be larger than 2G
+
+ It uses cPickle to serialize the objects
"""
+ def dump_stream(self, iterator, stream):
+ stream = CompressedStream(stream, 'w')
+ for value in iterator:
+ if isinstance(value, basestring):
+ if isinstance(value, unicode):
+ stream.write('U')
+ value = value.encode("utf-8")
+ else:
+ stream.write('S')
+ write_long(len(value), stream)
+ stream.write(value)
+ else:
+ stream.write('P')
+ cPickle.dump(value, stream, 2)
+ stream.flush()
+ def load_stream(self, stream):
+ stream = CompressedStream(stream, 'r')
+ while True:
+ type = stream.read(1)
+ if not type:
+ return
+ if type in ('S', 'U'):
+ length = read_long(stream)
+ value = stream.read(length)
+ if type == 'U':
+ value = value.decode('utf-8')
+ yield value
+ elif type == 'P':
+ yield cPickle.load(stream)
+ else:
+ raise ValueError("unknown type: %s" % type)
+
+
+class CompressedSerializer(Serializer):
+ """
+ Compress the serialized data
+ """
def __init__(self, serializer):
- FramedSerializer.__init__(self)
self.serializer = serializer
- def dumps(self, obj):
- return zlib.compress(self.serializer.dumps(obj), 1)
+ def load_stream(self, stream):
+ stream = CompressedStream(stream, "r")
+ return self.serializer.load_stream(stream)
- def loads(self, obj):
- return self.serializer.loads(zlib.decompress(obj))
+ def dump_stream(self, iterator, stream):
+ stream = CompressedStream(stream, "w")
+ self.serializer.dump_stream(iterator, stream)
+ stream.flush()
class UTF8Deserializer(Serializer):
@@ -517,3 +683,8 @@ def write_int(value, stream):
def write_with_length(obj, stream):
write_int(len(obj), stream)
stream.write(obj)
+
+
+if __name__ == '__main__':
+ import doctest
+ doctest.testmod()