aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark')
-rw-r--r--python/pyspark/rdd.py25
-rw-r--r--python/pyspark/serializers.py3
-rw-r--r--python/pyspark/tests.py27
3 files changed, 54 insertions, 1 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 86cd89b245..140cbe05a4 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -1687,6 +1687,31 @@ class RDD(object):
>>> x.zip(y).collect()
[(0, 1000), (1, 1001), (2, 1002), (3, 1003), (4, 1004)]
"""
+ if self.getNumPartitions() != other.getNumPartitions():
+ raise ValueError("Can only zip with RDD which has the same number of partitions")
+
+ def get_batch_size(ser):
+ if isinstance(ser, BatchedSerializer):
+ return ser.batchSize
+ return 0
+
+ def batch_as(rdd, batchSize):
+ ser = rdd._jrdd_deserializer
+ if isinstance(ser, BatchedSerializer):
+ ser = ser.serializer
+ return rdd._reserialize(BatchedSerializer(ser, batchSize))
+
+ my_batch = get_batch_size(self._jrdd_deserializer)
+ other_batch = get_batch_size(other._jrdd_deserializer)
+ if my_batch != other_batch:
+ # use the greatest batchSize to batch the other one.
+ if my_batch > other_batch:
+ other = batch_as(other, my_batch)
+ else:
+ self = batch_as(self, other_batch)
+
+ # There will be an Exception in JVM if there are different number
+ # of items in each partitions.
pairRDD = self._jrdd.zip(other._jrdd)
deserializer = PairDeserializer(self._jrdd_deserializer,
other._jrdd_deserializer)
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 74870c0edc..fc49aa42db 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -255,6 +255,9 @@ class PairDeserializer(CartesianDeserializer):
def load_stream(self, stream):
for (keys, vals) in self.prepare_keys_values(stream):
+ if len(keys) != len(vals):
+ raise ValueError("Can not deserialize RDD with different number of items"
+ " in pair: (%d, %d)" % (len(keys), len(vals)))
for pair in izip(keys, vals):
yield pair
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 69d543d9d0..51bfbb47e5 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -39,7 +39,7 @@ else:
from pyspark.context import SparkContext
from pyspark.files import SparkFiles
-from pyspark.serializers import read_int
+from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger
_have_scipy = False
@@ -339,6 +339,31 @@ class TestRDDFunctions(PySparkTestCase):
m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum()
self.assertEquals(N, m)
+ def test_zip_with_different_serializers(self):
+ a = self.sc.parallelize(range(5))
+ b = self.sc.parallelize(range(100, 105))
+ self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)])
+ a = a._reserialize(BatchedSerializer(PickleSerializer(), 2))
+ b = b._reserialize(MarshalSerializer())
+ self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)])
+
+ def test_zip_with_different_number_of_items(self):
+ a = self.sc.parallelize(range(5), 2)
+ # different number of partitions
+ b = self.sc.parallelize(range(100, 106), 3)
+ self.assertRaises(ValueError, lambda: a.zip(b))
+ # different number of batched items in JVM
+ b = self.sc.parallelize(range(100, 104), 2)
+ self.assertRaises(Exception, lambda: a.zip(b).count())
+ # different number of items in one pair
+ b = self.sc.parallelize(range(100, 106), 2)
+ self.assertRaises(Exception, lambda: a.zip(b).count())
+ # same total number of items, but different distributions
+ a = self.sc.parallelize([2, 3], 2).flatMap(range)
+ b = self.sc.parallelize([3, 2], 2).flatMap(range)
+ self.assertEquals(a.count(), b.count())
+ self.assertRaises(Exception, lambda: a.zip(b).count())
+
class TestIO(PySparkTestCase):