diff options
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/broadcast.py | 4 | ||||
-rw-r--r-- | python/pyspark/context.py | 5 | ||||
-rw-r--r-- | python/pyspark/serializers.py | 185 | ||||
-rw-r--r-- | python/pyspark/tests.py | 52 | ||||
-rw-r--r-- | python/pyspark/worker.py | 8 | ||||
-rwxr-xr-x | python/run-tests | 2 |
6 files changed, 239 insertions, 17 deletions
diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index f124dc6c07..01cac3c72c 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -29,7 +29,7 @@ """ import os -from pyspark.serializers import CompressedSerializer, PickleSerializer +from pyspark.serializers import LargeObjectSerializer __all__ = ['Broadcast'] @@ -73,7 +73,7 @@ class Broadcast(object): """ Return the broadcasted value """ if not hasattr(self, "_value") and self.path is not None: - ser = CompressedSerializer(PickleSerializer()) + ser = LargeObjectSerializer() self._value = ser.load_stream(open(self.path)).next() return self._value diff --git a/python/pyspark/context.py b/python/pyspark/context.py index b6c991453d..ec67ec8d0f 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -29,7 +29,7 @@ from pyspark.conf import SparkConf from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \ - PairDeserializer, CompressedSerializer, AutoBatchedSerializer, NoOpSerializer + PairDeserializer, AutoBatchedSerializer, NoOpSerializer, LargeObjectSerializer from pyspark.storagelevel import StorageLevel from pyspark.rdd import RDD from pyspark.traceback_utils import CallSite, first_spark_call @@ -624,7 +624,8 @@ class SparkContext(object): object for reading it in distributed functions. The variable will be sent to each cluster only once. """ - ser = CompressedSerializer(PickleSerializer()) + ser = LargeObjectSerializer() + # pass large object by py4j is very slow and need much memory tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir) ser.dump_stream([value], tempFile) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index d597cbf94e..760a509f0e 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -133,6 +133,8 @@ class FramedSerializer(Serializer): def _write_with_length(self, obj, stream): serialized = self.dumps(obj) + if len(serialized) > (1 << 31): + raise ValueError("can not serialize object larger than 2G") write_int(len(serialized), stream) if self._only_write_strings: stream.write(str(serialized)) @@ -446,20 +448,184 @@ class AutoSerializer(FramedSerializer): raise ValueError("invalid sevialization type: %s" % _type) -class CompressedSerializer(FramedSerializer): +class SizeLimitedStream(object): """ - Compress the serialized data + Read at most `limit` bytes from underlying stream + + >>> from StringIO import StringIO + >>> io = StringIO() + >>> io.write("Hello world") + >>> io.seek(0) + >>> lio = SizeLimitedStream(io, 5) + >>> lio.read() + 'Hello' + """ + def __init__(self, stream, limit): + self.stream = stream + self.limit = limit + + def read(self, n=0): + if n > self.limit or n == 0: + n = self.limit + buf = self.stream.read(n) + self.limit -= len(buf) + return buf + + +class CompressedStream(object): + """ + Compress the data using zlib + + >>> from StringIO import StringIO + >>> io = StringIO() + >>> wio = CompressedStream(io, 'w') + >>> wio.write("Hello world") + >>> wio.flush() + >>> io.seek(0) + >>> rio = CompressedStream(io, 'r') + >>> rio.read() + 'Hello world' + >>> rio.read() + '' + """ + MAX_BATCH = 1 << 20 # 1MB + + def __init__(self, stream, mode='w', level=1): + self.stream = stream + self.mode = mode + if mode == 'w': + self.compresser = zlib.compressobj(level) + elif mode == 'r': + self.decompresser = zlib.decompressobj() + self.buf = '' + else: + raise ValueError("can only support mode 'w' or 'r' ") + + def write(self, buf): + assert self.mode == 'w', "It's not opened for write" + if len(buf) > self.MAX_BATCH: + # zlib can not compress string larger than 2G + batches = len(buf) / self.MAX_BATCH + 1 # last one may be empty + for i in xrange(batches): + self.write(buf[i * self.MAX_BATCH:(i + 1) * self.MAX_BATCH]) + else: + compressed = self.compresser.compress(buf) + self.stream.write(compressed) + + def flush(self, mode=zlib.Z_FULL_FLUSH): + if self.mode == 'w': + d = self.compresser.flush(mode) + self.stream.write(d) + self.stream.flush() + + def close(self): + if self.mode == 'w': + self.flush(zlib.Z_FINISH) + self.stream.close() + + def read(self, size=0): + assert self.mode == 'r', "It's not opened for read" + if not size: + data = self.stream.read() + result = self.decompresser.decompress(data) + last = self.decompresser.flush() + return self.buf + result + last + + # fast path for small read() + if size <= len(self.buf): + result = self.buf[:size] + self.buf = self.buf[size:] + return result + + result = [self.buf] + size -= len(self.buf) + self.buf = '' + while size: + need = min(size, self.MAX_BATCH) + input = self.stream.read(need) + if input: + buf = self.decompresser.decompress(input) + else: + buf = self.decompresser.flush() + + if len(buf) >= size: + self.buf = buf[size:] + result.append(buf[:size]) + return ''.join(result) + + size -= len(buf) + result.append(buf) + if not input: + return ''.join(result) + + def readline(self): + """ + This is needed for pickle, but not used in protocol 2 + """ + line = [] + b = self.read(1) + while b and b != '\n': + line.append(b) + b = self.read(1) + line.append(b) + return ''.join(line) + + +class LargeObjectSerializer(Serializer): + """ + Serialize large object which could be larger than 2G + + It uses cPickle to serialize the objects """ + def dump_stream(self, iterator, stream): + stream = CompressedStream(stream, 'w') + for value in iterator: + if isinstance(value, basestring): + if isinstance(value, unicode): + stream.write('U') + value = value.encode("utf-8") + else: + stream.write('S') + write_long(len(value), stream) + stream.write(value) + else: + stream.write('P') + cPickle.dump(value, stream, 2) + stream.flush() + def load_stream(self, stream): + stream = CompressedStream(stream, 'r') + while True: + type = stream.read(1) + if not type: + return + if type in ('S', 'U'): + length = read_long(stream) + value = stream.read(length) + if type == 'U': + value = value.decode('utf-8') + yield value + elif type == 'P': + yield cPickle.load(stream) + else: + raise ValueError("unknown type: %s" % type) + + +class CompressedSerializer(Serializer): + """ + Compress the serialized data + """ def __init__(self, serializer): - FramedSerializer.__init__(self) self.serializer = serializer - def dumps(self, obj): - return zlib.compress(self.serializer.dumps(obj), 1) + def load_stream(self, stream): + stream = CompressedStream(stream, "r") + return self.serializer.load_stream(stream) - def loads(self, obj): - return self.serializer.loads(zlib.decompress(obj)) + def dump_stream(self, iterator, stream): + stream = CompressedStream(stream, "w") + self.serializer.dump_stream(iterator, stream) + stream.flush() class UTF8Deserializer(Serializer): @@ -517,3 +683,8 @@ def write_int(value, stream): def write_with_length(obj, stream): write_int(len(obj), stream) stream.write(obj) + + +if __name__ == '__main__': + import doctest + doctest.testmod() diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 491e445a21..a01bd8d415 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -32,6 +32,7 @@ import time import zipfile import random import threading +import hashlib if sys.version_info[:2] <= (2, 6): try: @@ -47,7 +48,7 @@ from pyspark.conf import SparkConf from pyspark.context import SparkContext from pyspark.files import SparkFiles from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \ - CloudPickleSerializer + CloudPickleSerializer, SizeLimitedStream, CompressedSerializer, LargeObjectSerializer from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \ UserDefinedType, DoubleType @@ -236,6 +237,27 @@ class SerializationTestCase(unittest.TestCase): self.assertTrue("exit" in foo.func_code.co_names) ser.dumps(foo) + def _test_serializer(self, ser): + from StringIO import StringIO + io = StringIO() + ser.dump_stream(["abc", u"123", range(5)], io) + io.seek(0) + self.assertEqual(["abc", u"123", range(5)], list(ser.load_stream(io))) + size = io.tell() + ser.dump_stream(range(1000), io) + io.seek(0) + first = SizeLimitedStream(io, size) + self.assertEqual(["abc", u"123", range(5)], list(ser.load_stream(first))) + self.assertEqual(range(1000), list(ser.load_stream(io))) + + def test_compressed_serializer(self): + ser = CompressedSerializer(PickleSerializer()) + self._test_serializer(ser) + + def test_large_object_serializer(self): + ser = LargeObjectSerializer() + self._test_serializer(ser) + class PySparkTestCase(unittest.TestCase): @@ -440,7 +462,7 @@ class RDDTests(ReusedPySparkTestCase): subset = data.takeSample(False, 10) self.assertEqual(len(subset), 10) - def testAggregateByKey(self): + def test_aggregate_by_key(self): data = self.sc.parallelize([(1, 1), (1, 1), (3, 2), (5, 1), (5, 3)], 2) def seqOp(x, y): @@ -478,6 +500,32 @@ class RDDTests(ReusedPySparkTestCase): m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() self.assertEquals(N, m) + def test_multiple_broadcasts(self): + N = 1 << 21 + b1 = self.sc.broadcast(set(range(N))) # multiple blocks in JVM + r = range(1 << 15) + random.shuffle(r) + s = str(r) + checksum = hashlib.md5(s).hexdigest() + b2 = self.sc.broadcast(s) + r = list(set(self.sc.parallelize(range(10), 10).map( + lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect())) + self.assertEqual(1, len(r)) + size, csum = r[0] + self.assertEqual(N, size) + self.assertEqual(checksum, csum) + + random.shuffle(r) + s = str(r) + checksum = hashlib.md5(s).hexdigest() + b2 = self.sc.broadcast(s) + r = list(set(self.sc.parallelize(range(10), 10).map( + lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect())) + self.assertEqual(1, len(r)) + size, csum = r[0] + self.assertEqual(N, size) + self.assertEqual(checksum, csum) + def test_large_closure(self): N = 1000000 data = [float(i) for i in xrange(N)] diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 2bdccb5e93..e1552a0b0b 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -31,7 +31,7 @@ from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.files import SparkFiles from pyspark.serializers import write_with_length, write_int, read_long, \ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ - CompressedSerializer + SizeLimitedStream, LargeObjectSerializer from pyspark import shuffle pickleSer = PickleSerializer() @@ -78,11 +78,13 @@ def main(infile, outfile): # fetch names and values of broadcast variables num_broadcast_variables = read_int(infile) - ser = CompressedSerializer(pickleSer) + bser = LargeObjectSerializer() for _ in range(num_broadcast_variables): bid = read_long(infile) if bid >= 0: - value = ser._read_with_length(infile) + size = read_long(infile) + s = SizeLimitedStream(infile, size) + value = list((bser.load_stream(s)))[0] # read out all the bytes _broadcastRegistry[bid] = Broadcast(bid, value) else: bid = - bid - 1 diff --git a/python/run-tests b/python/run-tests index e66854b44d..9ee19ed6e6 100755 --- a/python/run-tests +++ b/python/run-tests @@ -56,7 +56,7 @@ function run_core_tests() { run_test "pyspark/conf.py" PYSPARK_DOC_TEST=1 run_test "pyspark/broadcast.py" PYSPARK_DOC_TEST=1 run_test "pyspark/accumulators.py" - PYSPARK_DOC_TEST=1 run_test "pyspark/serializers.py" + run_test "pyspark/serializers.py" run_test "pyspark/shuffle.py" run_test "pyspark/tests.py" } |