aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/serializers.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/serializers.py')
-rw-r--r--python/pyspark/serializers.py58
1 files changed, 35 insertions, 23 deletions
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 2a1326947f..c4f2f08cb4 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -61,7 +61,7 @@ import itertools
if sys.version < '3':
import cPickle as pickle
protocol = 2
- from itertools import izip as zip
+ from itertools import izip as zip, imap as map
else:
import pickle
protocol = 3
@@ -96,7 +96,12 @@ class Serializer(object):
raise NotImplementedError
def _load_stream_without_unbatching(self, stream):
- return self.load_stream(stream)
+ """
+ Return an iterator of deserialized batches (lists) of objects from the input stream.
+ if the serializer does not operate on batches the default implementation returns an
+ iterator of single element lists.
+ """
+ return map(lambda x: [x], self.load_stream(stream))
# Note: our notion of "equality" is that output generated by
# equal serializers can be deserialized using the same serializer.
@@ -278,50 +283,57 @@ class AutoBatchedSerializer(BatchedSerializer):
return "AutoBatchedSerializer(%s)" % self.serializer
-class CartesianDeserializer(FramedSerializer):
+class CartesianDeserializer(Serializer):
"""
Deserializes the JavaRDD cartesian() of two PythonRDDs.
+ Due to pyspark batching we cannot simply use the result of the Java RDD cartesian,
+ we additionally need to do the cartesian within each pair of batches.
"""
def __init__(self, key_ser, val_ser):
- FramedSerializer.__init__(self)
self.key_ser = key_ser
self.val_ser = val_ser
- def prepare_keys_values(self, stream):
- key_stream = self.key_ser._load_stream_without_unbatching(stream)
- val_stream = self.val_ser._load_stream_without_unbatching(stream)
- key_is_batched = isinstance(self.key_ser, BatchedSerializer)
- val_is_batched = isinstance(self.val_ser, BatchedSerializer)
- for (keys, vals) in zip(key_stream, val_stream):
- keys = keys if key_is_batched else [keys]
- vals = vals if val_is_batched else [vals]
- yield (keys, vals)
+ def _load_stream_without_unbatching(self, stream):
+ key_batch_stream = self.key_ser._load_stream_without_unbatching(stream)
+ val_batch_stream = self.val_ser._load_stream_without_unbatching(stream)
+ for (key_batch, val_batch) in zip(key_batch_stream, val_batch_stream):
+ # for correctness with repeated cartesian/zip this must be returned as one batch
+ yield product(key_batch, val_batch)
def load_stream(self, stream):
- for (keys, vals) in self.prepare_keys_values(stream):
- for pair in product(keys, vals):
- yield pair
+ return chain.from_iterable(self._load_stream_without_unbatching(stream))
def __repr__(self):
return "CartesianDeserializer(%s, %s)" % \
(str(self.key_ser), str(self.val_ser))
-class PairDeserializer(CartesianDeserializer):
+class PairDeserializer(Serializer):
"""
Deserializes the JavaRDD zip() of two PythonRDDs.
+ Due to pyspark batching we cannot simply use the result of the Java RDD zip,
+ we additionally need to do the zip within each pair of batches.
"""
+ def __init__(self, key_ser, val_ser):
+ self.key_ser = key_ser
+ self.val_ser = val_ser
+
+ def _load_stream_without_unbatching(self, stream):
+ key_batch_stream = self.key_ser._load_stream_without_unbatching(stream)
+ val_batch_stream = self.val_ser._load_stream_without_unbatching(stream)
+ for (key_batch, val_batch) in zip(key_batch_stream, val_batch_stream):
+ if len(key_batch) != len(val_batch):
+ raise ValueError("Can not deserialize PairRDD with different number of items"
+ " in batches: (%d, %d)" % (len(key_batch), len(val_batch)))
+ # for correctness with repeated cartesian/zip this must be returned as one batch
+ yield zip(key_batch, val_batch)
+
def load_stream(self, stream):
- for (keys, vals) in self.prepare_keys_values(stream):
- if len(keys) != len(vals):
- raise ValueError("Can not deserialize RDD with different number of items"
- " in pair: (%d, %d)" % (len(keys), len(vals)))
- for pair in zip(keys, vals):
- yield pair
+ return chain.from_iterable(self._load_stream_without_unbatching(stream))
def __repr__(self):
return "PairDeserializer(%s, %s)" % (str(self.key_ser), str(self.val_ser))