aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/serializers.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/serializers.py')
-rw-r--r--python/pyspark/serializers.py299
1 files changed, 251 insertions, 48 deletions
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 54fed1c9c7..2a500ab919 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -15,45 +15,269 @@
# limitations under the License.
#
-import struct
+"""
+PySpark supports custom serializers for transferring data; this can improve
+performance.
+
+By default, PySpark uses L{PickleSerializer} to serialize objects using Python's
+C{cPickle} serializer, which can serialize nearly any Python object.
+Other serializers, like L{MarshalSerializer}, support fewer datatypes but can be
+faster.
+
+The serializer is chosen when creating L{SparkContext}:
+
+>>> from pyspark.context import SparkContext
+>>> from pyspark.serializers import MarshalSerializer
+>>> sc = SparkContext('local', 'test', serializer=MarshalSerializer())
+>>> sc.parallelize(list(range(1000))).map(lambda x: 2 * x).take(10)
+[0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
+>>> sc.stop()
+
+By default, PySpark serialize objects in batches; the batch size can be
+controlled through SparkContext's C{batchSize} parameter
+(the default size is 1024 objects):
+
+>>> sc = SparkContext('local', 'test', batchSize=2)
+>>> rdd = sc.parallelize(range(16), 4).map(lambda x: x)
+
+Behind the scenes, this creates a JavaRDD with four partitions, each of
+which contains two batches of two objects:
+
+>>> rdd.glom().collect()
+[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]
+>>> rdd._jrdd.count()
+8L
+>>> sc.stop()
+
+A batch size of -1 uses an unlimited batch size, and a size of 1 disables
+batching:
+
+>>> sc = SparkContext('local', 'test', batchSize=1)
+>>> rdd = sc.parallelize(range(16), 4).map(lambda x: x)
+>>> rdd.glom().collect()
+[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]
+>>> rdd._jrdd.count()
+16L
+"""
+
import cPickle
+from itertools import chain, izip, product
+import marshal
+import struct
+from pyspark import cloudpickle
+
+
+__all__ = ["PickleSerializer", "MarshalSerializer"]
+
+
+class SpecialLengths(object):
+ END_OF_DATA_SECTION = -1
+ PYTHON_EXCEPTION_THROWN = -2
+ TIMING_DATA = -3
+
+
+class Serializer(object):
+
+ def dump_stream(self, iterator, stream):
+ """
+ Serialize an iterator of objects to the output stream.
+ """
+ raise NotImplementedError
+
+ def load_stream(self, stream):
+ """
+ Return an iterator of deserialized objects from the input stream.
+ """
+ raise NotImplementedError
+
+
+ def _load_stream_without_unbatching(self, stream):
+ return self.load_stream(stream)
+
+ # Note: our notion of "equality" is that output generated by
+ # equal serializers can be deserialized using the same serializer.
+
+ # This default implementation handles the simple cases;
+ # subclasses should override __eq__ as appropriate.
+
+ def __eq__(self, other):
+ return isinstance(other, self.__class__)
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+
+class FramedSerializer(Serializer):
+ """
+ Serializer that writes objects as a stream of (length, data) pairs,
+ where C{length} is a 32-bit integer and data is C{length} bytes.
+ """
+
+ def dump_stream(self, iterator, stream):
+ for obj in iterator:
+ self._write_with_length(obj, stream)
+
+ def load_stream(self, stream):
+ while True:
+ try:
+ yield self._read_with_length(stream)
+ except EOFError:
+ return
+
+ def _write_with_length(self, obj, stream):
+ serialized = self.dumps(obj)
+ write_int(len(serialized), stream)
+ stream.write(serialized)
+
+ def _read_with_length(self, stream):
+ length = read_int(stream)
+ obj = stream.read(length)
+ if obj == "":
+ raise EOFError
+ return self.loads(obj)
+
+ def dumps(self, obj):
+ """
+ Serialize an object into a byte array.
+ When batching is used, this will be called with an array of objects.
+ """
+ raise NotImplementedError
+
+ def loads(self, obj):
+ """
+ Deserialize an object from a byte array.
+ """
+ raise NotImplementedError
+
+
+class BatchedSerializer(Serializer):
+ """
+ Serializes a stream of objects in batches by calling its wrapped
+ Serializer with streams of objects.
+ """
+
+ UNLIMITED_BATCH_SIZE = -1
+
+ def __init__(self, serializer, batchSize=UNLIMITED_BATCH_SIZE):
+ self.serializer = serializer
+ self.batchSize = batchSize
+
+ def _batched(self, iterator):
+ if self.batchSize == self.UNLIMITED_BATCH_SIZE:
+ yield list(iterator)
+ else:
+ items = []
+ count = 0
+ for item in iterator:
+ items.append(item)
+ count += 1
+ if count == self.batchSize:
+ yield items
+ items = []
+ count = 0
+ if items:
+ yield items
+
+ def dump_stream(self, iterator, stream):
+ self.serializer.dump_stream(self._batched(iterator), stream)
+
+ def load_stream(self, stream):
+ return chain.from_iterable(self._load_stream_without_unbatching(stream))
+
+ def _load_stream_without_unbatching(self, stream):
+ return self.serializer.load_stream(stream)
+
+ def __eq__(self, other):
+ return isinstance(other, BatchedSerializer) and \
+ other.serializer == self.serializer
+
+ def __str__(self):
+ return "BatchedSerializer<%s>" % str(self.serializer)
-class Batch(object):
+class CartesianDeserializer(FramedSerializer):
"""
- Used to store multiple RDD entries as a single Java object.
+ Deserializes the JavaRDD cartesian() of two PythonRDDs.
+ """
+
+ def __init__(self, key_ser, val_ser):
+ self.key_ser = key_ser
+ self.val_ser = val_ser
+
+ def load_stream(self, stream):
+ key_stream = self.key_ser._load_stream_without_unbatching(stream)
+ val_stream = self.val_ser._load_stream_without_unbatching(stream)
+ key_is_batched = isinstance(self.key_ser, BatchedSerializer)
+ val_is_batched = isinstance(self.val_ser, BatchedSerializer)
+ for (keys, vals) in izip(key_stream, val_stream):
+ keys = keys if key_is_batched else [keys]
+ vals = vals if val_is_batched else [vals]
+ for pair in product(keys, vals):
+ yield pair
+
+ def __eq__(self, other):
+ return isinstance(other, CartesianDeserializer) and \
+ self.key_ser == other.key_ser and self.val_ser == other.val_ser
+
+ def __str__(self):
+ return "CartesianDeserializer<%s, %s>" % \
+ (str(self.key_ser), str(self.val_ser))
- This relieves us from having to explicitly track whether an RDD
- is stored as batches of objects and avoids problems when processing
- the union() of batched and unbatched RDDs (e.g. the union() of textFile()
- with another RDD).
+
+class NoOpSerializer(FramedSerializer):
+
+ def loads(self, obj): return obj
+ def dumps(self, obj): return obj
+
+
+class PickleSerializer(FramedSerializer):
"""
- def __init__(self, items):
- self.items = items
+ Serializes objects using Python's cPickle serializer:
+ http://docs.python.org/2/library/pickle.html
-def batched(iterator, batchSize):
- if batchSize == -1: # unlimited batch size
- yield Batch(list(iterator))
- else:
- items = []
- count = 0
- for item in iterator:
- items.append(item)
- count += 1
- if count == batchSize:
- yield Batch(items)
- items = []
- count = 0
- if items:
- yield Batch(items)
+ This serializer supports nearly any Python object, but may
+ not be as fast as more specialized serializers.
+ """
+ def dumps(self, obj): return cPickle.dumps(obj, 2)
+ loads = cPickle.loads
-def dump_pickle(obj):
- return cPickle.dumps(obj, 2)
+class CloudPickleSerializer(PickleSerializer):
+ def dumps(self, obj): return cloudpickle.dumps(obj, 2)
-load_pickle = cPickle.loads
+
+class MarshalSerializer(FramedSerializer):
+ """
+ Serializes objects using Python's Marshal serializer:
+
+ http://docs.python.org/2/library/marshal.html
+
+ This serializer is faster than PickleSerializer but supports fewer datatypes.
+ """
+
+ dumps = marshal.dumps
+ loads = marshal.loads
+
+
+class MUTF8Deserializer(Serializer):
+ """
+ Deserializes streams written by Java's DataOutputStream.writeUTF().
+ """
+
+ def loads(self, stream):
+ length = struct.unpack('>H', stream.read(2))[0]
+ return stream.read(length).decode('utf8')
+
+ def load_stream(self, stream):
+ while True:
+ try:
+ yield self.loads(stream)
+ except struct.error:
+ return
+ except EOFError:
+ return
def read_long(stream):
@@ -85,24 +309,3 @@ def write_int(value, stream):
def write_with_length(obj, stream):
write_int(len(obj), stream)
stream.write(obj)
-
-
-def read_with_length(stream):
- length = read_int(stream)
- obj = stream.read(length)
- if obj == "":
- raise EOFError
- return obj
-
-
-def read_from_pickle_file(stream):
- try:
- while True:
- obj = load_pickle(read_with_length(stream))
- if type(obj) == Batch: # We don't care about inheritance
- for item in obj.items:
- yield item
- else:
- yield obj
- except EOFError:
- return