aboutsummaryrefslogtreecommitdiff
path: root/core/src
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@eecs.berkeley.edu>2013-01-26 15:57:01 -0800
committerJosh Rosen <joshrosen@eecs.berkeley.edu>2013-01-26 16:13:18 -0800
commitd49cf0e587b7cbbd31917d9bb69f98466feb0f9f (patch)
tree47fbfd578856e6a3337d5259c282f95f09ff485f /core/src
parent2435b7b5b77fe3e30fc8175983460ceb7a70632e (diff)
downloadspark-d49cf0e587b7cbbd31917d9bb69f98466feb0f9f.tar.gz
spark-d49cf0e587b7cbbd31917d9bb69f98466feb0f9f.tar.bz2
spark-d49cf0e587b7cbbd31917d9bb69f98466feb0f9f.zip
Fix JavaRDDLike.flatMap(PairFlatMapFunction) (SPARK-668).
This workaround is easier than rewriting JavaRDDLike in Java.
Diffstat (limited to 'core/src')
-rw-r--r--core/src/main/scala/spark/api/java/JavaRDDLike.scala7
-rw-r--r--core/src/main/scala/spark/api/java/PairFlatMapWorkaround.java20
-rw-r--r--core/src/test/scala/spark/JavaAPISuite.java28
3 files changed, 51 insertions, 4 deletions
diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala
index b3698ffa44..4c95c989b5 100644
--- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala
@@ -12,7 +12,7 @@ import spark.storage.StorageLevel
import com.google.common.base.Optional
-trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
+trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends PairFlatMapWorkaround[T] {
def wrapRDD(rdd: RDD[T]): This
implicit val classManifest: ClassManifest[T]
@@ -82,10 +82,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
}
/**
- * Return a new RDD by first applying a function to all elements of this
- * RDD, and then flattening the results.
+ * Part of the workaround for SPARK-668; called in PairFlatMapWorkaround.java.
*/
- def flatMap[K, V](f: PairFlatMapFunction[T, K, V]): JavaPairRDD[K, V] = {
+ private[spark] def doFlatMap[K, V](f: PairFlatMapFunction[T, K, V]): JavaPairRDD[K, V] = {
import scala.collection.JavaConverters._
def fn = (x: T) => f.apply(x).asScala
def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K, V]]]
diff --git a/core/src/main/scala/spark/api/java/PairFlatMapWorkaround.java b/core/src/main/scala/spark/api/java/PairFlatMapWorkaround.java
new file mode 100644
index 0000000000..68b6fd6622
--- /dev/null
+++ b/core/src/main/scala/spark/api/java/PairFlatMapWorkaround.java
@@ -0,0 +1,20 @@
+package spark.api.java;
+
+import spark.api.java.JavaPairRDD;
+import spark.api.java.JavaRDDLike;
+import spark.api.java.function.PairFlatMapFunction;
+
+import java.io.Serializable;
+
+/**
+ * Workaround for SPARK-668.
+ */
+class PairFlatMapWorkaround<T> implements Serializable {
+ /**
+ * Return a new RDD by first applying a function to all elements of this
+ * RDD, and then flattening the results.
+ */
+ public <K, V> JavaPairRDD<K, V> flatMap(PairFlatMapFunction<T, K, V> f) {
+ return ((JavaRDDLike <T, ?>) this).doFlatMap(f);
+ }
+}
diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java
index 01351de4ae..f50ba093e9 100644
--- a/core/src/test/scala/spark/JavaAPISuite.java
+++ b/core/src/test/scala/spark/JavaAPISuite.java
@@ -356,6 +356,34 @@ public class JavaAPISuite implements Serializable {
}
@Test
+ public void mapsFromPairsToPairs() {
+ List<Tuple2<Integer, String>> pairs = Arrays.asList(
+ new Tuple2<Integer, String>(1, "a"),
+ new Tuple2<Integer, String>(2, "aa"),
+ new Tuple2<Integer, String>(3, "aaa")
+ );
+ JavaPairRDD<Integer, String> pairRDD = sc.parallelizePairs(pairs);
+
+ // Regression test for SPARK-668:
+ JavaPairRDD<String, Integer> swapped = pairRDD.flatMap(
+ new PairFlatMapFunction<Tuple2<Integer, String>, String, Integer>() {
+ @Override
+ public Iterable<Tuple2<String, Integer>> call(Tuple2<Integer, String> item) throws Exception {
+ return Collections.singletonList(item.swap());
+ }
+ });
+ swapped.collect();
+
+ // There was never a bug here, but it's worth testing:
+ pairRDD.map(new PairFunction<Tuple2<Integer, String>, String, Integer>() {
+ @Override
+ public Tuple2<String, Integer> call(Tuple2<Integer, String> item) throws Exception {
+ return item.swap();
+ }
+ }).collect();
+ }
+
+ @Test
public void mapPartitions() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2);
JavaRDD<Integer> partitionSums = rdd.mapPartitions(