aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAllan Douglas R. de Oliveira <allandouglas@gmail.com>2014-06-20 11:03:03 -0700
committerPatrick Wendell <pwendell@gmail.com>2014-06-20 11:03:03 -0700
commit6a224c31e8563156ad5732a23667e73076984ae1 (patch)
treee3364e0ab07258a483668635a90da442a6d0a8df
parentd484ddeff1440d8e14e05c3cd7e7a18746f1a586 (diff)
downloadspark-6a224c31e8563156ad5732a23667e73076984ae1.tar.gz
spark-6a224c31e8563156ad5732a23667e73076984ae1.tar.bz2
spark-6a224c31e8563156ad5732a23667e73076984ae1.zip
SPARK-1868: Users should be allowed to cogroup at least 4 RDDs
Adds cogroup for 4 RDDs. Author: Allan Douglas R. de Oliveira <allandouglas@gmail.com> Closes #813 from douglaz/more_cogroups and squashes the following commits: f8d6273 [Allan Douglas R. de Oliveira] Test python groupWith for one more case 0e9009c [Allan Douglas R. de Oliveira] Added scala tests c3ffcdd [Allan Douglas R. de Oliveira] Added java tests 517a67f [Allan Douglas R. de Oliveira] Added tests for python groupWith 2f402d5 [Allan Douglas R. de Oliveira] Removed TODO 17474f4 [Allan Douglas R. de Oliveira] Use new cogroup function 7877a2a [Allan Douglas R. de Oliveira] Fixed code ba02414 [Allan Douglas R. de Oliveira] Added varargs cogroup to pyspark c4a8a51 [Allan Douglas R. de Oliveira] Added java cogroup 4 e94963c [Allan Douglas R. de Oliveira] Fixed spacing f1ee57b [Allan Douglas R. de Oliveira] Fixed scala style issues d7196f1 [Allan Douglas R. de Oliveira] Allow the cogroup of 4 RDDs
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala51
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala51
-rw-r--r--core/src/test/java/org/apache/spark/JavaAPISuite.java63
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala33
-rw-r--r--python/pyspark/join.py20
-rw-r--r--python/pyspark/rdd.py22
6 files changed, 223 insertions, 17 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
index 14fa9d8135..4f3081433a 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
@@ -544,6 +544,18 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2, partitioner)))
/**
+ * For each key k in `this` or `other1` or `other2` or `other3`,
+ * return a resulting RDD that contains a tuple with the list of values
+ * for that key in `this`, `other1`, `other2` and `other3`.
+ */
+ def cogroup[W1, W2, W3](other1: JavaPairRDD[K, W1],
+ other2: JavaPairRDD[K, W2],
+ other3: JavaPairRDD[K, W3],
+ partitioner: Partitioner)
+ : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3])] =
+ fromRDD(cogroupResult3ToJava(rdd.cogroup(other1, other2, other3, partitioner)))
+
+ /**
* For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the
* list of values for that key in `this` as well as `other`.
*/
@@ -559,6 +571,17 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2)))
/**
+ * For each key k in `this` or `other1` or `other2` or `other3`,
+ * return a resulting RDD that contains a tuple with the list of values
+ * for that key in `this`, `other1`, `other2` and `other3`.
+ */
+ def cogroup[W1, W2, W3](other1: JavaPairRDD[K, W1],
+ other2: JavaPairRDD[K, W2],
+ other3: JavaPairRDD[K, W3])
+ : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3])] =
+ fromRDD(cogroupResult3ToJava(rdd.cogroup(other1, other2, other3)))
+
+ /**
* For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the
* list of values for that key in `this` as well as `other`.
*/
@@ -574,6 +597,18 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
: JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2])] =
fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2, numPartitions)))
+ /**
+ * For each key k in `this` or `other1` or `other2` or `other3`,
+ * return a resulting RDD that contains a tuple with the list of values
+ * for that key in `this`, `other1`, `other2` and `other3`.
+ */
+ def cogroup[W1, W2, W3](other1: JavaPairRDD[K, W1],
+ other2: JavaPairRDD[K, W2],
+ other3: JavaPairRDD[K, W3],
+ numPartitions: Int)
+ : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3])] =
+ fromRDD(cogroupResult3ToJava(rdd.cogroup(other1, other2, other3, numPartitions)))
+
/** Alias for cogroup. */
def groupWith[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (JIterable[V], JIterable[W])] =
fromRDD(cogroupResultToJava(rdd.groupWith(other)))
@@ -583,6 +618,13 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
: JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2])] =
fromRDD(cogroupResult2ToJava(rdd.groupWith(other1, other2)))
+ /** Alias for cogroup. */
+ def groupWith[W1, W2, W3](other1: JavaPairRDD[K, W1],
+ other2: JavaPairRDD[K, W2],
+ other3: JavaPairRDD[K, W3])
+ : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3])] =
+ fromRDD(cogroupResult3ToJava(rdd.groupWith(other1, other2, other3)))
+
/**
* Return the list of values in the RDD for key `key`. This operation is done efficiently if the
* RDD has a known partitioner by only searching the partition that the key maps to.
@@ -786,6 +828,15 @@ object JavaPairRDD {
.mapValues(x => (asJavaIterable(x._1), asJavaIterable(x._2), asJavaIterable(x._3)))
}
+ private[spark]
+ def cogroupResult3ToJava[K: ClassTag, V, W1, W2, W3](
+ rdd: RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))])
+ : RDD[(K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3]))] = {
+ rddToPairRDDFunctions(rdd)
+ .mapValues(x =>
+ (asJavaIterable(x._1), asJavaIterable(x._2), asJavaIterable(x._3), asJavaIterable(x._4)))
+ }
+
def fromRDD[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]): JavaPairRDD[K, V] = {
new JavaPairRDD[K, V](rdd)
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index fe36c80e0b..443d1c587c 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -568,6 +568,28 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
}
/**
+ * For each key k in `this` or `other1` or `other2` or `other3`,
+ * return a resulting RDD that contains a tuple with the list of values
+ * for that key in `this`, `other1`, `other2` and `other3`.
+ */
+ def cogroup[W1, W2, W3](other1: RDD[(K, W1)],
+ other2: RDD[(K, W2)],
+ other3: RDD[(K, W3)],
+ partitioner: Partitioner)
+ : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = {
+ if (partitioner.isInstanceOf[HashPartitioner] && keyClass.isArray) {
+ throw new SparkException("Default partitioner cannot partition array keys.")
+ }
+ val cg = new CoGroupedRDD[K](Seq(self, other1, other2, other3), partitioner)
+ cg.mapValues { case Seq(vs, w1s, w2s, w3s) =>
+ (vs.asInstanceOf[Seq[V]],
+ w1s.asInstanceOf[Seq[W1]],
+ w2s.asInstanceOf[Seq[W2]],
+ w3s.asInstanceOf[Seq[W3]])
+ }
+ }
+
+ /**
* For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the
* list of values for that key in `this` as well as `other`.
*/
@@ -600,6 +622,16 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
}
/**
+ * For each key k in `this` or `other1` or `other2` or `other3`,
+ * return a resulting RDD that contains a tuple with the list of values
+ * for that key in `this`, `other1`, `other2` and `other3`.
+ */
+ def cogroup[W1, W2, W3](other1: RDD[(K, W1)], other2: RDD[(K, W2)], other3: RDD[(K, W3)])
+ : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = {
+ cogroup(other1, other2, other3, defaultPartitioner(self, other1, other2, other3))
+ }
+
+ /**
* For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the
* list of values for that key in `this` as well as `other`.
*/
@@ -633,6 +665,19 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
cogroup(other1, other2, new HashPartitioner(numPartitions))
}
+ /**
+ * For each key k in `this` or `other1` or `other2` or `other3`,
+ * return a resulting RDD that contains a tuple with the list of values
+ * for that key in `this`, `other1`, `other2` and `other3`.
+ */
+ def cogroup[W1, W2, W3](other1: RDD[(K, W1)],
+ other2: RDD[(K, W2)],
+ other3: RDD[(K, W3)],
+ numPartitions: Int)
+ : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = {
+ cogroup(other1, other2, other3, new HashPartitioner(numPartitions))
+ }
+
/** Alias for cogroup. */
def groupWith[W](other: RDD[(K, W)]): RDD[(K, (Iterable[V], Iterable[W]))] = {
cogroup(other, defaultPartitioner(self, other))
@@ -644,6 +689,12 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
cogroup(other1, other2, defaultPartitioner(self, other1, other2))
}
+ /** Alias for cogroup. */
+ def groupWith[W1, W2, W3](other1: RDD[(K, W1)], other2: RDD[(K, W2)], other3: RDD[(K, W3)])
+ : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = {
+ cogroup(other1, other2, other3, defaultPartitioner(self, other1, other2, other3))
+ }
+
/**
* Return an RDD with the pairs from `this` whose keys are not in `other`.
*
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index e46298c6a9..761f2d6a77 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -21,6 +21,9 @@ import java.io.*;
import java.util.*;
import scala.Tuple2;
+import scala.Tuple3;
+import scala.Tuple4;
+
import com.google.common.collect.Iterables;
import com.google.common.collect.Iterators;
@@ -306,6 +309,66 @@ public class JavaAPISuite implements Serializable {
@SuppressWarnings("unchecked")
@Test
+ public void cogroup3() {
+ JavaPairRDD<String, String> categories = sc.parallelizePairs(Arrays.asList(
+ new Tuple2<String, String>("Apples", "Fruit"),
+ new Tuple2<String, String>("Oranges", "Fruit"),
+ new Tuple2<String, String>("Oranges", "Citrus")
+ ));
+ JavaPairRDD<String, Integer> prices = sc.parallelizePairs(Arrays.asList(
+ new Tuple2<String, Integer>("Oranges", 2),
+ new Tuple2<String, Integer>("Apples", 3)
+ ));
+ JavaPairRDD<String, Integer> quantities = sc.parallelizePairs(Arrays.asList(
+ new Tuple2<String, Integer>("Oranges", 21),
+ new Tuple2<String, Integer>("Apples", 42)
+ ));
+
+ JavaPairRDD<String, Tuple3<Iterable<String>, Iterable<Integer>, Iterable<Integer>>> cogrouped =
+ categories.cogroup(prices, quantities);
+ Assert.assertEquals("[Fruit, Citrus]",
+ Iterables.toString(cogrouped.lookup("Oranges").get(0)._1()));
+ Assert.assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2()));
+ Assert.assertEquals("[42]", Iterables.toString(cogrouped.lookup("Apples").get(0)._3()));
+
+
+ cogrouped.collect();
+ }
+
+ @SuppressWarnings("unchecked")
+ @Test
+ public void cogroup4() {
+ JavaPairRDD<String, String> categories = sc.parallelizePairs(Arrays.asList(
+ new Tuple2<String, String>("Apples", "Fruit"),
+ new Tuple2<String, String>("Oranges", "Fruit"),
+ new Tuple2<String, String>("Oranges", "Citrus")
+ ));
+ JavaPairRDD<String, Integer> prices = sc.parallelizePairs(Arrays.asList(
+ new Tuple2<String, Integer>("Oranges", 2),
+ new Tuple2<String, Integer>("Apples", 3)
+ ));
+ JavaPairRDD<String, Integer> quantities = sc.parallelizePairs(Arrays.asList(
+ new Tuple2<String, Integer>("Oranges", 21),
+ new Tuple2<String, Integer>("Apples", 42)
+ ));
+ JavaPairRDD<String, String> countries = sc.parallelizePairs(Arrays.asList(
+ new Tuple2<String, String>("Oranges", "BR"),
+ new Tuple2<String, String>("Apples", "US")
+ ));
+
+ JavaPairRDD<String, Tuple4<Iterable<String>, Iterable<Integer>, Iterable<Integer>, Iterable<String>>> cogrouped =
+ categories.cogroup(prices, quantities, countries);
+ Assert.assertEquals("[Fruit, Citrus]",
+ Iterables.toString(cogrouped.lookup("Oranges").get(0)._1()));
+ Assert.assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2()));
+ Assert.assertEquals("[42]", Iterables.toString(cogrouped.lookup("Apples").get(0)._3()));
+ Assert.assertEquals("[BR]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._4()));
+
+ cogrouped.collect();
+ }
+
+ @SuppressWarnings("unchecked")
+ @Test
public void leftOuterJoin() {
JavaPairRDD<Integer, Integer> rdd1 = sc.parallelizePairs(Arrays.asList(
new Tuple2<Integer, Integer>(1, 1),
diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
index 0b9004448a..447e38ec9d 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
@@ -249,6 +249,39 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
))
}
+ test("groupWith3") {
+ val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+ val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
+ val rdd3 = sc.parallelize(Array((1, 'a'), (3, 'b'), (4, 'c'), (4, 'd')))
+ val joined = rdd1.groupWith(rdd2, rdd3).collect()
+ assert(joined.size === 4)
+ val joinedSet = joined.map(x => (x._1,
+ (x._2._1.toList, x._2._2.toList, x._2._3.toList))).toSet
+ assert(joinedSet === Set(
+ (1, (List(1, 2), List('x'), List('a'))),
+ (2, (List(1), List('y', 'z'), List())),
+ (3, (List(1), List(), List('b'))),
+ (4, (List(), List('w'), List('c', 'd')))
+ ))
+ }
+
+ test("groupWith4") {
+ val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+ val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
+ val rdd3 = sc.parallelize(Array((1, 'a'), (3, 'b'), (4, 'c'), (4, 'd')))
+ val rdd4 = sc.parallelize(Array((2, '@')))
+ val joined = rdd1.groupWith(rdd2, rdd3, rdd4).collect()
+ assert(joined.size === 4)
+ val joinedSet = joined.map(x => (x._1,
+ (x._2._1.toList, x._2._2.toList, x._2._3.toList, x._2._4.toList))).toSet
+ assert(joinedSet === Set(
+ (1, (List(1, 2), List('x'), List('a'), List())),
+ (2, (List(1), List('y', 'z'), List(), List('@'))),
+ (3, (List(1), List(), List('b'), List())),
+ (4, (List(), List('w'), List('c', 'd'), List()))
+ ))
+ }
+
test("zero-partition RDD") {
val emptyDir = Files.createTempDir()
emptyDir.deleteOnExit()
diff --git a/python/pyspark/join.py b/python/pyspark/join.py
index 6f94d26ef8..5f3a7e71f7 100644
--- a/python/pyspark/join.py
+++ b/python/pyspark/join.py
@@ -79,15 +79,15 @@ def python_left_outer_join(rdd, other, numPartitions):
return _do_python_join(rdd, other, numPartitions, dispatch)
-def python_cogroup(rdd, other, numPartitions):
- vs = rdd.map(lambda (k, v): (k, (1, v)))
- ws = other.map(lambda (k, v): (k, (2, v)))
+def python_cogroup(rdds, numPartitions):
+ def make_mapper(i):
+ return lambda (k, v): (k, (i, v))
+ vrdds = [rdd.map(make_mapper(i)) for i, rdd in enumerate(rdds)]
+ union_vrdds = reduce(lambda acc, other: acc.union(other), vrdds)
+ rdd_len = len(vrdds)
def dispatch(seq):
- vbuf, wbuf = [], []
+ bufs = [[] for i in range(rdd_len)]
for (n, v) in seq:
- if n == 1:
- vbuf.append(v)
- elif n == 2:
- wbuf.append(v)
- return (ResultIterable(vbuf), ResultIterable(wbuf))
- return vs.union(ws).groupByKey(numPartitions).mapValues(dispatch)
+ bufs[n].append(v)
+ return tuple(map(ResultIterable, bufs))
+ return union_vrdds.groupByKey(numPartitions).mapValues(dispatch)
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 62a95c8467..1d55c35a8b 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -1233,7 +1233,7 @@ class RDD(object):
combiners[k] = mergeCombiners(combiners[k], v)
return combiners.iteritems()
return shuffled.mapPartitions(_mergeCombiners)
-
+
def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None):
"""
Aggregate the values of each key, using given combine functions and a neutral "zero value".
@@ -1245,7 +1245,7 @@ class RDD(object):
"""
def createZero():
return copy.deepcopy(zeroValue)
-
+
return self.combineByKey(lambda v: seqFunc(createZero(), v), seqFunc, combFunc, numPartitions)
def foldByKey(self, zeroValue, func, numPartitions=None):
@@ -1323,12 +1323,20 @@ class RDD(object):
map_values_fn = lambda (k, v): (k, f(v))
return self.map(map_values_fn, preservesPartitioning=True)
- # TODO: support varargs cogroup of several RDDs.
- def groupWith(self, other):
+ def groupWith(self, other, *others):
"""
- Alias for cogroup.
+ Alias for cogroup but with support for multiple RDDs.
+
+ >>> w = sc.parallelize([("a", 5), ("b", 6)])
+ >>> x = sc.parallelize([("a", 1), ("b", 4)])
+ >>> y = sc.parallelize([("a", 2)])
+ >>> z = sc.parallelize([("b", 42)])
+ >>> map((lambda (x,y): (x, (list(y[0]), list(y[1]), list(y[2]), list(y[3])))), \
+ sorted(list(w.groupWith(x, y, z).collect())))
+ [('a', ([5], [1], [2], [])), ('b', ([6], [4], [], [42]))]
+
"""
- return self.cogroup(other)
+ return python_cogroup((self, other) + others, numPartitions=None)
# TODO: add variant with custom parittioner
def cogroup(self, other, numPartitions=None):
@@ -1342,7 +1350,7 @@ class RDD(object):
>>> map((lambda (x,y): (x, (list(y[0]), list(y[1])))), sorted(list(x.cogroup(y).collect())))
[('a', ([1], [2])), ('b', ([4], []))]
"""
- return python_cogroup(self, other, numPartitions)
+ return python_cogroup((self, other), numPartitions)
def subtractByKey(self, other, numPartitions=None):
"""