diff options
Diffstat (limited to 'python/pyspark/serializers.py')
-rw-r--r-- | python/pyspark/serializers.py | 185 |
1 files changed, 178 insertions, 7 deletions
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index d597cbf94e..760a509f0e 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -133,6 +133,8 @@ class FramedSerializer(Serializer): def _write_with_length(self, obj, stream): serialized = self.dumps(obj) + if len(serialized) > (1 << 31): + raise ValueError("can not serialize object larger than 2G") write_int(len(serialized), stream) if self._only_write_strings: stream.write(str(serialized)) @@ -446,20 +448,184 @@ class AutoSerializer(FramedSerializer): raise ValueError("invalid sevialization type: %s" % _type) -class CompressedSerializer(FramedSerializer): +class SizeLimitedStream(object): """ - Compress the serialized data + Read at most `limit` bytes from underlying stream + + >>> from StringIO import StringIO + >>> io = StringIO() + >>> io.write("Hello world") + >>> io.seek(0) + >>> lio = SizeLimitedStream(io, 5) + >>> lio.read() + 'Hello' + """ + def __init__(self, stream, limit): + self.stream = stream + self.limit = limit + + def read(self, n=0): + if n > self.limit or n == 0: + n = self.limit + buf = self.stream.read(n) + self.limit -= len(buf) + return buf + + +class CompressedStream(object): + """ + Compress the data using zlib + + >>> from StringIO import StringIO + >>> io = StringIO() + >>> wio = CompressedStream(io, 'w') + >>> wio.write("Hello world") + >>> wio.flush() + >>> io.seek(0) + >>> rio = CompressedStream(io, 'r') + >>> rio.read() + 'Hello world' + >>> rio.read() + '' + """ + MAX_BATCH = 1 << 20 # 1MB + + def __init__(self, stream, mode='w', level=1): + self.stream = stream + self.mode = mode + if mode == 'w': + self.compresser = zlib.compressobj(level) + elif mode == 'r': + self.decompresser = zlib.decompressobj() + self.buf = '' + else: + raise ValueError("can only support mode 'w' or 'r' ") + + def write(self, buf): + assert self.mode == 'w', "It's not opened for write" + if len(buf) > self.MAX_BATCH: + # zlib can not compress string larger than 2G + batches = len(buf) / self.MAX_BATCH + 1 # last one may be empty + for i in xrange(batches): + self.write(buf[i * self.MAX_BATCH:(i + 1) * self.MAX_BATCH]) + else: + compressed = self.compresser.compress(buf) + self.stream.write(compressed) + + def flush(self, mode=zlib.Z_FULL_FLUSH): + if self.mode == 'w': + d = self.compresser.flush(mode) + self.stream.write(d) + self.stream.flush() + + def close(self): + if self.mode == 'w': + self.flush(zlib.Z_FINISH) + self.stream.close() + + def read(self, size=0): + assert self.mode == 'r', "It's not opened for read" + if not size: + data = self.stream.read() + result = self.decompresser.decompress(data) + last = self.decompresser.flush() + return self.buf + result + last + + # fast path for small read() + if size <= len(self.buf): + result = self.buf[:size] + self.buf = self.buf[size:] + return result + + result = [self.buf] + size -= len(self.buf) + self.buf = '' + while size: + need = min(size, self.MAX_BATCH) + input = self.stream.read(need) + if input: + buf = self.decompresser.decompress(input) + else: + buf = self.decompresser.flush() + + if len(buf) >= size: + self.buf = buf[size:] + result.append(buf[:size]) + return ''.join(result) + + size -= len(buf) + result.append(buf) + if not input: + return ''.join(result) + + def readline(self): + """ + This is needed for pickle, but not used in protocol 2 + """ + line = [] + b = self.read(1) + while b and b != '\n': + line.append(b) + b = self.read(1) + line.append(b) + return ''.join(line) + + +class LargeObjectSerializer(Serializer): + """ + Serialize large object which could be larger than 2G + + It uses cPickle to serialize the objects """ + def dump_stream(self, iterator, stream): + stream = CompressedStream(stream, 'w') + for value in iterator: + if isinstance(value, basestring): + if isinstance(value, unicode): + stream.write('U') + value = value.encode("utf-8") + else: + stream.write('S') + write_long(len(value), stream) + stream.write(value) + else: + stream.write('P') + cPickle.dump(value, stream, 2) + stream.flush() + def load_stream(self, stream): + stream = CompressedStream(stream, 'r') + while True: + type = stream.read(1) + if not type: + return + if type in ('S', 'U'): + length = read_long(stream) + value = stream.read(length) + if type == 'U': + value = value.decode('utf-8') + yield value + elif type == 'P': + yield cPickle.load(stream) + else: + raise ValueError("unknown type: %s" % type) + + +class CompressedSerializer(Serializer): + """ + Compress the serialized data + """ def __init__(self, serializer): - FramedSerializer.__init__(self) self.serializer = serializer - def dumps(self, obj): - return zlib.compress(self.serializer.dumps(obj), 1) + def load_stream(self, stream): + stream = CompressedStream(stream, "r") + return self.serializer.load_stream(stream) - def loads(self, obj): - return self.serializer.loads(zlib.decompress(obj)) + def dump_stream(self, iterator, stream): + stream = CompressedStream(stream, "w") + self.serializer.dump_stream(iterator, stream) + stream.flush() class UTF8Deserializer(Serializer): @@ -517,3 +683,8 @@ def write_int(value, stream): def write_with_length(obj, stream): write_int(len(obj), stream) stream.write(obj) + + +if __name__ == '__main__': + import doctest + doctest.testmod() |