diff options
author | Josh Rosen <joshrosen@eecs.berkeley.edu> | 2013-01-26 15:57:01 -0800 |
---|---|---|
committer | Josh Rosen <joshrosen@eecs.berkeley.edu> | 2013-01-26 16:13:18 -0800 |
commit | d49cf0e587b7cbbd31917d9bb69f98466feb0f9f (patch) | |
tree | 47fbfd578856e6a3337d5259c282f95f09ff485f | |
parent | 2435b7b5b77fe3e30fc8175983460ceb7a70632e (diff) | |
download | spark-d49cf0e587b7cbbd31917d9bb69f98466feb0f9f.tar.gz spark-d49cf0e587b7cbbd31917d9bb69f98466feb0f9f.tar.bz2 spark-d49cf0e587b7cbbd31917d9bb69f98466feb0f9f.zip |
Fix JavaRDDLike.flatMap(PairFlatMapFunction) (SPARK-668).
This workaround is easier than rewriting JavaRDDLike in Java.
-rw-r--r-- | core/src/main/scala/spark/api/java/JavaRDDLike.scala | 7 | ||||
-rw-r--r-- | core/src/main/scala/spark/api/java/PairFlatMapWorkaround.java | 20 | ||||
-rw-r--r-- | core/src/test/scala/spark/JavaAPISuite.java | 28 |
3 files changed, 51 insertions, 4 deletions
diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala index b3698ffa44..4c95c989b5 100644 --- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala @@ -12,7 +12,7 @@ import spark.storage.StorageLevel import com.google.common.base.Optional -trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { +trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends PairFlatMapWorkaround[T] { def wrapRDD(rdd: RDD[T]): This implicit val classManifest: ClassManifest[T] @@ -82,10 +82,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { } /** - * Return a new RDD by first applying a function to all elements of this - * RDD, and then flattening the results. + * Part of the workaround for SPARK-668; called in PairFlatMapWorkaround.java. */ - def flatMap[K, V](f: PairFlatMapFunction[T, K, V]): JavaPairRDD[K, V] = { + private[spark] def doFlatMap[K, V](f: PairFlatMapFunction[T, K, V]): JavaPairRDD[K, V] = { import scala.collection.JavaConverters._ def fn = (x: T) => f.apply(x).asScala def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K, V]]] diff --git a/core/src/main/scala/spark/api/java/PairFlatMapWorkaround.java b/core/src/main/scala/spark/api/java/PairFlatMapWorkaround.java new file mode 100644 index 0000000000..68b6fd6622 --- /dev/null +++ b/core/src/main/scala/spark/api/java/PairFlatMapWorkaround.java @@ -0,0 +1,20 @@ +package spark.api.java; + +import spark.api.java.JavaPairRDD; +import spark.api.java.JavaRDDLike; +import spark.api.java.function.PairFlatMapFunction; + +import java.io.Serializable; + +/** + * Workaround for SPARK-668. + */ +class PairFlatMapWorkaround<T> implements Serializable { + /** + * Return a new RDD by first applying a function to all elements of this + * RDD, and then flattening the results. + */ + public <K, V> JavaPairRDD<K, V> flatMap(PairFlatMapFunction<T, K, V> f) { + return ((JavaRDDLike <T, ?>) this).doFlatMap(f); + } +} diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index 01351de4ae..f50ba093e9 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -356,6 +356,34 @@ public class JavaAPISuite implements Serializable { } @Test + public void mapsFromPairsToPairs() { + List<Tuple2<Integer, String>> pairs = Arrays.asList( + new Tuple2<Integer, String>(1, "a"), + new Tuple2<Integer, String>(2, "aa"), + new Tuple2<Integer, String>(3, "aaa") + ); + JavaPairRDD<Integer, String> pairRDD = sc.parallelizePairs(pairs); + + // Regression test for SPARK-668: + JavaPairRDD<String, Integer> swapped = pairRDD.flatMap( + new PairFlatMapFunction<Tuple2<Integer, String>, String, Integer>() { + @Override + public Iterable<Tuple2<String, Integer>> call(Tuple2<Integer, String> item) throws Exception { + return Collections.singletonList(item.swap()); + } + }); + swapped.collect(); + + // There was never a bug here, but it's worth testing: + pairRDD.map(new PairFunction<Tuple2<Integer, String>, String, Integer>() { + @Override + public Tuple2<String, Integer> call(Tuple2<Integer, String> item) throws Exception { + return item.swap(); + } + }).collect(); + } + + @Test public void mapPartitions() { JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2); JavaRDD<Integer> partitionSums = rdd.mapPartitions( |