diff options
Diffstat (limited to 'python/pyspark/serializers.py')
-rw-r--r-- | python/pyspark/serializers.py | 310 |
1 files changed, 243 insertions, 67 deletions
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index fd02e1ee8f..4fb444443f 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -15,8 +15,58 @@ # limitations under the License. # -import struct +""" +PySpark supports custom serializers for transferring data; this can improve +performance. + +By default, PySpark uses L{PickleSerializer} to serialize objects using Python's +C{cPickle} serializer, which can serialize nearly any Python object. +Other serializers, like L{MarshalSerializer}, support fewer datatypes but can be +faster. + +The serializer is chosen when creating L{SparkContext}: + +>>> from pyspark.context import SparkContext +>>> from pyspark.serializers import MarshalSerializer +>>> sc = SparkContext('local', 'test', serializer=MarshalSerializer()) +>>> sc.parallelize(list(range(1000))).map(lambda x: 2 * x).take(10) +[0, 2, 4, 6, 8, 10, 12, 14, 16, 18] +>>> sc.stop() + +By default, PySpark serialize objects in batches; the batch size can be +controlled through SparkContext's C{batchSize} parameter +(the default size is 1024 objects): + +>>> sc = SparkContext('local', 'test', batchSize=2) +>>> rdd = sc.parallelize(range(16), 4).map(lambda x: x) + +Behind the scenes, this creates a JavaRDD with four partitions, each of +which contains two batches of two objects: + +>>> rdd.glom().collect() +[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]] +>>> rdd._jrdd.count() +8L +>>> sc.stop() + +A batch size of -1 uses an unlimited batch size, and a size of 1 disables +batching: + +>>> sc = SparkContext('local', 'test', batchSize=1) +>>> rdd = sc.parallelize(range(16), 4).map(lambda x: x) +>>> rdd.glom().collect() +[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]] +>>> rdd._jrdd.count() +16L +""" + import cPickle +from itertools import chain, izip, product +import marshal +import struct + + +__all__ = ["PickleSerializer", "MarshalSerializer"] class SpecialLengths(object): @@ -25,41 +75,206 @@ class SpecialLengths(object): TIMING_DATA = -3 -class Batch(object): +class Serializer(object): + + def dump_stream(self, iterator, stream): + """ + Serialize an iterator of objects to the output stream. + """ + raise NotImplementedError + + def load_stream(self, stream): + """ + Return an iterator of deserialized objects from the input stream. + """ + raise NotImplementedError + + + def _load_stream_without_unbatching(self, stream): + return self.load_stream(stream) + + # Note: our notion of "equality" is that output generated by + # equal serializers can be deserialized using the same serializer. + + # This default implementation handles the simple cases; + # subclasses should override __eq__ as appropriate. + + def __eq__(self, other): + return isinstance(other, self.__class__) + + def __ne__(self, other): + return not self.__eq__(other) + + +class FramedSerializer(Serializer): + """ + Serializer that writes objects as a stream of (length, data) pairs, + where C{length} is a 32-bit integer and data is C{length} bytes. + """ + + def dump_stream(self, iterator, stream): + for obj in iterator: + self._write_with_length(obj, stream) + + def load_stream(self, stream): + while True: + try: + yield self._read_with_length(stream) + except EOFError: + return + + def _write_with_length(self, obj, stream): + serialized = self._dumps(obj) + write_int(len(serialized), stream) + stream.write(serialized) + + def _read_with_length(self, stream): + length = read_int(stream) + obj = stream.read(length) + if obj == "": + raise EOFError + return self._loads(obj) + + def _dumps(self, obj): + """ + Serialize an object into a byte array. + When batching is used, this will be called with an array of objects. + """ + raise NotImplementedError + + def _loads(self, obj): + """ + Deserialize an object from a byte array. + """ + raise NotImplementedError + + +class BatchedSerializer(Serializer): + """ + Serializes a stream of objects in batches by calling its wrapped + Serializer with streams of objects. + """ + + UNLIMITED_BATCH_SIZE = -1 + + def __init__(self, serializer, batchSize=UNLIMITED_BATCH_SIZE): + self.serializer = serializer + self.batchSize = batchSize + + def _batched(self, iterator): + if self.batchSize == self.UNLIMITED_BATCH_SIZE: + yield list(iterator) + else: + items = [] + count = 0 + for item in iterator: + items.append(item) + count += 1 + if count == self.batchSize: + yield items + items = [] + count = 0 + if items: + yield items + + def dump_stream(self, iterator, stream): + if isinstance(iterator, basestring): + iterator = [iterator] + self.serializer.dump_stream(self._batched(iterator), stream) + + def load_stream(self, stream): + return chain.from_iterable(self._load_stream_without_unbatching(stream)) + + def _load_stream_without_unbatching(self, stream): + return self.serializer.load_stream(stream) + + def __eq__(self, other): + return isinstance(other, BatchedSerializer) and \ + other.serializer == self.serializer + + def __str__(self): + return "BatchedSerializer<%s>" % str(self.serializer) + + +class CartesianDeserializer(FramedSerializer): """ - Used to store multiple RDD entries as a single Java object. + Deserializes the JavaRDD cartesian() of two PythonRDDs. + """ + + def __init__(self, key_ser, val_ser): + self.key_ser = key_ser + self.val_ser = val_ser + + def load_stream(self, stream): + key_stream = self.key_ser._load_stream_without_unbatching(stream) + val_stream = self.val_ser._load_stream_without_unbatching(stream) + key_is_batched = isinstance(self.key_ser, BatchedSerializer) + val_is_batched = isinstance(self.val_ser, BatchedSerializer) + for (keys, vals) in izip(key_stream, val_stream): + keys = keys if key_is_batched else [keys] + vals = vals if val_is_batched else [vals] + for pair in product(keys, vals): + yield pair + + def __eq__(self, other): + return isinstance(other, CartesianDeserializer) and \ + self.key_ser == other.key_ser and self.val_ser == other.val_ser + + def __str__(self): + return "CartesianDeserializer<%s, %s>" % \ + (str(self.key_ser), str(self.val_ser)) + + +class NoOpSerializer(FramedSerializer): + + def _loads(self, obj): return obj + def _dumps(self, obj): return obj + + +class PickleSerializer(FramedSerializer): + """ + Serializes objects using Python's cPickle serializer: + + http://docs.python.org/2/library/pickle.html + + This serializer supports nearly any Python object, but may + not be as fast as more specialized serializers. + """ + + def _dumps(self, obj): return cPickle.dumps(obj, 2) + _loads = cPickle.loads + - This relieves us from having to explicitly track whether an RDD - is stored as batches of objects and avoids problems when processing - the union() of batched and unbatched RDDs (e.g. the union() of textFile() - with another RDD). +class MarshalSerializer(FramedSerializer): """ - def __init__(self, items): - self.items = items + Serializes objects using Python's Marshal serializer: + http://docs.python.org/2/library/marshal.html -def batched(iterator, batchSize): - if batchSize == -1: # unlimited batch size - yield Batch(list(iterator)) - else: - items = [] - count = 0 - for item in iterator: - items.append(item) - count += 1 - if count == batchSize: - yield Batch(items) - items = [] - count = 0 - if items: - yield Batch(items) + This serializer is faster than PickleSerializer but supports fewer datatypes. + """ + + _dumps = marshal.dumps + _loads = marshal.loads -def dump_pickle(obj): - return cPickle.dumps(obj, 2) +class MUTF8Deserializer(Serializer): + """ + Deserializes streams written by Java's DataOutputStream.writeUTF(). + """ + def _loads(self, stream): + length = struct.unpack('>H', stream.read(2))[0] + return stream.read(length).decode('utf8') -load_pickle = cPickle.loads + def load_stream(self, stream): + while True: + try: + yield self._loads(stream) + except struct.error: + return + except EOFError: + return def read_long(stream): @@ -90,43 +305,4 @@ def write_int(value, stream): def write_with_length(obj, stream): write_int(len(obj), stream) - stream.write(obj) - - -def read_mutf8(stream): - """ - Read a string written with Java's DataOutputStream.writeUTF() method. - """ - length = struct.unpack('>H', stream.read(2))[0] - return stream.read(length).decode('utf8') - - -def read_with_length(stream): - length = read_int(stream) - obj = stream.read(length) - if obj == "": - raise EOFError - return obj - - -def read_from_pickle_file(stream): - try: - while True: - obj = load_pickle(read_with_length(stream)) - if type(obj) == Batch: # We don't care about inheritance - for item in obj.items: - yield item - else: - yield obj - except EOFError: - return - - -def read_pairs_from_pickle_file(stream): - try: - while True: - a = load_pickle(read_with_length(stream)) - b = load_pickle(read_with_length(stream)) - yield (a, b) - except EOFError: - return
\ No newline at end of file + stream.write(obj)
\ No newline at end of file |