aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/spark/PairRDDFunctions.scala27
-rw-r--r--core/src/main/scala/spark/api/java/JavaPairRDD.scala24
-rw-r--r--core/src/test/scala/spark/JavaAPISuite.java22
3 files changed, 73 insertions, 0 deletions
diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala
index 3d1b1ca268..07efba9e8d 100644
--- a/core/src/main/scala/spark/PairRDDFunctions.scala
+++ b/core/src/main/scala/spark/PairRDDFunctions.scala
@@ -89,6 +89,33 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
}
/**
+ * Merge the values for each key using an associative function and a neutral "zero value" which may
+ * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for
+ * list concatenation, 0 for addition, or 1 for multiplication.).
+ */
+ def foldByKey(zeroValue: V, partitioner: Partitioner)(func: (V, V) => V): RDD[(K, V)] = {
+ combineByKey[V]({v: V => func(zeroValue, v)}, func, func, partitioner)
+ }
+
+ /**
+ * Merge the values for each key using an associative function and a neutral "zero value" which may
+ * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for
+ * list concatenation, 0 for addition, or 1 for multiplication.).
+ */
+ def foldByKey(zeroValue: V, numPartitions: Int)(func: (V, V) => V): RDD[(K, V)] = {
+ foldByKey(zeroValue, new HashPartitioner(numPartitions))(func)
+ }
+
+ /**
+ * Merge the values for each key using an associative function and a neutral "zero value" which may
+ * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for
+ * list concatenation, 0 for addition, or 1 for multiplication.).
+ */
+ def foldByKey(zeroValue: V)(func: (V, V) => V): RDD[(K, V)] = {
+ foldByKey(zeroValue, defaultPartitioner(self))(func)
+ }
+
+ /**
* Merge the values for each key using an associative reduce function. This will also perform
* the merging locally on each mapper before sending results to a reducer, similarly to a
* "combiner" in MapReduce.
diff --git a/core/src/main/scala/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/spark/api/java/JavaPairRDD.scala
index c1bd13c49a..49aaabf835 100644
--- a/core/src/main/scala/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/spark/api/java/JavaPairRDD.scala
@@ -161,6 +161,30 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
rdd.countByKeyApprox(timeout, confidence).map(mapAsJavaMap)
/**
+ * Merge the values for each key using an associative function and a neutral "zero value" which may
+ * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for
+ * list concatenation, 0 for addition, or 1 for multiplication.).
+ */
+ def foldByKey(zeroValue: V, partitioner: Partitioner, func: JFunction2[V, V, V]): JavaPairRDD[K, V] =
+ fromRDD(rdd.foldByKey(zeroValue, partitioner)(func))
+
+ /**
+ * Merge the values for each key using an associative function and a neutral "zero value" which may
+ * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for
+ * list concatenation, 0 for addition, or 1 for multiplication.).
+ */
+ def foldByKey(zeroValue: V, numPartitions: Int, func: JFunction2[V, V, V]): JavaPairRDD[K, V] =
+ fromRDD(rdd.foldByKey(zeroValue, numPartitions)(func))
+
+ /**
+ * Merge the values for each key using an associative function and a neutral "zero value" which may
+ * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for
+ * list concatenation, 0 for addition, or 1 for multiplication.).
+ */
+ def foldByKey(zeroValue: V, func: JFunction2[V, V, V]): JavaPairRDD[K, V] =
+ fromRDD(rdd.foldByKey(zeroValue)(func))
+
+ /**
* Merge the values for each key using an associative reduce function. This will also perform
* the merging locally on each mapper before sending results to a reducer, similarly to a
* "combiner" in MapReduce. Output will be hash-partitioned with numPartitions partitions.
diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java
index 26e3ab72c0..d3dcd3bbeb 100644
--- a/core/src/test/scala/spark/JavaAPISuite.java
+++ b/core/src/test/scala/spark/JavaAPISuite.java
@@ -197,6 +197,28 @@ public class JavaAPISuite implements Serializable {
}
@Test
+ public void foldByKey() {
+ List<Tuple2<Integer, Integer>> pairs = Arrays.asList(
+ new Tuple2<Integer, Integer>(2, 1),
+ new Tuple2<Integer, Integer>(2, 1),
+ new Tuple2<Integer, Integer>(1, 1),
+ new Tuple2<Integer, Integer>(3, 2),
+ new Tuple2<Integer, Integer>(3, 1)
+ );
+ JavaPairRDD<Integer, Integer> rdd = sc.parallelizePairs(pairs);
+ JavaPairRDD<Integer, Integer> sums = rdd.foldByKey(0,
+ new Function2<Integer, Integer, Integer>() {
+ @Override
+ public Integer call(Integer a, Integer b) {
+ return a + b;
+ }
+ });
+ Assert.assertEquals(1, sums.lookup(1).get(0).intValue());
+ Assert.assertEquals(2, sums.lookup(2).get(0).intValue());
+ Assert.assertEquals(3, sums.lookup(3).get(0).intValue());
+ }
+
+ @Test
public void reduceByKey() {
List<Tuple2<Integer, Integer>> pairs = Arrays.asList(
new Tuple2<Integer, Integer>(2, 1),