diff options
8 files changed, 152 insertions, 5 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index 8e8f7f6c4f..79e4ebf2db 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -32,7 +32,8 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.StatCounter import org.apache.spark.util.Utils -class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[JDouble, JavaDoubleRDD] { +class JavaDoubleRDD(val srdd: RDD[scala.Double]) + extends AbstractJavaRDDLike[JDouble, JavaDoubleRDD] { override val classTag: ClassTag[JDouble] = implicitly[ClassTag[JDouble]] diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 7af3538262..4eadc9a856 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -44,7 +44,7 @@ import org.apache.spark.util.Utils class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) (implicit val kClassTag: ClassTag[K], implicit val vClassTag: ClassTag[V]) - extends JavaRDDLike[(K, V), JavaPairRDD[K, V]] { + extends AbstractJavaRDDLike[(K, V), JavaPairRDD[K, V]] { override def wrapRDD(rdd: RDD[(K, V)]): JavaPairRDD[K, V] = JavaPairRDD.fromRDD(rdd) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index 86fb374bef..645dc3bfb6 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -30,7 +30,7 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) - extends JavaRDDLike[T, JavaRDD[T]] { + extends AbstractJavaRDDLike[T, JavaRDD[T]] { override def wrapRDD(rdd: RDD[T]): JavaRDD[T] = JavaRDD.fromRDD(rdd) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 0f91c942ec..8da42934a7 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -39,6 +39,14 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils /** + * As a workaround for https://issues.scala-lang.org/browse/SI-8905, implementations + * of JavaRDDLike should extend this dummy abstract class instead of directly inheriting + * from the trait. See SPARK-3266 for additional details. + */ +private[spark] abstract class AbstractJavaRDDLike[T, This <: JavaRDDLike[T, This]] + extends JavaRDDLike[T, This] + +/** * Defines operations common to several Java RDD implementations. * Note that this trait is not intended to be implemented by user code. */ diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index b16a1e9460..dc9f77360d 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -268,6 +268,22 @@ public class JavaAPISuite implements Serializable { } @Test + public void foreachPartition() { + final Accumulator<Integer> accum = sc.accumulator(0); + JavaRDD<String> rdd = sc.parallelize(Arrays.asList("Hello", "World")); + rdd.foreachPartition(new VoidFunction<Iterator<String>>() { + @Override + public void call(Iterator<String> iter) throws IOException { + while (iter.hasNext()) { + iter.next(); + accum.add(1); + } + } + }); + Assert.assertEquals(2, accum.value().intValue()); + } + + @Test public void toLocalIterator() { List<Integer> correct = Arrays.asList(1, 2, 3, 4); JavaRDD<Integer> rdd = sc.parallelize(correct); @@ -658,6 +674,13 @@ public class JavaAPISuite implements Serializable { } @Test + public void toArray() { + JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3)); + List<Integer> list = rdd.toArray(); + Assert.assertEquals(Arrays.asList(1, 2, 3), list); + } + + @Test public void cartesian() { JavaDoubleRDD doubleRDD = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0)); JavaRDD<String> stringRDD = sc.parallelize(Arrays.asList("Hello", "World")); @@ -710,6 +733,80 @@ public class JavaAPISuite implements Serializable { Assert.assertArrayEquals(expected_counts, histogram); } + private static class DoubleComparator implements Comparator<Double>, Serializable { + public int compare(Double o1, Double o2) { + return o1.compareTo(o2); + } + } + + @Test + public void max() { + JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); + double max = rdd.max(new DoubleComparator()); + Assert.assertEquals(4.0, max, 0.001); + } + + @Test + public void min() { + JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); + double max = rdd.min(new DoubleComparator()); + Assert.assertEquals(1.0, max, 0.001); + } + + @Test + public void takeOrdered() { + JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); + Assert.assertEquals(Arrays.asList(1.0, 2.0), rdd.takeOrdered(2, new DoubleComparator())); + Assert.assertEquals(Arrays.asList(1.0, 2.0), rdd.takeOrdered(2)); + } + + @Test + public void top() { + JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4)); + List<Integer> top2 = rdd.top(2); + Assert.assertEquals(Arrays.asList(4, 3), top2); + } + + private static class AddInts implements Function2<Integer, Integer, Integer> { + @Override + public Integer call(Integer a, Integer b) { + return a + b; + } + } + + @Test + public void reduce() { + JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4)); + int sum = rdd.reduce(new AddInts()); + Assert.assertEquals(10, sum); + } + + @Test + public void reduceOnJavaDoubleRDD() { + JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); + double sum = rdd.reduce(new Function2<Double, Double, Double>() { + @Override + public Double call(Double v1, Double v2) throws Exception { + return v1 + v2; + } + }); + Assert.assertEquals(10.0, sum, 0.001); + } + + @Test + public void fold() { + JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4)); + int sum = rdd.fold(0, new AddInts()); + Assert.assertEquals(10, sum); + } + + @Test + public void aggregate() { + JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4)); + int sum = rdd.aggregate(0, new AddInts(), new AddInts()); + Assert.assertEquals(10, sum); + } + @Test public void map() { JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); @@ -826,6 +923,25 @@ public class JavaAPISuite implements Serializable { Assert.assertEquals("[3, 7]", partitionSums.collect().toString()); } + + @Test + public void mapPartitionsWithIndex() { + JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2); + JavaRDD<Integer> partitionSums = rdd.mapPartitionsWithIndex( + new Function2<Integer, Iterator<Integer>, Iterator<Integer>>() { + @Override + public Iterator<Integer> call(Integer index, Iterator<Integer> iter) throws Exception { + int sum = 0; + while (iter.hasNext()) { + sum += iter.next(); + } + return Collections.singletonList(sum).iterator(); + } + }, false); + Assert.assertEquals("[3, 7]", partitionSums.collect().toString()); + } + + @Test public void repartition() { // Shrinking number of partitions @@ -1513,6 +1629,19 @@ public class JavaAPISuite implements Serializable { } @Test + public void takeAsync() throws Exception { + List<Integer> data = Arrays.asList(1, 2, 3, 4, 5); + JavaRDD<Integer> rdd = sc.parallelize(data, 1); + JavaFutureAction<List<Integer>> future = rdd.takeAsync(1); + List<Integer> result = future.get(); + Assert.assertEquals(1, result.size()); + Assert.assertEquals((Integer) 1, result.get(0)); + Assert.assertFalse(future.isCancelled()); + Assert.assertTrue(future.isDone()); + Assert.assertEquals(1, future.jobIds().size()); + } + + @Test public void foreachAsync() throws Exception { List<Integer> data = Arrays.asList(1, 2, 3, 4, 5); JavaRDD<Integer> rdd = sc.parallelize(data, 1); diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala index 505e4431e4..01cdcb0574 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala @@ -36,7 +36,7 @@ import org.apache.spark.streaming.dstream.DStream * [[org.apache.spark.streaming.api.java.JavaPairDStream]]. */ class JavaDStream[T](val dstream: DStream[T])(implicit val classTag: ClassTag[T]) - extends JavaDStreamLike[T, JavaDStream[T], JavaRDD[T]] { + extends AbstractJavaDStreamLike[T, JavaDStream[T], JavaRDD[T]] { override def wrapRDD(rdd: RDD[T]): JavaRDD[T] = JavaRDD.fromRDD(rdd) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala index c382a12f4d..2eabdd9387 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -34,6 +34,15 @@ import org.apache.spark.streaming._ import org.apache.spark.streaming.api.java.JavaDStream._ import org.apache.spark.streaming.dstream.DStream +/** + * As a workaround for https://issues.scala-lang.org/browse/SI-8905, implementations + * of JavaDStreamLike should extend this dummy abstract class instead of directly inheriting + * from the trait. See SPARK-3266 for additional details. + */ +private[streaming] +abstract class AbstractJavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], + R <: JavaRDDLike[T, R]] extends JavaDStreamLike[T, This, R] + trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T, R]] extends Serializable { implicit val classTag: ClassTag[T] diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala index bd01789b61..f94f2d0e8b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -45,7 +45,7 @@ import org.apache.spark.streaming.dstream.DStream class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( implicit val kManifest: ClassTag[K], implicit val vManifest: ClassTag[V]) - extends JavaDStreamLike[(K, V), JavaPairDStream[K, V], JavaPairRDD[K, V]] { + extends AbstractJavaDStreamLike[(K, V), JavaPairDStream[K, V], JavaPairRDD[K, V]] { override def wrapRDD(rdd: RDD[(K, V)]): JavaPairRDD[K, V] = JavaPairRDD.fromRDD(rdd) |