aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authormcheah <mcheah@palantir.com>2015-03-19 08:51:49 -0400
committerSean Owen <sowen@cloudera.com>2015-03-19 08:51:49 -0400
commit3c4e486b9c8d3f256e801db7c176ab650c976135 (patch)
tree2069a9298b1751278f08fb9ca25297cf2e1b9cc9 /core
parent797f8a000773d848fa52c7fe2eb1b5e5e7f6c55a (diff)
downloadspark-3c4e486b9c8d3f256e801db7c176ab650c976135.tar.gz
spark-3c4e486b9c8d3f256e801db7c176ab650c976135.tar.bz2
spark-3c4e486b9c8d3f256e801db7c176ab650c976135.zip
[SPARK-5843] [API] Allowing map-side combine to be specified in Java.
Specifically, when calling JavaPairRDD.combineByKey(), there is a new six-parameter method that exposes the map-side-combine boolean as the fifth parameter and the serializer as the sixth parameter. Author: mcheah <mcheah@palantir.com> Closes #4634 from mccheah/pair-rdd-map-side-combine and squashes the following commits: 5c58319 [mcheah] Fixing compiler errors. 3ce7deb [mcheah] Addressing style and documentation comments. 7455c7a [mcheah] Allowing Java combineByKey to specify Serializer as well. 6ddd729 [mcheah] [SPARK-5843] Allowing map-side combine to be specified in Java.
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala46
-rw-r--r--core/src/test/java/org/apache/spark/JavaAPISuite.java53
2 files changed, 87 insertions, 12 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
index 4eadc9a856..a023712be1 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
@@ -39,6 +39,7 @@ import org.apache.spark.api.java.function.{Function => JFunction, Function2 => J
import org.apache.spark.partial.{BoundedDouble, PartialResult}
import org.apache.spark.rdd.{OrderedRDDFunctions, RDD}
import org.apache.spark.rdd.RDD.rddToPairRDDFunctions
+import org.apache.spark.serializer.Serializer
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
@@ -227,24 +228,51 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
* - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list)
* - `mergeCombiners`, to combine two C's into a single one.
*
- * In addition, users can control the partitioning of the output RDD, and whether to perform
- * map-side aggregation (if a mapper can produce multiple items with the same key).
+ * In addition, users can control the partitioning of the output RDD, the serializer that is use
+ * for the shuffle, and whether to perform map-side aggregation (if a mapper can produce multiple
+ * items with the same key).
*/
def combineByKey[C](createCombiner: JFunction[V, C],
- mergeValue: JFunction2[C, V, C],
- mergeCombiners: JFunction2[C, C, C],
- partitioner: Partitioner): JavaPairRDD[K, C] = {
- implicit val ctag: ClassTag[C] = fakeClassTag
+ mergeValue: JFunction2[C, V, C],
+ mergeCombiners: JFunction2[C, C, C],
+ partitioner: Partitioner,
+ mapSideCombine: Boolean,
+ serializer: Serializer): JavaPairRDD[K, C] = {
+ implicit val ctag: ClassTag[C] = fakeClassTag
fromRDD(rdd.combineByKey(
createCombiner,
mergeValue,
mergeCombiners,
- partitioner
+ partitioner,
+ mapSideCombine,
+ serializer
))
}
/**
- * Simplified version of combineByKey that hash-partitions the output RDD.
+ * Generic function to combine the elements for each key using a custom set of aggregation
+ * functions. Turns a JavaPairRDD[(K, V)] into a result of type JavaPairRDD[(K, C)], for a
+ * "combined type" C * Note that V and C can be different -- for example, one might group an
+ * RDD of type (Int, Int) into an RDD of type (Int, List[Int]). Users provide three
+ * functions:
+ *
+ * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list)
+ * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list)
+ * - `mergeCombiners`, to combine two C's into a single one.
+ *
+ * In addition, users can control the partitioning of the output RDD. This method automatically
+ * uses map-side aggregation in shuffling the RDD.
+ */
+ def combineByKey[C](createCombiner: JFunction[V, C],
+ mergeValue: JFunction2[C, V, C],
+ mergeCombiners: JFunction2[C, C, C],
+ partitioner: Partitioner): JavaPairRDD[K, C] = {
+ combineByKey(createCombiner, mergeValue, mergeCombiners, partitioner, true, null)
+ }
+
+ /**
+ * Simplified version of combineByKey that hash-partitions the output RDD and uses map-side
+ * aggregation.
*/
def combineByKey[C](createCombiner: JFunction[V, C],
mergeValue: JFunction2[C, V, C],
@@ -488,7 +516,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
/**
* Simplified version of combineByKey that hash-partitions the resulting RDD using the existing
- * partitioner/parallelism level.
+ * partitioner/parallelism level and using map-side aggregation.
*/
def combineByKey[C](createCombiner: JFunction[V, C],
mergeValue: JFunction2[C, V, C],
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index 8ec54360ca..d4b5bb5191 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -24,11 +24,12 @@ import java.net.URI;
import java.util.*;
import java.util.concurrent.*;
-import org.apache.spark.input.PortableDataStream;
+import scala.collection.JavaConversions;
import scala.Tuple2;
import scala.Tuple3;
import scala.Tuple4;
+import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import com.google.common.collect.Iterators;
import com.google.common.collect.Lists;
@@ -51,8 +52,11 @@ import org.junit.Test;
import org.apache.spark.api.java.*;
import org.apache.spark.api.java.function.*;
import org.apache.spark.executor.TaskMetrics;
+import org.apache.spark.input.PortableDataStream;
import org.apache.spark.partial.BoundedDouble;
import org.apache.spark.partial.PartialResult;
+import org.apache.spark.rdd.RDD;
+import org.apache.spark.serializer.KryoSerializer;
import org.apache.spark.storage.StorageLevel;
import org.apache.spark.util.StatCounter;
@@ -726,8 +730,8 @@ public class JavaAPISuite implements Serializable {
Tuple2<double[], long[]> results = rdd.histogram(2);
double[] expected_buckets = {1.0, 2.5, 4.0};
long[] expected_counts = {2, 2};
- Assert.assertArrayEquals(expected_buckets, results._1, 0.1);
- Assert.assertArrayEquals(expected_counts, results._2);
+ Assert.assertArrayEquals(expected_buckets, results._1(), 0.1);
+ Assert.assertArrayEquals(expected_counts, results._2());
// Test with provided buckets
long[] histogram = rdd.histogram(expected_buckets);
Assert.assertArrayEquals(expected_counts, histogram);
@@ -1424,6 +1428,49 @@ public class JavaAPISuite implements Serializable {
Assert.assertEquals(Arrays.asList(1, 2, 3, 4, 5), recovered.collect());
}
+ @Test
+ public void combineByKey() {
+ JavaRDD<Integer> originalRDD = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6));
+ Function<Integer, Integer> keyFunction = new Function<Integer, Integer>() {
+ @Override
+ public Integer call(Integer v1) throws Exception {
+ return v1 % 3;
+ }
+ };
+ Function<Integer, Integer> createCombinerFunction = new Function<Integer, Integer>() {
+ @Override
+ public Integer call(Integer v1) throws Exception {
+ return v1;
+ }
+ };
+
+ Function2<Integer, Integer, Integer> mergeValueFunction = new Function2<Integer, Integer, Integer>() {
+ @Override
+ public Integer call(Integer v1, Integer v2) throws Exception {
+ return v1 + v2;
+ }
+ };
+
+ JavaPairRDD<Integer, Integer> combinedRDD = originalRDD.keyBy(keyFunction)
+ .combineByKey(createCombinerFunction, mergeValueFunction, mergeValueFunction);
+ Map<Integer, Integer> results = combinedRDD.collectAsMap();
+ ImmutableMap<Integer, Integer> expected = ImmutableMap.of(0, 9, 1, 5, 2, 7);
+ Assert.assertEquals(expected, results);
+
+ Partitioner defaultPartitioner = Partitioner.defaultPartitioner(
+ combinedRDD.rdd(), JavaConversions.asScalaBuffer(Lists.<RDD<?>>newArrayList()));
+ combinedRDD = originalRDD.keyBy(keyFunction)
+ .combineByKey(
+ createCombinerFunction,
+ mergeValueFunction,
+ mergeValueFunction,
+ defaultPartitioner,
+ false,
+ new KryoSerializer(new SparkConf()));
+ results = combinedRDD.collectAsMap();
+ Assert.assertEquals(expected, results);
+ }
+
@SuppressWarnings("unchecked")
@Test
public void mapOnPairRDD() {