aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/serializers.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/serializers.py')
-rw-r--r--python/pyspark/serializers.py185
1 files changed, 178 insertions, 7 deletions
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index d597cbf94e..760a509f0e 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -133,6 +133,8 @@ class FramedSerializer(Serializer):
def _write_with_length(self, obj, stream):
serialized = self.dumps(obj)
+ if len(serialized) > (1 << 31):
+ raise ValueError("can not serialize object larger than 2G")
write_int(len(serialized), stream)
if self._only_write_strings:
stream.write(str(serialized))
@@ -446,20 +448,184 @@ class AutoSerializer(FramedSerializer):
raise ValueError("invalid sevialization type: %s" % _type)
-class CompressedSerializer(FramedSerializer):
+class SizeLimitedStream(object):
"""
- Compress the serialized data
+ Read at most `limit` bytes from underlying stream
+
+ >>> from StringIO import StringIO
+ >>> io = StringIO()
+ >>> io.write("Hello world")
+ >>> io.seek(0)
+ >>> lio = SizeLimitedStream(io, 5)
+ >>> lio.read()
+ 'Hello'
+ """
+ def __init__(self, stream, limit):
+ self.stream = stream
+ self.limit = limit
+
+ def read(self, n=0):
+ if n > self.limit or n == 0:
+ n = self.limit
+ buf = self.stream.read(n)
+ self.limit -= len(buf)
+ return buf
+
+
+class CompressedStream(object):
+ """
+ Compress the data using zlib
+
+ >>> from StringIO import StringIO
+ >>> io = StringIO()
+ >>> wio = CompressedStream(io, 'w')
+ >>> wio.write("Hello world")
+ >>> wio.flush()
+ >>> io.seek(0)
+ >>> rio = CompressedStream(io, 'r')
+ >>> rio.read()
+ 'Hello world'
+ >>> rio.read()
+ ''
+ """
+ MAX_BATCH = 1 << 20 # 1MB
+
+ def __init__(self, stream, mode='w', level=1):
+ self.stream = stream
+ self.mode = mode
+ if mode == 'w':
+ self.compresser = zlib.compressobj(level)
+ elif mode == 'r':
+ self.decompresser = zlib.decompressobj()
+ self.buf = ''
+ else:
+ raise ValueError("can only support mode 'w' or 'r' ")
+
+ def write(self, buf):
+ assert self.mode == 'w', "It's not opened for write"
+ if len(buf) > self.MAX_BATCH:
+ # zlib can not compress string larger than 2G
+ batches = len(buf) / self.MAX_BATCH + 1 # last one may be empty
+ for i in xrange(batches):
+ self.write(buf[i * self.MAX_BATCH:(i + 1) * self.MAX_BATCH])
+ else:
+ compressed = self.compresser.compress(buf)
+ self.stream.write(compressed)
+
+ def flush(self, mode=zlib.Z_FULL_FLUSH):
+ if self.mode == 'w':
+ d = self.compresser.flush(mode)
+ self.stream.write(d)
+ self.stream.flush()
+
+ def close(self):
+ if self.mode == 'w':
+ self.flush(zlib.Z_FINISH)
+ self.stream.close()
+
+ def read(self, size=0):
+ assert self.mode == 'r', "It's not opened for read"
+ if not size:
+ data = self.stream.read()
+ result = self.decompresser.decompress(data)
+ last = self.decompresser.flush()
+ return self.buf + result + last
+
+ # fast path for small read()
+ if size <= len(self.buf):
+ result = self.buf[:size]
+ self.buf = self.buf[size:]
+ return result
+
+ result = [self.buf]
+ size -= len(self.buf)
+ self.buf = ''
+ while size:
+ need = min(size, self.MAX_BATCH)
+ input = self.stream.read(need)
+ if input:
+ buf = self.decompresser.decompress(input)
+ else:
+ buf = self.decompresser.flush()
+
+ if len(buf) >= size:
+ self.buf = buf[size:]
+ result.append(buf[:size])
+ return ''.join(result)
+
+ size -= len(buf)
+ result.append(buf)
+ if not input:
+ return ''.join(result)
+
+ def readline(self):
+ """
+ This is needed for pickle, but not used in protocol 2
+ """
+ line = []
+ b = self.read(1)
+ while b and b != '\n':
+ line.append(b)
+ b = self.read(1)
+ line.append(b)
+ return ''.join(line)
+
+
+class LargeObjectSerializer(Serializer):
+ """
+ Serialize large object which could be larger than 2G
+
+ It uses cPickle to serialize the objects
"""
+ def dump_stream(self, iterator, stream):
+ stream = CompressedStream(stream, 'w')
+ for value in iterator:
+ if isinstance(value, basestring):
+ if isinstance(value, unicode):
+ stream.write('U')
+ value = value.encode("utf-8")
+ else:
+ stream.write('S')
+ write_long(len(value), stream)
+ stream.write(value)
+ else:
+ stream.write('P')
+ cPickle.dump(value, stream, 2)
+ stream.flush()
+ def load_stream(self, stream):
+ stream = CompressedStream(stream, 'r')
+ while True:
+ type = stream.read(1)
+ if not type:
+ return
+ if type in ('S', 'U'):
+ length = read_long(stream)
+ value = stream.read(length)
+ if type == 'U':
+ value = value.decode('utf-8')
+ yield value
+ elif type == 'P':
+ yield cPickle.load(stream)
+ else:
+ raise ValueError("unknown type: %s" % type)
+
+
+class CompressedSerializer(Serializer):
+ """
+ Compress the serialized data
+ """
def __init__(self, serializer):
- FramedSerializer.__init__(self)
self.serializer = serializer
- def dumps(self, obj):
- return zlib.compress(self.serializer.dumps(obj), 1)
+ def load_stream(self, stream):
+ stream = CompressedStream(stream, "r")
+ return self.serializer.load_stream(stream)
- def loads(self, obj):
- return self.serializer.loads(zlib.decompress(obj))
+ def dump_stream(self, iterator, stream):
+ stream = CompressedStream(stream, "w")
+ self.serializer.dump_stream(iterator, stream)
+ stream.flush()
class UTF8Deserializer(Serializer):
@@ -517,3 +683,8 @@ def write_int(value, stream):
def write_with_length(obj, stream):
write_int(len(obj), stream)
stream.write(obj)
+
+
+if __name__ == '__main__':
+ import doctest
+ doctest.testmod()