aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/rdd.py20
-rw-r--r--python/pyspark/serializers.py29
2 files changed, 47 insertions, 2 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index e72f57d9d1..5ab27ff402 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -30,7 +30,7 @@ from threading import Thread
import warnings
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
- BatchedSerializer, CloudPickleSerializer, pack_long
+ BatchedSerializer, CloudPickleSerializer, PairDeserializer, pack_long
from pyspark.join import python_join, python_left_outer_join, \
python_right_outer_join, python_cogroup
from pyspark.statcounter import StatCounter
@@ -1081,6 +1081,24 @@ class RDD(object):
jrdd = self._jrdd.coalesce(numPartitions)
return RDD(jrdd, self.ctx, self._jrdd_deserializer)
+ def zip(self, other):
+ """
+ Zips this RDD with another one, returning key-value pairs with the first element in each RDD
+ second element in each RDD, etc. Assumes that the two RDDs have the same number of
+ partitions and the same number of elements in each partition (e.g. one was made through
+ a map on the other).
+
+ >>> x = sc.parallelize(range(0,5))
+ >>> y = sc.parallelize(range(1000, 1005))
+ >>> x.zip(y).collect()
+ [(0, 1000), (1, 1001), (2, 1002), (3, 1003), (4, 1004)]
+ """
+ pairRDD = self._jrdd.zip(other._jrdd)
+ deserializer = PairDeserializer(self._jrdd_deserializer,
+ other._jrdd_deserializer)
+ return RDD(pairRDD, self.ctx, deserializer)
+
+
# TODO: `lookup` is disabled because we can't make direct comparisons based
# on the key; we need to compare the hash of the key to the hash of the
# keys in the pairs. This could be an expensive operation, since those
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 8c6ad79059..12c63f186a 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -204,7 +204,7 @@ class CartesianDeserializer(FramedSerializer):
self.key_ser = key_ser
self.val_ser = val_ser
- def load_stream(self, stream):
+ def prepare_keys_values(self, stream):
key_stream = self.key_ser._load_stream_without_unbatching(stream)
val_stream = self.val_ser._load_stream_without_unbatching(stream)
key_is_batched = isinstance(self.key_ser, BatchedSerializer)
@@ -212,6 +212,10 @@ class CartesianDeserializer(FramedSerializer):
for (keys, vals) in izip(key_stream, val_stream):
keys = keys if key_is_batched else [keys]
vals = vals if val_is_batched else [vals]
+ yield (keys, vals)
+
+ def load_stream(self, stream):
+ for (keys, vals) in self.prepare_keys_values(stream):
for pair in product(keys, vals):
yield pair
@@ -224,6 +228,29 @@ class CartesianDeserializer(FramedSerializer):
(str(self.key_ser), str(self.val_ser))
+class PairDeserializer(CartesianDeserializer):
+ """
+ Deserializes the JavaRDD zip() of two PythonRDDs.
+ """
+
+ def __init__(self, key_ser, val_ser):
+ self.key_ser = key_ser
+ self.val_ser = val_ser
+
+ def load_stream(self, stream):
+ for (keys, vals) in self.prepare_keys_values(stream):
+ for pair in izip(keys, vals):
+ yield pair
+
+ def __eq__(self, other):
+ return isinstance(other, PairDeserializer) and \
+ self.key_ser == other.key_ser and self.val_ser == other.val_ser
+
+ def __str__(self):
+ return "PairDeserializer<%s, %s>" % \
+ (str(self.key_ser), str(self.val_ser))
+
+
class NoOpSerializer(FramedSerializer):
def loads(self, obj): return obj