aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMark Hamstra <markhamstra@gmail.com>2013-03-14 13:58:37 -0700
committerMark Hamstra <markhamstra@gmail.com>2013-03-14 13:58:37 -0700
commit16a4ca45373cd4b75032f88668610d9b693fb4b3 (patch)
treeaf70d212487d5805d52dec482a926fbe6cb8fb55
parentb1422cbdd59569261c9df84034a5a03833ec3996 (diff)
downloadspark-16a4ca45373cd4b75032f88668610d9b693fb4b3.tar.gz
spark-16a4ca45373cd4b75032f88668610d9b693fb4b3.tar.bz2
spark-16a4ca45373cd4b75032f88668610d9b693fb4b3.zip
restrict V type of foldByKey in order to retain ClassManifest; added foldByKey to Java API and test
-rw-r--r--core/src/main/scala/spark/PairRDDFunctions.scala4
-rw-r--r--core/src/main/scala/spark/api/java/JavaPairRDD.scala6
-rw-r--r--core/src/test/scala/spark/JavaAPISuite.java24
3 files changed, 31 insertions, 3 deletions
diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala
index a6e00c3a84..0fde902261 100644
--- a/core/src/main/scala/spark/PairRDDFunctions.scala
+++ b/core/src/main/scala/spark/PairRDDFunctions.scala
@@ -91,8 +91,8 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
/**
* Merge the values for each key using an associative function and a neutral "zero value".
*/
- def foldByKey[V1 >: V](zeroValue: V1)(op: (V1, V1) => V1): RDD[(K, V1)] = {
- groupByKey.mapValues(seq => seq.fold(zeroValue)(op))
+ def foldByKey(zeroValue: V)(op: (V, V) => V): RDD[(K, V)] = {
+ groupByKey.mapValues(seq => seq.fold[V](zeroValue)(op))
}
/**
diff --git a/core/src/main/scala/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/spark/api/java/JavaPairRDD.scala
index c1bd13c49a..1e1c910202 100644
--- a/core/src/main/scala/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/spark/api/java/JavaPairRDD.scala
@@ -161,6 +161,12 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
rdd.countByKeyApprox(timeout, confidence).map(mapAsJavaMap)
/**
+ * Merge the values for each key using an associative function and a neutral "zero value".
+ */
+ def foldByKey(zeroValue: V, func: JFunction2[V, V, V]): JavaPairRDD[K, V] =
+ fromRDD(rdd.foldByKey(zeroValue)(func))
+
+ /**
* Merge the values for each key using an associative reduce function. This will also perform
* the merging locally on each mapper before sending results to a reducer, similarly to a
* "combiner" in MapReduce. Output will be hash-partitioned with numPartitions partitions.
diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java
index 26e3ab72c0..b83076b929 100644
--- a/core/src/test/scala/spark/JavaAPISuite.java
+++ b/core/src/test/scala/spark/JavaAPISuite.java
@@ -196,7 +196,29 @@ public class JavaAPISuite implements Serializable {
Assert.assertEquals(33, sum);
}
- @Test
+ @Test
+ public void foldByKey() {
+ List<Tuple2<Integer, Integer>> pairs = Arrays.asList(
+ new Tuple2<Integer, Integer>(2, 1),
+ new Tuple2<Integer, Integer>(2, 1),
+ new Tuple2<Integer, Integer>(1, 1),
+ new Tuple2<Integer, Integer>(3, 2),
+ new Tuple2<Integer, Integer>(3, 1)
+ );
+ JavaPairRDD<Integer, Integer> rdd = sc.parallelizePairs(pairs);
+ JavaPairRDD<Integer, Integer> sums = rdd.foldByKey(0,
+ new Function2<Integer, Integer, Integer>() {
+ @Override
+ public Integer call(Integer a, Integer b) {
+ return a + b;
+ }
+ });
+ Assert.assertEquals(1, sums.lookup(1).get(0).intValue());
+ Assert.assertEquals(2, sums.lookup(2).get(0).intValue());
+ Assert.assertEquals(3, sums.lookup(3).get(0).intValue());
+ }
+
+ @Test
public void reduceByKey() {
List<Tuple2<Integer, Integer>> pairs = Arrays.asList(
new Tuple2<Integer, Integer>(2, 1),