aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/join.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/join.py')
-rw-r--r--python/pyspark/join.py20
1 files changed, 10 insertions, 10 deletions
diff --git a/python/pyspark/join.py b/python/pyspark/join.py
index 6f94d26ef8..5f3a7e71f7 100644
--- a/python/pyspark/join.py
+++ b/python/pyspark/join.py
@@ -79,15 +79,15 @@ def python_left_outer_join(rdd, other, numPartitions):
return _do_python_join(rdd, other, numPartitions, dispatch)
-def python_cogroup(rdd, other, numPartitions):
- vs = rdd.map(lambda (k, v): (k, (1, v)))
- ws = other.map(lambda (k, v): (k, (2, v)))
+def python_cogroup(rdds, numPartitions):
+ def make_mapper(i):
+ return lambda (k, v): (k, (i, v))
+ vrdds = [rdd.map(make_mapper(i)) for i, rdd in enumerate(rdds)]
+ union_vrdds = reduce(lambda acc, other: acc.union(other), vrdds)
+ rdd_len = len(vrdds)
def dispatch(seq):
- vbuf, wbuf = [], []
+ bufs = [[] for i in range(rdd_len)]
for (n, v) in seq:
- if n == 1:
- vbuf.append(v)
- elif n == 2:
- wbuf.append(v)
- return (ResultIterable(vbuf), ResultIterable(wbuf))
- return vs.union(ws).groupByKey(numPartitions).mapValues(dispatch)
+ bufs[n].append(v)
+ return tuple(map(ResultIterable, bufs))
+ return union_vrdds.groupByKey(numPartitions).mapValues(dispatch)