aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndrew Ray <ray.andrew@gmail.com>2016-12-08 11:08:12 -0800
committerDavies Liu <davies.liu@gmail.com>2016-12-08 11:08:12 -0800
commit3c68944b229aaaeeaee3efcbae3e3be9a2914855 (patch)
tree8f6cf65d6396567a42c7d442d37fc1a1f29438b5
parented8869ebbf39783b16daba2e2498a2bc1889306f (diff)
downloadspark-3c68944b229aaaeeaee3efcbae3e3be9a2914855.tar.gz
spark-3c68944b229aaaeeaee3efcbae3e3be9a2914855.tar.bz2
spark-3c68944b229aaaeeaee3efcbae3e3be9a2914855.zip
[SPARK-16589] [PYTHON] Chained cartesian produces incorrect number of records
## What changes were proposed in this pull request? Fixes a bug in the python implementation of rdd cartesian product related to batching that showed up in repeated cartesian products with seemingly random results. The root cause being multiple iterators pulling from the same stream in the wrong order because of logic that ignored batching. `CartesianDeserializer` and `PairDeserializer` were changed to implement `_load_stream_without_unbatching` and borrow the one line implementation of `load_stream` from `BatchedSerializer`. The default implementation of `_load_stream_without_unbatching` was changed to give consistent results (always an iterable) so that it could be used without additional checks. `PairDeserializer` no longer extends `CartesianDeserializer` as it was not really proper. If wanted a new common super class could be added. Both `CartesianDeserializer` and `PairDeserializer` now only extend `Serializer` (which has no `dump_stream` implementation) since they are only meant for *de*serialization. ## How was this patch tested? Additional unit tests (sourced from #14248) plus one for testing a cartesian with zip. Author: Andrew Ray <ray.andrew@gmail.com> Closes #16121 from aray/fix-cartesian.
-rw-r--r--python/pyspark/serializers.py58
-rw-r--r--python/pyspark/tests.py18
2 files changed, 53 insertions, 23 deletions
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 2a1326947f..c4f2f08cb4 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -61,7 +61,7 @@ import itertools
if sys.version < '3':
import cPickle as pickle
protocol = 2
- from itertools import izip as zip
+ from itertools import izip as zip, imap as map
else:
import pickle
protocol = 3
@@ -96,7 +96,12 @@ class Serializer(object):
raise NotImplementedError
def _load_stream_without_unbatching(self, stream):
- return self.load_stream(stream)
+ """
+ Return an iterator of deserialized batches (lists) of objects from the input stream.
+ if the serializer does not operate on batches the default implementation returns an
+ iterator of single element lists.
+ """
+ return map(lambda x: [x], self.load_stream(stream))
# Note: our notion of "equality" is that output generated by
# equal serializers can be deserialized using the same serializer.
@@ -278,50 +283,57 @@ class AutoBatchedSerializer(BatchedSerializer):
return "AutoBatchedSerializer(%s)" % self.serializer
-class CartesianDeserializer(FramedSerializer):
+class CartesianDeserializer(Serializer):
"""
Deserializes the JavaRDD cartesian() of two PythonRDDs.
+ Due to pyspark batching we cannot simply use the result of the Java RDD cartesian,
+ we additionally need to do the cartesian within each pair of batches.
"""
def __init__(self, key_ser, val_ser):
- FramedSerializer.__init__(self)
self.key_ser = key_ser
self.val_ser = val_ser
- def prepare_keys_values(self, stream):
- key_stream = self.key_ser._load_stream_without_unbatching(stream)
- val_stream = self.val_ser._load_stream_without_unbatching(stream)
- key_is_batched = isinstance(self.key_ser, BatchedSerializer)
- val_is_batched = isinstance(self.val_ser, BatchedSerializer)
- for (keys, vals) in zip(key_stream, val_stream):
- keys = keys if key_is_batched else [keys]
- vals = vals if val_is_batched else [vals]
- yield (keys, vals)
+ def _load_stream_without_unbatching(self, stream):
+ key_batch_stream = self.key_ser._load_stream_without_unbatching(stream)
+ val_batch_stream = self.val_ser._load_stream_without_unbatching(stream)
+ for (key_batch, val_batch) in zip(key_batch_stream, val_batch_stream):
+ # for correctness with repeated cartesian/zip this must be returned as one batch
+ yield product(key_batch, val_batch)
def load_stream(self, stream):
- for (keys, vals) in self.prepare_keys_values(stream):
- for pair in product(keys, vals):
- yield pair
+ return chain.from_iterable(self._load_stream_without_unbatching(stream))
def __repr__(self):
return "CartesianDeserializer(%s, %s)" % \
(str(self.key_ser), str(self.val_ser))
-class PairDeserializer(CartesianDeserializer):
+class PairDeserializer(Serializer):
"""
Deserializes the JavaRDD zip() of two PythonRDDs.
+ Due to pyspark batching we cannot simply use the result of the Java RDD zip,
+ we additionally need to do the zip within each pair of batches.
"""
+ def __init__(self, key_ser, val_ser):
+ self.key_ser = key_ser
+ self.val_ser = val_ser
+
+ def _load_stream_without_unbatching(self, stream):
+ key_batch_stream = self.key_ser._load_stream_without_unbatching(stream)
+ val_batch_stream = self.val_ser._load_stream_without_unbatching(stream)
+ for (key_batch, val_batch) in zip(key_batch_stream, val_batch_stream):
+ if len(key_batch) != len(val_batch):
+ raise ValueError("Can not deserialize PairRDD with different number of items"
+ " in batches: (%d, %d)" % (len(key_batch), len(val_batch)))
+ # for correctness with repeated cartesian/zip this must be returned as one batch
+ yield zip(key_batch, val_batch)
+
def load_stream(self, stream):
- for (keys, vals) in self.prepare_keys_values(stream):
- if len(keys) != len(vals):
- raise ValueError("Can not deserialize RDD with different number of items"
- " in pair: (%d, %d)" % (len(keys), len(vals)))
- for pair in zip(keys, vals):
- yield pair
+ return chain.from_iterable(self._load_stream_without_unbatching(stream))
def __repr__(self):
return "PairDeserializer(%s, %s)" % (str(self.key_ser), str(self.val_ser))
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index ab4bef8329..89fce8ab25 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -548,6 +548,24 @@ class RDDTests(ReusedPySparkTestCase):
self.assertEqual(u"Hello World!", x.strip())
self.assertEqual(u"Hello World!", y.strip())
+ def test_cartesian_chaining(self):
+ # Tests for SPARK-16589
+ rdd = self.sc.parallelize(range(10), 2)
+ self.assertSetEqual(
+ set(rdd.cartesian(rdd).cartesian(rdd).collect()),
+ set([((x, y), z) for x in range(10) for y in range(10) for z in range(10)])
+ )
+
+ self.assertSetEqual(
+ set(rdd.cartesian(rdd.cartesian(rdd)).collect()),
+ set([(x, (y, z)) for x in range(10) for y in range(10) for z in range(10)])
+ )
+
+ self.assertSetEqual(
+ set(rdd.cartesian(rdd.zip(rdd)).collect()),
+ set([(x, (y, y)) for x in range(10) for y in range(10)])
+ )
+
def test_deleting_input_files(self):
# Regression test for SPARK-1025
tempFile = tempfile.NamedTemporaryFile(delete=False)