aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/serializers.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/serializers.py')
-rw-r--r--python/pyspark/serializers.py29
1 files changed, 28 insertions, 1 deletions
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 9be78b39fb..03b31ae962 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -193,7 +193,7 @@ class BatchedSerializer(Serializer):
return chain.from_iterable(self._load_stream_without_unbatching(stream))
def _load_stream_without_unbatching(self, stream):
- return self.serializer.load_stream(stream)
+ return self.serializer.load_stream(stream)
def __eq__(self, other):
return (isinstance(other, BatchedSerializer) and
@@ -302,6 +302,33 @@ class MarshalSerializer(FramedSerializer):
loads = marshal.loads
+class AutoSerializer(FramedSerializer):
+ """
+ Choose marshal or cPickle as serialization protocol autumatically
+ """
+ def __init__(self):
+ FramedSerializer.__init__(self)
+ self._type = None
+
+ def dumps(self, obj):
+ if self._type is not None:
+ return 'P' + cPickle.dumps(obj, -1)
+ try:
+ return 'M' + marshal.dumps(obj)
+ except Exception:
+ self._type = 'P'
+ return 'P' + cPickle.dumps(obj, -1)
+
+ def loads(self, obj):
+ _type = obj[0]
+ if _type == 'M':
+ return marshal.loads(obj[1:])
+ elif _type == 'P':
+ return cPickle.loads(obj[1:])
+ else:
+ raise ValueError("invalid sevialization type: %s" % _type)
+
+
class UTF8Deserializer(Serializer):
"""
Deserializes streams written by String.getBytes.