diff options
Diffstat (limited to 'python/pyspark/serializers.py')
-rw-r--r-- | python/pyspark/serializers.py | 36 |
1 files changed, 36 insertions, 0 deletions
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 44ac564283..2672da36c1 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -68,6 +68,7 @@ import sys import types import collections import zlib +import itertools from pyspark import cloudpickle @@ -214,6 +215,41 @@ class BatchedSerializer(Serializer): return "BatchedSerializer<%s>" % str(self.serializer) +class AutoBatchedSerializer(BatchedSerializer): + """ + Choose the size of batch automatically based on the size of object + """ + + def __init__(self, serializer, bestSize=1 << 20): + BatchedSerializer.__init__(self, serializer, -1) + self.bestSize = bestSize + + def dump_stream(self, iterator, stream): + batch, best = 1, self.bestSize + iterator = iter(iterator) + while True: + vs = list(itertools.islice(iterator, batch)) + if not vs: + break + + bytes = self.serializer.dumps(vs) + write_int(len(bytes), stream) + stream.write(bytes) + + size = len(bytes) + if size < best: + batch *= 2 + elif size > best * 10 and batch > 1: + batch /= 2 + + def __eq__(self, other): + return (isinstance(other, AutoBatchedSerializer) and + other.serializer == self.serializer) + + def __str__(self): + return "BatchedSerializer<%s>" % str(self.serializer) + + class CartesianDeserializer(FramedSerializer): """ |