diff options
Diffstat (limited to 'core')
-rw-r--r-- | core/src/main/scala/spark/api/java/JavaRDDLike.scala | 15 | ||||
-rw-r--r-- | core/src/main/scala/spark/api/java/function/FlatMapFunction2.scala | 11 | ||||
-rw-r--r-- | core/src/test/scala/spark/JavaAPISuite.java | 26 |
3 files changed, 52 insertions, 0 deletions
diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala index d884529d7a..9b74d1226f 100644 --- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala @@ -182,6 +182,21 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { JavaPairRDD.fromRDD(rdd.zip(other.rdd)(other.classManifest))(classManifest, other.classManifest) } + /** + * Zip this RDD's partitions with one (or more) RDD(s) and return a new RDD by + * applying a function to the zipped partitions. Assumes that all the RDDs have the + * *same number of partitions*, but does *not* require them to have the same number + * of elements in each partition. + */ + def zipPartitions[U, V]( + f: FlatMapFunction2[java.util.Iterator[T], java.util.Iterator[U], V], + other: JavaRDDLike[U, _]): JavaRDD[V] = { + def fn = (x: Iterator[T], y: Iterator[U]) => asScalaIterator( + f.apply(asJavaIterator(x), asJavaIterator(y)).iterator()) + JavaRDD.fromRDD( + rdd.zipPartitions(fn, other.rdd)(other.classManifest, f.elementType()))(f.elementType()) + } + // Actions (launch a job to return a value to the user program) /** diff --git a/core/src/main/scala/spark/api/java/function/FlatMapFunction2.scala b/core/src/main/scala/spark/api/java/function/FlatMapFunction2.scala new file mode 100644 index 0000000000..6044043add --- /dev/null +++ b/core/src/main/scala/spark/api/java/function/FlatMapFunction2.scala @@ -0,0 +1,11 @@ +package spark.api.java.function + +/** + * A function that takes two inputs and returns zero or more output records. + */ +abstract class FlatMapFunction2[A, B, C] extends Function2[A, B, java.lang.Iterable[C]] { + @throws(classOf[Exception]) + def call(a: A, b:B) : java.lang.Iterable[C] + + def elementType() : ClassManifest[C] = ClassManifest.Any.asInstanceOf[ClassManifest[C]] +} diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index d3dcd3bbeb..93bb69b41c 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -633,6 +633,32 @@ public class JavaAPISuite implements Serializable { } @Test + public void zipPartitions() { + JavaRDD<Integer> rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6), 2); + JavaRDD<String> rdd2 = sc.parallelize(Arrays.asList("1", "2", "3", "4"), 2); + FlatMapFunction2<Iterator<Integer>, Iterator<String>, Integer> sizesFn = + new FlatMapFunction2<Iterator<Integer>, Iterator<String>, Integer>() { + @Override + public Iterable<Integer> call(Iterator<Integer> i, Iterator<String> s) { + int sizeI = 0; + int sizeS = 0; + while (i.hasNext()) { + sizeI += 1; + i.next(); + } + while (s.hasNext()) { + sizeS += 1; + s.next(); + } + return Arrays.asList(sizeI, sizeS); + } + }; + + JavaRDD<Integer> sizes = rdd1.zipPartitions(sizesFn, rdd2); + Assert.assertEquals("[3, 2, 3, 2]", sizes.collect().toString()); + } + + @Test public void accumulators() { JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); |