diff options
author | mcheah <mcheah@palantir.com> | 2015-03-19 08:51:49 -0400 |
---|---|---|
committer | Sean Owen <sowen@cloudera.com> | 2015-03-19 08:51:49 -0400 |
commit | 3c4e486b9c8d3f256e801db7c176ab650c976135 (patch) | |
tree | 2069a9298b1751278f08fb9ca25297cf2e1b9cc9 /core/src/test/java | |
parent | 797f8a000773d848fa52c7fe2eb1b5e5e7f6c55a (diff) | |
download | spark-3c4e486b9c8d3f256e801db7c176ab650c976135.tar.gz spark-3c4e486b9c8d3f256e801db7c176ab650c976135.tar.bz2 spark-3c4e486b9c8d3f256e801db7c176ab650c976135.zip |
[SPARK-5843] [API] Allowing map-side combine to be specified in Java.
Specifically, when calling JavaPairRDD.combineByKey(), there is a new
six-parameter method that exposes the map-side-combine boolean as the
fifth parameter and the serializer as the sixth parameter.
Author: mcheah <mcheah@palantir.com>
Closes #4634 from mccheah/pair-rdd-map-side-combine and squashes the following commits:
5c58319 [mcheah] Fixing compiler errors.
3ce7deb [mcheah] Addressing style and documentation comments.
7455c7a [mcheah] Allowing Java combineByKey to specify Serializer as well.
6ddd729 [mcheah] [SPARK-5843] Allowing map-side combine to be specified in Java.
Diffstat (limited to 'core/src/test/java')
-rw-r--r-- | core/src/test/java/org/apache/spark/JavaAPISuite.java | 53 |
1 files changed, 50 insertions, 3 deletions
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 8ec54360ca..d4b5bb5191 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -24,11 +24,12 @@ import java.net.URI; import java.util.*; import java.util.concurrent.*; -import org.apache.spark.input.PortableDataStream; +import scala.collection.JavaConversions; import scala.Tuple2; import scala.Tuple3; import scala.Tuple4; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; import com.google.common.collect.Iterators; import com.google.common.collect.Lists; @@ -51,8 +52,11 @@ import org.junit.Test; import org.apache.spark.api.java.*; import org.apache.spark.api.java.function.*; import org.apache.spark.executor.TaskMetrics; +import org.apache.spark.input.PortableDataStream; import org.apache.spark.partial.BoundedDouble; import org.apache.spark.partial.PartialResult; +import org.apache.spark.rdd.RDD; +import org.apache.spark.serializer.KryoSerializer; import org.apache.spark.storage.StorageLevel; import org.apache.spark.util.StatCounter; @@ -726,8 +730,8 @@ public class JavaAPISuite implements Serializable { Tuple2<double[], long[]> results = rdd.histogram(2); double[] expected_buckets = {1.0, 2.5, 4.0}; long[] expected_counts = {2, 2}; - Assert.assertArrayEquals(expected_buckets, results._1, 0.1); - Assert.assertArrayEquals(expected_counts, results._2); + Assert.assertArrayEquals(expected_buckets, results._1(), 0.1); + Assert.assertArrayEquals(expected_counts, results._2()); // Test with provided buckets long[] histogram = rdd.histogram(expected_buckets); Assert.assertArrayEquals(expected_counts, histogram); @@ -1424,6 +1428,49 @@ public class JavaAPISuite implements Serializable { Assert.assertEquals(Arrays.asList(1, 2, 3, 4, 5), recovered.collect()); } + @Test + public void combineByKey() { + JavaRDD<Integer> originalRDD = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6)); + Function<Integer, Integer> keyFunction = new Function<Integer, Integer>() { + @Override + public Integer call(Integer v1) throws Exception { + return v1 % 3; + } + }; + Function<Integer, Integer> createCombinerFunction = new Function<Integer, Integer>() { + @Override + public Integer call(Integer v1) throws Exception { + return v1; + } + }; + + Function2<Integer, Integer, Integer> mergeValueFunction = new Function2<Integer, Integer, Integer>() { + @Override + public Integer call(Integer v1, Integer v2) throws Exception { + return v1 + v2; + } + }; + + JavaPairRDD<Integer, Integer> combinedRDD = originalRDD.keyBy(keyFunction) + .combineByKey(createCombinerFunction, mergeValueFunction, mergeValueFunction); + Map<Integer, Integer> results = combinedRDD.collectAsMap(); + ImmutableMap<Integer, Integer> expected = ImmutableMap.of(0, 9, 1, 5, 2, 7); + Assert.assertEquals(expected, results); + + Partitioner defaultPartitioner = Partitioner.defaultPartitioner( + combinedRDD.rdd(), JavaConversions.asScalaBuffer(Lists.<RDD<?>>newArrayList())); + combinedRDD = originalRDD.keyBy(keyFunction) + .combineByKey( + createCombinerFunction, + mergeValueFunction, + mergeValueFunction, + defaultPartitioner, + false, + new KryoSerializer(new SparkConf())); + results = combinedRDD.collectAsMap(); + Assert.assertEquals(expected, results); + } + @SuppressWarnings("unchecked") @Test public void mapOnPairRDD() { |