aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2014-12-15 22:58:26 -0800
committerJosh Rosen <joshrosen@databricks.com>2014-12-17 12:22:28 -0800
commit0429ec3089afc03064f8ad4608b951ef324f34d8 (patch)
treefd259257a6b6482fe3d9a6572e4513cb97fecd02 /python
parent76c88c6687033bd3cb9a686ea922098f4c0212ad (diff)
downloadspark-0429ec3089afc03064f8ad4608b951ef324f34d8.tar.gz
spark-0429ec3089afc03064f8ad4608b951ef324f34d8.tar.bz2
spark-0429ec3089afc03064f8ad4608b951ef324f34d8.zip
[SPARK-4841] fix zip with textFile()
UTF8Deserializer can not be used in BatchedSerializer, so always use PickleSerializer() when change batchSize in zip(). Also, if two RDD have the same batch size already, they did not need re-serialize any more. Author: Davies Liu <davies@databricks.com> Closes #3706 from davies/fix_4841 and squashes the following commits: 20ce3a3 [Davies Liu] fix bug in _reserialize() e3ebf7c [Davies Liu] add comment 379d2c8 [Davies Liu] fix zip with textFile() (cherry picked from commit c246b95dd2f565043db429c38c6cc029a0b870c1) Signed-off-by: Josh Rosen <joshrosen@databricks.com>
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/rdd.py25
-rw-r--r--python/pyspark/serializers.py6
-rw-r--r--python/pyspark/tests.py9
3 files changed, 26 insertions, 14 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 57754776fa..bd2ff00c0f 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -469,8 +469,7 @@ class RDD(object):
def _reserialize(self, serializer=None):
serializer = serializer or self.ctx.serializer
if self._jrdd_deserializer != serializer:
- if not isinstance(self, PipelinedRDD):
- self = self.map(lambda x: x, preservesPartitioning=True)
+ self = self.map(lambda x: x, preservesPartitioning=True)
self._jrdd_deserializer = serializer
return self
@@ -1798,23 +1797,21 @@ class RDD(object):
def get_batch_size(ser):
if isinstance(ser, BatchedSerializer):
return ser.batchSize
- return 1
+ return 1 # not batched
def batch_as(rdd, batchSize):
- ser = rdd._jrdd_deserializer
- if isinstance(ser, BatchedSerializer):
- ser = ser.serializer
- return rdd._reserialize(BatchedSerializer(ser, batchSize))
+ return rdd._reserialize(BatchedSerializer(PickleSerializer(), batchSize))
my_batch = get_batch_size(self._jrdd_deserializer)
other_batch = get_batch_size(other._jrdd_deserializer)
- # use the smallest batchSize for both of them
- batchSize = min(my_batch, other_batch)
- if batchSize <= 0:
- # auto batched or unlimited
- batchSize = 100
- other = batch_as(other, batchSize)
- self = batch_as(self, batchSize)
+ if my_batch != other_batch:
+ # use the smallest batchSize for both of them
+ batchSize = min(my_batch, other_batch)
+ if batchSize <= 0:
+ # auto batched or unlimited
+ batchSize = 100
+ other = batch_as(other, batchSize)
+ self = batch_as(self, batchSize)
if self.getNumPartitions() != other.getNumPartitions():
raise ValueError("Can only zip with RDD which has the same number of partitions")
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 33aa55f7f1..bd08c9a6d2 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -463,6 +463,9 @@ class CompressedSerializer(FramedSerializer):
def loads(self, obj):
return self.serializer.loads(zlib.decompress(obj))
+ def __eq__(self, other):
+ return isinstance(other, CompressedSerializer) and self.serializer == other.serializer
+
class UTF8Deserializer(Serializer):
@@ -489,6 +492,9 @@ class UTF8Deserializer(Serializer):
except EOFError:
return
+ def __eq__(self, other):
+ return isinstance(other, UTF8Deserializer) and self.use_unicode == other.use_unicode
+
def read_long(stream):
length = stream.read(8)
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 32645778c2..bca52a7ce6 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -533,6 +533,15 @@ class RDDTests(ReusedPySparkTestCase):
a = a._reserialize(BatchedSerializer(PickleSerializer(), 2))
b = b._reserialize(MarshalSerializer())
self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)])
+ # regression test for SPARK-4841
+ path = os.path.join(SPARK_HOME, "python/test_support/hello.txt")
+ t = self.sc.textFile(path)
+ cnt = t.count()
+ self.assertEqual(cnt, t.zip(t).count())
+ rdd = t.map(str)
+ self.assertEqual(cnt, t.zip(rdd).count())
+ # regression test for bug in _reserializer()
+ self.assertEqual(cnt, t.zip(rdd).count())
def test_zip_with_different_number_of_items(self):
a = self.sc.parallelize(range(5), 2)