aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala44
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala50
-rw-r--r--core/src/test/java/org/apache/spark/JavaAPISuite.java31
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala13
-rw-r--r--docs/programming-guide.md4
-rw-r--r--project/MimaExcludes.scala5
-rw-r--r--python/pyspark/rdd.py19
-rw-r--r--python/pyspark/tests.py15
8 files changed, 179 insertions, 2 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
index 7dcfbf741c..14fa9d8135 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
@@ -229,6 +229,50 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
rdd.countByKeyApprox(timeout, confidence).map(mapAsJavaMap)
/**
+ * Aggregate the values of each key, using given combine functions and a neutral "zero value".
+ * This function can return a different result type, U, than the type of the values in this RDD,
+ * V. Thus, we need one operation for merging a V into a U and one operation for merging two U's,
+ * as in scala.TraversableOnce. The former operation is used for merging values within a
+ * partition, and the latter is used for merging values between partitions. To avoid memory
+ * allocation, both of these functions are allowed to modify and return their first argument
+ * instead of creating a new U.
+ */
+ def aggregateByKey[U](zeroValue: U, partitioner: Partitioner, seqFunc: JFunction2[U, V, U],
+ combFunc: JFunction2[U, U, U]): JavaPairRDD[K, U] = {
+ implicit val ctag: ClassTag[U] = fakeClassTag
+ fromRDD(rdd.aggregateByKey(zeroValue, partitioner)(seqFunc, combFunc))
+ }
+
+ /**
+ * Aggregate the values of each key, using given combine functions and a neutral "zero value".
+ * This function can return a different result type, U, than the type of the values in this RDD,
+ * V. Thus, we need one operation for merging a V into a U and one operation for merging two U's,
+ * as in scala.TraversableOnce. The former operation is used for merging values within a
+ * partition, and the latter is used for merging values between partitions. To avoid memory
+ * allocation, both of these functions are allowed to modify and return their first argument
+ * instead of creating a new U.
+ */
+ def aggregateByKey[U](zeroValue: U, numPartitions: Int, seqFunc: JFunction2[U, V, U],
+ combFunc: JFunction2[U, U, U]): JavaPairRDD[K, U] = {
+ implicit val ctag: ClassTag[U] = fakeClassTag
+ fromRDD(rdd.aggregateByKey(zeroValue, numPartitions)(seqFunc, combFunc))
+ }
+
+ /**
+ * Aggregate the values of each key, using given combine functions and a neutral "zero value".
+ * This function can return a different result type, U, than the type of the values in this RDD,
+ * V. Thus, we need one operation for merging a V into a U and one operation for merging two U's.
+ * The former operation is used for merging values within a partition, and the latter is used for
+ * merging values between partitions. To avoid memory allocation, both of these functions are
+ * allowed to modify and return their first argument instead of creating a new U.
+ */
+ def aggregateByKey[U](zeroValue: U, seqFunc: JFunction2[U, V, U], combFunc: JFunction2[U, U, U]):
+ JavaPairRDD[K, U] = {
+ implicit val ctag: ClassTag[U] = fakeClassTag
+ fromRDD(rdd.aggregateByKey(zeroValue)(seqFunc, combFunc))
+ }
+
+ /**
* Merge the values for each key using an associative function and a neutral "zero value" which
* may be added to the result an arbitrary number of times, and must not change the result
* (e.g ., Nil for list concatenation, 0 for addition, or 1 for multiplication.).
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index 8909980957..b6ad9b6c3e 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -119,6 +119,56 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
}
/**
+ * Aggregate the values of each key, using given combine functions and a neutral "zero value".
+ * This function can return a different result type, U, than the type of the values in this RDD,
+ * V. Thus, we need one operation for merging a V into a U and one operation for merging two U's,
+ * as in scala.TraversableOnce. The former operation is used for merging values within a
+ * partition, and the latter is used for merging values between partitions. To avoid memory
+ * allocation, both of these functions are allowed to modify and return their first argument
+ * instead of creating a new U.
+ */
+ def aggregateByKey[U: ClassTag](zeroValue: U, partitioner: Partitioner)(seqOp: (U, V) => U,
+ combOp: (U, U) => U): RDD[(K, U)] = {
+ // Serialize the zero value to a byte array so that we can get a new clone of it on each key
+ val zeroBuffer = SparkEnv.get.closureSerializer.newInstance().serialize(zeroValue)
+ val zeroArray = new Array[Byte](zeroBuffer.limit)
+ zeroBuffer.get(zeroArray)
+
+ lazy val cachedSerializer = SparkEnv.get.closureSerializer.newInstance()
+ def createZero() = cachedSerializer.deserialize[U](ByteBuffer.wrap(zeroArray))
+
+ combineByKey[U]((v: V) => seqOp(createZero(), v), seqOp, combOp, partitioner)
+ }
+
+ /**
+ * Aggregate the values of each key, using given combine functions and a neutral "zero value".
+ * This function can return a different result type, U, than the type of the values in this RDD,
+ * V. Thus, we need one operation for merging a V into a U and one operation for merging two U's,
+ * as in scala.TraversableOnce. The former operation is used for merging values within a
+ * partition, and the latter is used for merging values between partitions. To avoid memory
+ * allocation, both of these functions are allowed to modify and return their first argument
+ * instead of creating a new U.
+ */
+ def aggregateByKey[U: ClassTag](zeroValue: U, numPartitions: Int)(seqOp: (U, V) => U,
+ combOp: (U, U) => U): RDD[(K, U)] = {
+ aggregateByKey(zeroValue, new HashPartitioner(numPartitions))(seqOp, combOp)
+ }
+
+ /**
+ * Aggregate the values of each key, using given combine functions and a neutral "zero value".
+ * This function can return a different result type, U, than the type of the values in this RDD,
+ * V. Thus, we need one operation for merging a V into a U and one operation for merging two U's,
+ * as in scala.TraversableOnce. The former operation is used for merging values within a
+ * partition, and the latter is used for merging values between partitions. To avoid memory
+ * allocation, both of these functions are allowed to modify and return their first argument
+ * instead of creating a new U.
+ */
+ def aggregateByKey[U: ClassTag](zeroValue: U)(seqOp: (U, V) => U,
+ combOp: (U, U) => U): RDD[(K, U)] = {
+ aggregateByKey(zeroValue, defaultPartitioner(self))(seqOp, combOp)
+ }
+
+ /**
* Merge the values for each key using an associative function and a neutral "zero value" which
* may be added to the result an arbitrary number of times, and must not change the result
* (e.g., Nil for list concatenation, 0 for addition, or 1 for multiplication.).
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index 50a6212911..ef41bfb88d 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -317,6 +317,37 @@ public class JavaAPISuite implements Serializable {
Assert.assertEquals(33, sum);
}
+ @Test
+ public void aggregateByKey() {
+ JavaPairRDD<Integer, Integer> pairs = sc.parallelizePairs(
+ Arrays.asList(
+ new Tuple2<Integer, Integer>(1, 1),
+ new Tuple2<Integer, Integer>(1, 1),
+ new Tuple2<Integer, Integer>(3, 2),
+ new Tuple2<Integer, Integer>(5, 1),
+ new Tuple2<Integer, Integer>(5, 3)), 2);
+
+ Map<Integer, Set<Integer>> sets = pairs.aggregateByKey(new HashSet<Integer>(),
+ new Function2<Set<Integer>, Integer, Set<Integer>>() {
+ @Override
+ public Set<Integer> call(Set<Integer> a, Integer b) {
+ a.add(b);
+ return a;
+ }
+ },
+ new Function2<Set<Integer>, Set<Integer>, Set<Integer>>() {
+ @Override
+ public Set<Integer> call(Set<Integer> a, Set<Integer> b) {
+ a.addAll(b);
+ return a;
+ }
+ }).collectAsMap();
+ Assert.assertEquals(3, sets.size());
+ Assert.assertEquals(new HashSet<Integer>(Arrays.asList(1)), sets.get(1));
+ Assert.assertEquals(new HashSet<Integer>(Arrays.asList(2)), sets.get(3));
+ Assert.assertEquals(new HashSet<Integer>(Arrays.asList(1, 3)), sets.get(5));
+ }
+
@SuppressWarnings("unchecked")
@Test
public void foldByKey() {
diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
index 9ddafc4518..0b9004448a 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
@@ -30,6 +30,19 @@ import org.apache.spark.SparkContext._
import org.apache.spark.{Partitioner, SharedSparkContext}
class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
+ test("aggregateByKey") {
+ val pairs = sc.parallelize(Array((1, 1), (1, 1), (3, 2), (5, 1), (5, 3)), 2)
+
+ val sets = pairs.aggregateByKey(new HashSet[Int]())(_ += _, _ ++= _).collect()
+ assert(sets.size === 3)
+ val valuesFor1 = sets.find(_._1 == 1).get._2
+ assert(valuesFor1.toList.sorted === List(1))
+ val valuesFor3 = sets.find(_._1 == 3).get._2
+ assert(valuesFor3.toList.sorted === List(2))
+ val valuesFor5 = sets.find(_._1 == 5).get._2
+ assert(valuesFor5.toList.sorted === List(1, 3))
+ }
+
test("groupByKey") {
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)))
val groups = pairs.groupByKey().collect()
diff --git a/docs/programming-guide.md b/docs/programming-guide.md
index 7989e02dfb..79784682bf 100644
--- a/docs/programming-guide.md
+++ b/docs/programming-guide.md
@@ -891,6 +891,10 @@ for details.
<td> When called on a dataset of (K, V) pairs, returns a dataset of (K, V) pairs where the values for each key are aggregated using the given reduce function. Like in <code>groupByKey</code>, the number of reduce tasks is configurable through an optional second argument. </td>
</tr>
<tr>
+ <td> <b>aggregateByKey</b>(<i>zeroValue</i>)(<i>seqOp</i>, <i>combOp</i>, [<i>numTasks</i>]) </td>
+ <td> When called on a dataset of (K, V) pairs, returns a dataset of (K, U) pairs where the values for each key are aggregated using the given combine functions and a neutral "zero" value. Allows an aggregated value type that is different than the input value type, while avoiding unnecessary allocations. Like in <code>groupByKey</code>, the number of reduce tasks is configurable through an optional second argument. </td>
+</tr>
+<tr>
<td> <b>sortByKey</b>([<i>ascending</i>], [<i>numTasks</i>]) </td>
<td> When called on a dataset of (K, V) pairs where K implements Ordered, returns a dataset of (K, V) pairs sorted by keys in ascending or descending order, as specified in the boolean <code>ascending</code> argument.</td>
</tr>
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index c80ab9a9f8..ee629794f6 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -52,7 +52,10 @@ object MimaExcludes {
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.api.java.JavaRDDLike.countApproxDistinct$default$1"),
ProblemFilters.exclude[MissingMethodProblem](
- "org.apache.spark.api.java.JavaDoubleRDD.countApproxDistinct$default$1")
+ "org.apache.spark.api.java.JavaDoubleRDD.countApproxDistinct$default$1"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.rdd.PairRDDFunctions.org$apache$spark$rdd$PairRDDFunctions$$"
+ + "createZero$1")
) ++
Seq( // Ignore some private methods in ALS.
ProblemFilters.exclude[MissingMethodProblem](
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 8a215fc511..735389c698 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -1178,6 +1178,20 @@ class RDD(object):
combiners[k] = mergeCombiners(combiners[k], v)
return combiners.iteritems()
return shuffled.mapPartitions(_mergeCombiners)
+
+ def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None):
+ """
+ Aggregate the values of each key, using given combine functions and a neutral "zero value".
+ This function can return a different result type, U, than the type of the values in this RDD,
+ V. Thus, we need one operation for merging a V into a U and one operation for merging two U's,
+ The former operation is used for merging values within a partition, and the latter is used
+ for merging values between partitions. To avoid memory allocation, both of these functions are
+ allowed to modify and return their first argument instead of creating a new U.
+ """
+ def createZero():
+ return copy.deepcopy(zeroValue)
+
+ return self.combineByKey(lambda v: seqFunc(createZero(), v), seqFunc, combFunc, numPartitions)
def foldByKey(self, zeroValue, func, numPartitions=None):
"""
@@ -1190,7 +1204,10 @@ class RDD(object):
>>> rdd.foldByKey(0, add).collect()
[('a', 2), ('b', 1)]
"""
- return self.combineByKey(lambda v: func(zeroValue, v), func, func, numPartitions)
+ def createZero():
+ return copy.deepcopy(zeroValue)
+
+ return self.combineByKey(lambda v: func(createZero(), v), func, func, numPartitions)
# TODO: support variant with custom partitioner
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 184ee810b8..c15bb45775 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -188,6 +188,21 @@ class TestRDDFunctions(PySparkTestCase):
os.unlink(tempFile.name)
self.assertRaises(Exception, lambda: filtered_data.count())
+ def testAggregateByKey(self):
+ data = self.sc.parallelize([(1, 1), (1, 1), (3, 2), (5, 1), (5, 3)], 2)
+ def seqOp(x, y):
+ x.add(y)
+ return x
+
+ def combOp(x, y):
+ x |= y
+ return x
+
+ sets = dict(data.aggregateByKey(set(), seqOp, combOp).collect())
+ self.assertEqual(3, len(sets))
+ self.assertEqual(set([1]), sets[1])
+ self.assertEqual(set([2]), sets[3])
+ self.assertEqual(set([1, 3]), sets[5])
class TestIO(PySparkTestCase):