From 3540d4b387568a4017fcd772233e4e10c1beb1b4 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 19 Aug 2014 14:46:32 -0700 Subject: [SPARK-2790] [PySpark] fix zip with serializers which have different batch sizes. If two RDDs have different batch size in serializers, then it will try to re-serialize the one with smaller batch size, then call RDD.zip() in Spark. Author: Davies Liu Closes #1894 from davies/zip and squashes the following commits: c4652ea [Davies Liu] add more test cases 6d05fc8 [Davies Liu] Merge branch 'master' into zip 813b1e4 [Davies Liu] add more tests for failed cases a4aafda [Davies Liu] fix zip with serializers which have different batch sizes. (cherry picked from commit d7e80c2597d4a9cae2e0cb35a86f7889323f4cbb) Signed-off-by: Josh Rosen --- python/pyspark/rdd.py | 25 +++++++++++++++++++++++++ python/pyspark/serializers.py | 3 +++ python/pyspark/tests.py | 27 ++++++++++++++++++++++++++- 3 files changed, 54 insertions(+), 1 deletion(-) (limited to 'python') diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 86cd89b245..140cbe05a4 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -1687,6 +1687,31 @@ class RDD(object): >>> x.zip(y).collect() [(0, 1000), (1, 1001), (2, 1002), (3, 1003), (4, 1004)] """ + if self.getNumPartitions() != other.getNumPartitions(): + raise ValueError("Can only zip with RDD which has the same number of partitions") + + def get_batch_size(ser): + if isinstance(ser, BatchedSerializer): + return ser.batchSize + return 0 + + def batch_as(rdd, batchSize): + ser = rdd._jrdd_deserializer + if isinstance(ser, BatchedSerializer): + ser = ser.serializer + return rdd._reserialize(BatchedSerializer(ser, batchSize)) + + my_batch = get_batch_size(self._jrdd_deserializer) + other_batch = get_batch_size(other._jrdd_deserializer) + if my_batch != other_batch: + # use the greatest batchSize to batch the other one. + if my_batch > other_batch: + other = batch_as(other, my_batch) + else: + self = batch_as(self, other_batch) + + # There will be an Exception in JVM if there are different number + # of items in each partitions. pairRDD = self._jrdd.zip(other._jrdd) deserializer = PairDeserializer(self._jrdd_deserializer, other._jrdd_deserializer) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 74870c0edc..fc49aa42db 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -255,6 +255,9 @@ class PairDeserializer(CartesianDeserializer): def load_stream(self, stream): for (keys, vals) in self.prepare_keys_values(stream): + if len(keys) != len(vals): + raise ValueError("Can not deserialize RDD with different number of items" + " in pair: (%d, %d)" % (len(keys), len(vals))) for pair in izip(keys, vals): yield pair diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 69d543d9d0..51bfbb47e5 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -39,7 +39,7 @@ else: from pyspark.context import SparkContext from pyspark.files import SparkFiles -from pyspark.serializers import read_int +from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger _have_scipy = False @@ -339,6 +339,31 @@ class TestRDDFunctions(PySparkTestCase): m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() self.assertEquals(N, m) + def test_zip_with_different_serializers(self): + a = self.sc.parallelize(range(5)) + b = self.sc.parallelize(range(100, 105)) + self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)]) + a = a._reserialize(BatchedSerializer(PickleSerializer(), 2)) + b = b._reserialize(MarshalSerializer()) + self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)]) + + def test_zip_with_different_number_of_items(self): + a = self.sc.parallelize(range(5), 2) + # different number of partitions + b = self.sc.parallelize(range(100, 106), 3) + self.assertRaises(ValueError, lambda: a.zip(b)) + # different number of batched items in JVM + b = self.sc.parallelize(range(100, 104), 2) + self.assertRaises(Exception, lambda: a.zip(b).count()) + # different number of items in one pair + b = self.sc.parallelize(range(100, 106), 2) + self.assertRaises(Exception, lambda: a.zip(b).count()) + # same total number of items, but different distributions + a = self.sc.parallelize([2, 3], 2).flatMap(range) + b = self.sc.parallelize([3, 2], 2).flatMap(range) + self.assertEquals(a.count(), b.count()) + self.assertRaises(Exception, lambda: a.zip(b).count()) + class TestIO(PySparkTestCase): -- cgit v1.2.3