aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorSandy Ryza <sandy@cloudera.com>2014-06-12 08:14:25 -0700
committerPatrick Wendell <pwendell@gmail.com>2014-06-12 08:14:25 -0700
commitce92a9c18f033ac9fa2f12143fab00a90e0f4577 (patch)
tree34507639ff2f876630c6868619844c52e13d3720 /core
parent43d53d51c9ee2626d9de91faa3b192979b86821d (diff)
downloadspark-ce92a9c18f033ac9fa2f12143fab00a90e0f4577.tar.gz
spark-ce92a9c18f033ac9fa2f12143fab00a90e0f4577.tar.bz2
spark-ce92a9c18f033ac9fa2f12143fab00a90e0f4577.zip
SPARK-554. Add aggregateByKey.
Author: Sandy Ryza <sandy@cloudera.com> Closes #705 from sryza/sandy-spark-554 and squashes the following commits: 2302b8f [Sandy Ryza] Add MIMA exclude f52e0ad [Sandy Ryza] Fix Python tests for real 2f3afa3 [Sandy Ryza] Fix Python test 0b735e9 [Sandy Ryza] Fix line lengths ae56746 [Sandy Ryza] Fix doc (replace T with V) c2be415 [Sandy Ryza] Java and Python aggregateByKey 23bf400 [Sandy Ryza] SPARK-554. Add aggregateByKey.
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala44
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala50
-rw-r--r--core/src/test/java/org/apache/spark/JavaAPISuite.java31
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala13
4 files changed, 138 insertions, 0 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
index 7dcfbf741c..14fa9d8135 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
@@ -229,6 +229,50 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
rdd.countByKeyApprox(timeout, confidence).map(mapAsJavaMap)
/**
+ * Aggregate the values of each key, using given combine functions and a neutral "zero value".
+ * This function can return a different result type, U, than the type of the values in this RDD,
+ * V. Thus, we need one operation for merging a V into a U and one operation for merging two U's,
+ * as in scala.TraversableOnce. The former operation is used for merging values within a
+ * partition, and the latter is used for merging values between partitions. To avoid memory
+ * allocation, both of these functions are allowed to modify and return their first argument
+ * instead of creating a new U.
+ */
+ def aggregateByKey[U](zeroValue: U, partitioner: Partitioner, seqFunc: JFunction2[U, V, U],
+ combFunc: JFunction2[U, U, U]): JavaPairRDD[K, U] = {
+ implicit val ctag: ClassTag[U] = fakeClassTag
+ fromRDD(rdd.aggregateByKey(zeroValue, partitioner)(seqFunc, combFunc))
+ }
+
+ /**
+ * Aggregate the values of each key, using given combine functions and a neutral "zero value".
+ * This function can return a different result type, U, than the type of the values in this RDD,
+ * V. Thus, we need one operation for merging a V into a U and one operation for merging two U's,
+ * as in scala.TraversableOnce. The former operation is used for merging values within a
+ * partition, and the latter is used for merging values between partitions. To avoid memory
+ * allocation, both of these functions are allowed to modify and return their first argument
+ * instead of creating a new U.
+ */
+ def aggregateByKey[U](zeroValue: U, numPartitions: Int, seqFunc: JFunction2[U, V, U],
+ combFunc: JFunction2[U, U, U]): JavaPairRDD[K, U] = {
+ implicit val ctag: ClassTag[U] = fakeClassTag
+ fromRDD(rdd.aggregateByKey(zeroValue, numPartitions)(seqFunc, combFunc))
+ }
+
+ /**
+ * Aggregate the values of each key, using given combine functions and a neutral "zero value".
+ * This function can return a different result type, U, than the type of the values in this RDD,
+ * V. Thus, we need one operation for merging a V into a U and one operation for merging two U's.
+ * The former operation is used for merging values within a partition, and the latter is used for
+ * merging values between partitions. To avoid memory allocation, both of these functions are
+ * allowed to modify and return their first argument instead of creating a new U.
+ */
+ def aggregateByKey[U](zeroValue: U, seqFunc: JFunction2[U, V, U], combFunc: JFunction2[U, U, U]):
+ JavaPairRDD[K, U] = {
+ implicit val ctag: ClassTag[U] = fakeClassTag
+ fromRDD(rdd.aggregateByKey(zeroValue)(seqFunc, combFunc))
+ }
+
+ /**
* Merge the values for each key using an associative function and a neutral "zero value" which
* may be added to the result an arbitrary number of times, and must not change the result
* (e.g ., Nil for list concatenation, 0 for addition, or 1 for multiplication.).
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index 8909980957..b6ad9b6c3e 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -119,6 +119,56 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
}
/**
+ * Aggregate the values of each key, using given combine functions and a neutral "zero value".
+ * This function can return a different result type, U, than the type of the values in this RDD,
+ * V. Thus, we need one operation for merging a V into a U and one operation for merging two U's,
+ * as in scala.TraversableOnce. The former operation is used for merging values within a
+ * partition, and the latter is used for merging values between partitions. To avoid memory
+ * allocation, both of these functions are allowed to modify and return their first argument
+ * instead of creating a new U.
+ */
+ def aggregateByKey[U: ClassTag](zeroValue: U, partitioner: Partitioner)(seqOp: (U, V) => U,
+ combOp: (U, U) => U): RDD[(K, U)] = {
+ // Serialize the zero value to a byte array so that we can get a new clone of it on each key
+ val zeroBuffer = SparkEnv.get.closureSerializer.newInstance().serialize(zeroValue)
+ val zeroArray = new Array[Byte](zeroBuffer.limit)
+ zeroBuffer.get(zeroArray)
+
+ lazy val cachedSerializer = SparkEnv.get.closureSerializer.newInstance()
+ def createZero() = cachedSerializer.deserialize[U](ByteBuffer.wrap(zeroArray))
+
+ combineByKey[U]((v: V) => seqOp(createZero(), v), seqOp, combOp, partitioner)
+ }
+
+ /**
+ * Aggregate the values of each key, using given combine functions and a neutral "zero value".
+ * This function can return a different result type, U, than the type of the values in this RDD,
+ * V. Thus, we need one operation for merging a V into a U and one operation for merging two U's,
+ * as in scala.TraversableOnce. The former operation is used for merging values within a
+ * partition, and the latter is used for merging values between partitions. To avoid memory
+ * allocation, both of these functions are allowed to modify and return their first argument
+ * instead of creating a new U.
+ */
+ def aggregateByKey[U: ClassTag](zeroValue: U, numPartitions: Int)(seqOp: (U, V) => U,
+ combOp: (U, U) => U): RDD[(K, U)] = {
+ aggregateByKey(zeroValue, new HashPartitioner(numPartitions))(seqOp, combOp)
+ }
+
+ /**
+ * Aggregate the values of each key, using given combine functions and a neutral "zero value".
+ * This function can return a different result type, U, than the type of the values in this RDD,
+ * V. Thus, we need one operation for merging a V into a U and one operation for merging two U's,
+ * as in scala.TraversableOnce. The former operation is used for merging values within a
+ * partition, and the latter is used for merging values between partitions. To avoid memory
+ * allocation, both of these functions are allowed to modify and return their first argument
+ * instead of creating a new U.
+ */
+ def aggregateByKey[U: ClassTag](zeroValue: U)(seqOp: (U, V) => U,
+ combOp: (U, U) => U): RDD[(K, U)] = {
+ aggregateByKey(zeroValue, defaultPartitioner(self))(seqOp, combOp)
+ }
+
+ /**
* Merge the values for each key using an associative function and a neutral "zero value" which
* may be added to the result an arbitrary number of times, and must not change the result
* (e.g., Nil for list concatenation, 0 for addition, or 1 for multiplication.).
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index 50a6212911..ef41bfb88d 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -317,6 +317,37 @@ public class JavaAPISuite implements Serializable {
Assert.assertEquals(33, sum);
}
+ @Test
+ public void aggregateByKey() {
+ JavaPairRDD<Integer, Integer> pairs = sc.parallelizePairs(
+ Arrays.asList(
+ new Tuple2<Integer, Integer>(1, 1),
+ new Tuple2<Integer, Integer>(1, 1),
+ new Tuple2<Integer, Integer>(3, 2),
+ new Tuple2<Integer, Integer>(5, 1),
+ new Tuple2<Integer, Integer>(5, 3)), 2);
+
+ Map<Integer, Set<Integer>> sets = pairs.aggregateByKey(new HashSet<Integer>(),
+ new Function2<Set<Integer>, Integer, Set<Integer>>() {
+ @Override
+ public Set<Integer> call(Set<Integer> a, Integer b) {
+ a.add(b);
+ return a;
+ }
+ },
+ new Function2<Set<Integer>, Set<Integer>, Set<Integer>>() {
+ @Override
+ public Set<Integer> call(Set<Integer> a, Set<Integer> b) {
+ a.addAll(b);
+ return a;
+ }
+ }).collectAsMap();
+ Assert.assertEquals(3, sets.size());
+ Assert.assertEquals(new HashSet<Integer>(Arrays.asList(1)), sets.get(1));
+ Assert.assertEquals(new HashSet<Integer>(Arrays.asList(2)), sets.get(3));
+ Assert.assertEquals(new HashSet<Integer>(Arrays.asList(1, 3)), sets.get(5));
+ }
+
@SuppressWarnings("unchecked")
@Test
public void foldByKey() {
diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
index 9ddafc4518..0b9004448a 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
@@ -30,6 +30,19 @@ import org.apache.spark.SparkContext._
import org.apache.spark.{Partitioner, SharedSparkContext}
class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
+ test("aggregateByKey") {
+ val pairs = sc.parallelize(Array((1, 1), (1, 1), (3, 2), (5, 1), (5, 3)), 2)
+
+ val sets = pairs.aggregateByKey(new HashSet[Int]())(_ += _, _ ++= _).collect()
+ assert(sets.size === 3)
+ val valuesFor1 = sets.find(_._1 == 1).get._2
+ assert(valuesFor1.toList.sorted === List(1))
+ val valuesFor3 = sets.find(_._1 == 3).get._2
+ assert(valuesFor3.toList.sorted === List(2))
+ val valuesFor5 = sets.find(_._1 == 5).get._2
+ assert(valuesFor5.toList.sorted === List(1, 3))
+ }
+
test("groupByKey") {
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)))
val groups = pairs.groupByKey().collect()