diff options
Diffstat (limited to 'python/pyspark/join.py')
-rw-r--r-- | python/pyspark/join.py | 20 |
1 files changed, 10 insertions, 10 deletions
diff --git a/python/pyspark/join.py b/python/pyspark/join.py index 6f94d26ef8..5f3a7e71f7 100644 --- a/python/pyspark/join.py +++ b/python/pyspark/join.py @@ -79,15 +79,15 @@ def python_left_outer_join(rdd, other, numPartitions): return _do_python_join(rdd, other, numPartitions, dispatch) -def python_cogroup(rdd, other, numPartitions): - vs = rdd.map(lambda (k, v): (k, (1, v))) - ws = other.map(lambda (k, v): (k, (2, v))) +def python_cogroup(rdds, numPartitions): + def make_mapper(i): + return lambda (k, v): (k, (i, v)) + vrdds = [rdd.map(make_mapper(i)) for i, rdd in enumerate(rdds)] + union_vrdds = reduce(lambda acc, other: acc.union(other), vrdds) + rdd_len = len(vrdds) def dispatch(seq): - vbuf, wbuf = [], [] + bufs = [[] for i in range(rdd_len)] for (n, v) in seq: - if n == 1: - vbuf.append(v) - elif n == 2: - wbuf.append(v) - return (ResultIterable(vbuf), ResultIterable(wbuf)) - return vs.union(ws).groupByKey(numPartitions).mapValues(dispatch) + bufs[n].append(v) + return tuple(map(ResultIterable, bufs)) + return union_vrdds.groupByKey(numPartitions).mapValues(dispatch) |