diff options
Diffstat (limited to 'python/pyspark/serializers.py')
-rw-r--r-- | python/pyspark/serializers.py | 29 |
1 files changed, 28 insertions, 1 deletions
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 9be78b39fb..03b31ae962 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -193,7 +193,7 @@ class BatchedSerializer(Serializer): return chain.from_iterable(self._load_stream_without_unbatching(stream)) def _load_stream_without_unbatching(self, stream): - return self.serializer.load_stream(stream) + return self.serializer.load_stream(stream) def __eq__(self, other): return (isinstance(other, BatchedSerializer) and @@ -302,6 +302,33 @@ class MarshalSerializer(FramedSerializer): loads = marshal.loads +class AutoSerializer(FramedSerializer): + """ + Choose marshal or cPickle as serialization protocol autumatically + """ + def __init__(self): + FramedSerializer.__init__(self) + self._type = None + + def dumps(self, obj): + if self._type is not None: + return 'P' + cPickle.dumps(obj, -1) + try: + return 'M' + marshal.dumps(obj) + except Exception: + self._type = 'P' + return 'P' + cPickle.dumps(obj, -1) + + def loads(self, obj): + _type = obj[0] + if _type == 'M': + return marshal.loads(obj[1:]) + elif _type == 'P': + return cPickle.loads(obj[1:]) + else: + raise ValueError("invalid sevialization type: %s" % _type) + + class UTF8Deserializer(Serializer): """ Deserializes streams written by String.getBytes. |