aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/rdd.py25
-rw-r--r--python/pyspark/serializers.py6
-rw-r--r--python/pyspark/tests.py9
3 files changed, 26 insertions, 14 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 57754776fa..bd2ff00c0f 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -469,8 +469,7 @@ class RDD(object):
def _reserialize(self, serializer=None):
serializer = serializer or self.ctx.serializer
if self._jrdd_deserializer != serializer:
- if not isinstance(self, PipelinedRDD):
- self = self.map(lambda x: x, preservesPartitioning=True)
+ self = self.map(lambda x: x, preservesPartitioning=True)
self._jrdd_deserializer = serializer
return self
@@ -1798,23 +1797,21 @@ class RDD(object):
def get_batch_size(ser):
if isinstance(ser, BatchedSerializer):
return ser.batchSize
- return 1
+ return 1 # not batched
def batch_as(rdd, batchSize):
- ser = rdd._jrdd_deserializer
- if isinstance(ser, BatchedSerializer):
- ser = ser.serializer
- return rdd._reserialize(BatchedSerializer(ser, batchSize))
+ return rdd._reserialize(BatchedSerializer(PickleSerializer(), batchSize))
my_batch = get_batch_size(self._jrdd_deserializer)
other_batch = get_batch_size(other._jrdd_deserializer)
- # use the smallest batchSize for both of them
- batchSize = min(my_batch, other_batch)
- if batchSize <= 0:
- # auto batched or unlimited
- batchSize = 100
- other = batch_as(other, batchSize)
- self = batch_as(self, batchSize)
+ if my_batch != other_batch:
+ # use the smallest batchSize for both of them
+ batchSize = min(my_batch, other_batch)
+ if batchSize <= 0:
+ # auto batched or unlimited
+ batchSize = 100
+ other = batch_as(other, batchSize)
+ self = batch_as(self, batchSize)
if self.getNumPartitions() != other.getNumPartitions():
raise ValueError("Can only zip with RDD which has the same number of partitions")
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 33aa55f7f1..bd08c9a6d2 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -463,6 +463,9 @@ class CompressedSerializer(FramedSerializer):
def loads(self, obj):
return self.serializer.loads(zlib.decompress(obj))
+ def __eq__(self, other):
+ return isinstance(other, CompressedSerializer) and self.serializer == other.serializer
+
class UTF8Deserializer(Serializer):
@@ -489,6 +492,9 @@ class UTF8Deserializer(Serializer):
except EOFError:
return
+ def __eq__(self, other):
+ return isinstance(other, UTF8Deserializer) and self.use_unicode == other.use_unicode
+
def read_long(stream):
length = stream.read(8)
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 32645778c2..bca52a7ce6 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -533,6 +533,15 @@ class RDDTests(ReusedPySparkTestCase):
a = a._reserialize(BatchedSerializer(PickleSerializer(), 2))
b = b._reserialize(MarshalSerializer())
self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)])
+ # regression test for SPARK-4841
+ path = os.path.join(SPARK_HOME, "python/test_support/hello.txt")
+ t = self.sc.textFile(path)
+ cnt = t.count()
+ self.assertEqual(cnt, t.zip(t).count())
+ rdd = t.map(str)
+ self.assertEqual(cnt, t.zip(rdd).count())
+ # regression test for bug in _reserializer()
+ self.assertEqual(cnt, t.zip(rdd).count())
def test_zip_with_different_number_of_items(self):
a = self.sc.parallelize(range(5), 2)