diff options
author | Josh Rosen <joshrosen@databricks.com> | 2015-03-17 09:18:57 -0700 |
---|---|---|
committer | Josh Rosen <joshrosen@databricks.com> | 2015-03-17 09:18:57 -0700 |
commit | 0f673c21f68ee3d5df3c01ae405709d3c1f4909b (patch) | |
tree | 3d637c661ea26d97ffa1754833f83d792ee15ef2 /core/src/test/java | |
parent | e9f22c61290348c58af54c0ae3c6226105126a8d (diff) | |
download | spark-0f673c21f68ee3d5df3c01ae405709d3c1f4909b.tar.gz spark-0f673c21f68ee3d5df3c01ae405709d3c1f4909b.tar.bz2 spark-0f673c21f68ee3d5df3c01ae405709d3c1f4909b.zip |
[SPARK-3266] Use intermediate abstract classes to fix type erasure issues in Java APIs
This PR addresses a Scala compiler bug ([SI-8905](https://issues.scala-lang.org/browse/SI-8905)) that was breaking some of the Spark Java APIs. In a nutshell, it seems that methods whose implementations are inherited from generic traits sometimes have their type parameters erased to Object. This was causing methods like `DoubleRDD.min()` to throw confusing NoSuchMethodErrors at runtime.
The fix implemented here is to introduce an intermediate layer of abstract classes and inherit from those instead of directly extends the `Java*Like` traits. This should not break binary compatibility.
I also improved the test coverage of the Java API, adding several new tests for methods that failed at runtime due to this bug.
Author: Josh Rosen <joshrosen@databricks.com>
Closes #5050 from JoshRosen/javardd-si-8905-fix and squashes the following commits:
2feb068 [Josh Rosen] Use intermediate abstract classes to work around SPARK-3266
d5f3e5d [Josh Rosen] Add failing regression tests for SPARK-3266
Diffstat (limited to 'core/src/test/java')
-rw-r--r-- | core/src/test/java/org/apache/spark/JavaAPISuite.java | 129 |
1 files changed, 129 insertions, 0 deletions
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 74e88c767e..8ec54360ca 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -268,6 +268,22 @@ public class JavaAPISuite implements Serializable { } @Test + public void foreachPartition() { + final Accumulator<Integer> accum = sc.accumulator(0); + JavaRDD<String> rdd = sc.parallelize(Arrays.asList("Hello", "World")); + rdd.foreachPartition(new VoidFunction<Iterator<String>>() { + @Override + public void call(Iterator<String> iter) throws IOException { + while (iter.hasNext()) { + iter.next(); + accum.add(1); + } + } + }); + Assert.assertEquals(2, accum.value().intValue()); + } + + @Test public void toLocalIterator() { List<Integer> correct = Arrays.asList(1, 2, 3, 4); JavaRDD<Integer> rdd = sc.parallelize(correct); @@ -658,6 +674,13 @@ public class JavaAPISuite implements Serializable { } @Test + public void toArray() { + JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3)); + List<Integer> list = rdd.toArray(); + Assert.assertEquals(Arrays.asList(1, 2, 3), list); + } + + @Test public void cartesian() { JavaDoubleRDD doubleRDD = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0)); JavaRDD<String> stringRDD = sc.parallelize(Arrays.asList("Hello", "World")); @@ -714,6 +737,80 @@ public class JavaAPISuite implements Serializable { sc.parallelizeDoubles(new ArrayList<Double>(0), 1).histogram(new double[]{0.0, 1.0})); } + private static class DoubleComparator implements Comparator<Double>, Serializable { + public int compare(Double o1, Double o2) { + return o1.compareTo(o2); + } + } + + @Test + public void max() { + JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); + double max = rdd.max(new DoubleComparator()); + Assert.assertEquals(4.0, max, 0.001); + } + + @Test + public void min() { + JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); + double max = rdd.min(new DoubleComparator()); + Assert.assertEquals(1.0, max, 0.001); + } + + @Test + public void takeOrdered() { + JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); + Assert.assertEquals(Arrays.asList(1.0, 2.0), rdd.takeOrdered(2, new DoubleComparator())); + Assert.assertEquals(Arrays.asList(1.0, 2.0), rdd.takeOrdered(2)); + } + + @Test + public void top() { + JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4)); + List<Integer> top2 = rdd.top(2); + Assert.assertEquals(Arrays.asList(4, 3), top2); + } + + private static class AddInts implements Function2<Integer, Integer, Integer> { + @Override + public Integer call(Integer a, Integer b) { + return a + b; + } + } + + @Test + public void reduce() { + JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4)); + int sum = rdd.reduce(new AddInts()); + Assert.assertEquals(10, sum); + } + + @Test + public void reduceOnJavaDoubleRDD() { + JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); + double sum = rdd.reduce(new Function2<Double, Double, Double>() { + @Override + public Double call(Double v1, Double v2) throws Exception { + return v1 + v2; + } + }); + Assert.assertEquals(10.0, sum, 0.001); + } + + @Test + public void fold() { + JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4)); + int sum = rdd.fold(0, new AddInts()); + Assert.assertEquals(10, sum); + } + + @Test + public void aggregate() { + JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4)); + int sum = rdd.aggregate(0, new AddInts(), new AddInts()); + Assert.assertEquals(10, sum); + } + @Test public void map() { JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); @@ -830,6 +927,25 @@ public class JavaAPISuite implements Serializable { Assert.assertEquals("[3, 7]", partitionSums.collect().toString()); } + + @Test + public void mapPartitionsWithIndex() { + JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2); + JavaRDD<Integer> partitionSums = rdd.mapPartitionsWithIndex( + new Function2<Integer, Iterator<Integer>, Iterator<Integer>>() { + @Override + public Iterator<Integer> call(Integer index, Iterator<Integer> iter) throws Exception { + int sum = 0; + while (iter.hasNext()) { + sum += iter.next(); + } + return Collections.singletonList(sum).iterator(); + } + }, false); + Assert.assertEquals("[3, 7]", partitionSums.collect().toString()); + } + + @Test public void repartition() { // Shrinking number of partitions @@ -1517,6 +1633,19 @@ public class JavaAPISuite implements Serializable { } @Test + public void takeAsync() throws Exception { + List<Integer> data = Arrays.asList(1, 2, 3, 4, 5); + JavaRDD<Integer> rdd = sc.parallelize(data, 1); + JavaFutureAction<List<Integer>> future = rdd.takeAsync(1); + List<Integer> result = future.get(); + Assert.assertEquals(1, result.size()); + Assert.assertEquals((Integer) 1, result.get(0)); + Assert.assertFalse(future.isCancelled()); + Assert.assertTrue(future.isDone()); + Assert.assertEquals(1, future.jobIds().size()); + } + + @Test public void foreachAsync() throws Exception { List<Integer> data = Arrays.asList(1, 2, 3, 4, 5); JavaRDD<Integer> rdd = sc.parallelize(data, 1); |