diff options
author | Mark Hamstra <markhamstra@gmail.com> | 2013-03-14 13:58:37 -0700 |
---|---|---|
committer | Mark Hamstra <markhamstra@gmail.com> | 2013-03-14 13:58:37 -0700 |
commit | 16a4ca45373cd4b75032f88668610d9b693fb4b3 (patch) | |
tree | af70d212487d5805d52dec482a926fbe6cb8fb55 | |
parent | b1422cbdd59569261c9df84034a5a03833ec3996 (diff) | |
download | spark-16a4ca45373cd4b75032f88668610d9b693fb4b3.tar.gz spark-16a4ca45373cd4b75032f88668610d9b693fb4b3.tar.bz2 spark-16a4ca45373cd4b75032f88668610d9b693fb4b3.zip |
restrict V type of foldByKey in order to retain ClassManifest; added foldByKey to Java API and test
-rw-r--r-- | core/src/main/scala/spark/PairRDDFunctions.scala | 4 | ||||
-rw-r--r-- | core/src/main/scala/spark/api/java/JavaPairRDD.scala | 6 | ||||
-rw-r--r-- | core/src/test/scala/spark/JavaAPISuite.java | 24 |
3 files changed, 31 insertions, 3 deletions
diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index a6e00c3a84..0fde902261 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -91,8 +91,8 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( /** * Merge the values for each key using an associative function and a neutral "zero value". */ - def foldByKey[V1 >: V](zeroValue: V1)(op: (V1, V1) => V1): RDD[(K, V1)] = { - groupByKey.mapValues(seq => seq.fold(zeroValue)(op)) + def foldByKey(zeroValue: V)(op: (V, V) => V): RDD[(K, V)] = { + groupByKey.mapValues(seq => seq.fold[V](zeroValue)(op)) } /** diff --git a/core/src/main/scala/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/spark/api/java/JavaPairRDD.scala index c1bd13c49a..1e1c910202 100644 --- a/core/src/main/scala/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/spark/api/java/JavaPairRDD.scala @@ -161,6 +161,12 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif rdd.countByKeyApprox(timeout, confidence).map(mapAsJavaMap) /** + * Merge the values for each key using an associative function and a neutral "zero value". + */ + def foldByKey(zeroValue: V, func: JFunction2[V, V, V]): JavaPairRDD[K, V] = + fromRDD(rdd.foldByKey(zeroValue)(func)) + + /** * Merge the values for each key using an associative reduce function. This will also perform * the merging locally on each mapper before sending results to a reducer, similarly to a * "combiner" in MapReduce. Output will be hash-partitioned with numPartitions partitions. diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index 26e3ab72c0..b83076b929 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -196,7 +196,29 @@ public class JavaAPISuite implements Serializable { Assert.assertEquals(33, sum); } - @Test + @Test + public void foldByKey() { + List<Tuple2<Integer, Integer>> pairs = Arrays.asList( + new Tuple2<Integer, Integer>(2, 1), + new Tuple2<Integer, Integer>(2, 1), + new Tuple2<Integer, Integer>(1, 1), + new Tuple2<Integer, Integer>(3, 2), + new Tuple2<Integer, Integer>(3, 1) + ); + JavaPairRDD<Integer, Integer> rdd = sc.parallelizePairs(pairs); + JavaPairRDD<Integer, Integer> sums = rdd.foldByKey(0, + new Function2<Integer, Integer, Integer>() { + @Override + public Integer call(Integer a, Integer b) { + return a + b; + } + }); + Assert.assertEquals(1, sums.lookup(1).get(0).intValue()); + Assert.assertEquals(2, sums.lookup(2).get(0).intValue()); + Assert.assertEquals(3, sums.lookup(3).get(0).intValue()); + } + + @Test public void reduceByKey() { List<Tuple2<Integer, Integer>> pairs = Arrays.asList( new Tuple2<Integer, Integer>(2, 1), |