aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorSean Owen <sowen@cloudera.com>2017-02-19 09:42:50 -0800
committerSean Owen <sowen@cloudera.com>2017-02-19 09:42:50 -0800
commit1487c9af20a333ead55955acf4c0aa323bea0d07 (patch)
tree5f47daa77e0f73da1e009cc3dcf0a5c0073246aa /sql
parentde14d35f77071932963a994fac5aec0e5df838a1 (diff)
downloadspark-1487c9af20a333ead55955acf4c0aa323bea0d07.tar.gz
spark-1487c9af20a333ead55955acf4c0aa323bea0d07.tar.bz2
spark-1487c9af20a333ead55955acf4c0aa323bea0d07.zip
[SPARK-19534][TESTS] Convert Java tests to use lambdas, Java 8 features
## What changes were proposed in this pull request? Convert tests to use Java 8 lambdas, and modest related fixes to surrounding code. ## How was this patch tested? Jenkins tests Author: Sean Owen <sowen@cloudera.com> Closes #16964 from srowen/SPARK-19534.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java2
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/Java8DatasetAggregatorSuite.java16
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java22
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java47
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetAggregatorSuite.java49
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetAggregatorSuiteBase.java14
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java147
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java37
8 files changed, 108 insertions, 226 deletions
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java
index 06cd9ea2d2..bf87174835 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java
@@ -157,7 +157,7 @@ public abstract class SpecificParquetRecordReaderBase<T> extends RecordReader<Vo
// to the accumulator. So we can check if the row groups are filtered or not in test case.
TaskContext taskContext = TaskContext$.MODULE$.get();
if (taskContext != null) {
- Option<AccumulatorV2<?, ?>> accu = (Option<AccumulatorV2<?, ?>>) taskContext.taskMetrics()
+ Option<AccumulatorV2<?, ?>> accu = taskContext.taskMetrics()
.lookForAccumulatorByName("numRowGroups");
if (accu.isDefined()) {
((LongAccumulator)accu.get()).add((long)blocks.size());
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/Java8DatasetAggregatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/Java8DatasetAggregatorSuite.java
index 8b8a403e2b..6ffccee52c 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/Java8DatasetAggregatorSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/Java8DatasetAggregatorSuite.java
@@ -35,27 +35,35 @@ public class Java8DatasetAggregatorSuite extends JavaDatasetAggregatorSuiteBase
public void testTypedAggregationAverage() {
KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
Dataset<Tuple2<String, Double>> agged = grouped.agg(typed.avg(v -> (double)(v._2() * 2)));
- Assert.assertEquals(Arrays.asList(tuple2("a", 3.0), tuple2("b", 6.0)), agged.collectAsList());
+ Assert.assertEquals(
+ Arrays.asList(new Tuple2<>("a", 3.0), new Tuple2<>("b", 6.0)),
+ agged.collectAsList());
}
@Test
public void testTypedAggregationCount() {
KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
Dataset<Tuple2<String, Long>> agged = grouped.agg(typed.count(v -> v));
- Assert.assertEquals(Arrays.asList(tuple2("a", 2L), tuple2("b", 1L)), agged.collectAsList());
+ Assert.assertEquals(
+ Arrays.asList(new Tuple2<>("a", 2L), new Tuple2<>("b", 1L)),
+ agged.collectAsList());
}
@Test
public void testTypedAggregationSumDouble() {
KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
Dataset<Tuple2<String, Double>> agged = grouped.agg(typed.sum(v -> (double)v._2()));
- Assert.assertEquals(Arrays.asList(tuple2("a", 3.0), tuple2("b", 3.0)), agged.collectAsList());
+ Assert.assertEquals(
+ Arrays.asList(new Tuple2<>("a", 3.0), new Tuple2<>("b", 3.0)),
+ agged.collectAsList());
}
@Test
public void testTypedAggregationSumLong() {
KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
Dataset<Tuple2<String, Long>> agged = grouped.agg(typed.sumLong(v -> (long)v._2()));
- Assert.assertEquals(Arrays.asList(tuple2("a", 3L), tuple2("b", 3L)), agged.collectAsList());
+ Assert.assertEquals(
+ Arrays.asList(new Tuple2<>("a", 3L), new Tuple2<>("b", 3L)),
+ agged.collectAsList());
}
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java
index 573d0e3594..bf8ff61eae 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java
@@ -30,7 +30,6 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
@@ -95,12 +94,7 @@ public class JavaApplySchemaSuite implements Serializable {
personList.add(person2);
JavaRDD<Row> rowRDD = jsc.parallelize(personList).map(
- new Function<Person, Row>() {
- @Override
- public Row call(Person person) throws Exception {
- return RowFactory.create(person.getName(), person.getAge());
- }
- });
+ person -> RowFactory.create(person.getName(), person.getAge()));
List<StructField> fields = new ArrayList<>(2);
fields.add(DataTypes.createStructField("name", DataTypes.StringType, false));
@@ -131,12 +125,7 @@ public class JavaApplySchemaSuite implements Serializable {
personList.add(person2);
JavaRDD<Row> rowRDD = jsc.parallelize(personList).map(
- new Function<Person, Row>() {
- @Override
- public Row call(Person person) {
- return RowFactory.create(person.getName(), person.getAge());
- }
- });
+ person -> RowFactory.create(person.getName(), person.getAge()));
List<StructField> fields = new ArrayList<>(2);
fields.add(DataTypes.createStructField("", DataTypes.StringType, false));
@@ -146,12 +135,7 @@ public class JavaApplySchemaSuite implements Serializable {
Dataset<Row> df = spark.createDataFrame(rowRDD, schema);
df.createOrReplaceTempView("people");
List<String> actual = spark.sql("SELECT * FROM people").toJavaRDD()
- .map(new Function<Row, String>() {
- @Override
- public String call(Row row) {
- return row.getString(0) + "_" + row.get(1);
- }
- }).collect();
+ .map(row -> row.getString(0) + "_" + row.get(1)).collect();
List<String> expected = new ArrayList<>(2);
expected.add("Michael_29");
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index c44fc3d393..c3b94a44c2 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -189,7 +189,7 @@ public class JavaDataFrameSuite {
for (int i = 0; i < d.length(); i++) {
Assert.assertEquals(bean.getD().get(i), d.apply(i));
}
- // Java.math.BigInteger is equavient to Spark Decimal(38,0)
+ // Java.math.BigInteger is equivalent to Spark Decimal(38,0)
Assert.assertEquals(new BigDecimal(bean.getE()), first.getDecimal(4));
}
@@ -231,13 +231,10 @@ public class JavaDataFrameSuite {
Assert.assertEquals(0, schema2.fieldIndex("id"));
}
- private static final Comparator<Row> crosstabRowComparator = new Comparator<Row>() {
- @Override
- public int compare(Row row1, Row row2) {
- String item1 = row1.getString(0);
- String item2 = row2.getString(0);
- return item1.compareTo(item2);
- }
+ private static final Comparator<Row> crosstabRowComparator = (row1, row2) -> {
+ String item1 = row1.getString(0);
+ String item2 = row2.getString(0);
+ return item1.compareTo(item2);
};
@Test
@@ -249,7 +246,7 @@ public class JavaDataFrameSuite {
Assert.assertEquals("1", columnNames[1]);
Assert.assertEquals("2", columnNames[2]);
List<Row> rows = crosstab.collectAsList();
- Collections.sort(rows, crosstabRowComparator);
+ rows.sort(crosstabRowComparator);
Integer count = 1;
for (Row row : rows) {
Assert.assertEquals(row.get(0).toString(), count.toString());
@@ -284,7 +281,7 @@ public class JavaDataFrameSuite {
@Test
public void testSampleBy() {
Dataset<Row> df = spark.range(0, 100, 1, 2).select(col("id").mod(3).as("key"));
- Dataset<Row> sampled = df.stat().<Integer>sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L);
+ Dataset<Row> sampled = df.stat().sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L);
List<Row> actual = sampled.groupBy("key").count().orderBy("key").collectAsList();
Assert.assertEquals(0, actual.get(0).getLong(0));
Assert.assertTrue(0 <= actual.get(0).getLong(1) && actual.get(0).getLong(1) <= 8);
@@ -296,7 +293,7 @@ public class JavaDataFrameSuite {
public void pivot() {
Dataset<Row> df = spark.table("courseSales");
List<Row> actual = df.groupBy("year")
- .pivot("course", Arrays.<Object>asList("dotNET", "Java"))
+ .pivot("course", Arrays.asList("dotNET", "Java"))
.agg(sum("earnings")).orderBy("year").collectAsList();
Assert.assertEquals(2012, actual.get(0).getInt(0));
@@ -352,24 +349,24 @@ public class JavaDataFrameSuite {
Dataset<Long> df = spark.range(1000);
CountMinSketch sketch1 = df.stat().countMinSketch("id", 10, 20, 42);
- Assert.assertEquals(sketch1.totalCount(), 1000);
- Assert.assertEquals(sketch1.depth(), 10);
- Assert.assertEquals(sketch1.width(), 20);
+ Assert.assertEquals(1000, sketch1.totalCount());
+ Assert.assertEquals(10, sketch1.depth());
+ Assert.assertEquals(20, sketch1.width());
CountMinSketch sketch2 = df.stat().countMinSketch(col("id"), 10, 20, 42);
- Assert.assertEquals(sketch2.totalCount(), 1000);
- Assert.assertEquals(sketch2.depth(), 10);
- Assert.assertEquals(sketch2.width(), 20);
+ Assert.assertEquals(1000, sketch2.totalCount());
+ Assert.assertEquals(10, sketch2.depth());
+ Assert.assertEquals(20, sketch2.width());
CountMinSketch sketch3 = df.stat().countMinSketch("id", 0.001, 0.99, 42);
- Assert.assertEquals(sketch3.totalCount(), 1000);
- Assert.assertEquals(sketch3.relativeError(), 0.001, 1e-4);
- Assert.assertEquals(sketch3.confidence(), 0.99, 5e-3);
+ Assert.assertEquals(1000, sketch3.totalCount());
+ Assert.assertEquals(0.001, sketch3.relativeError(), 1.0e-4);
+ Assert.assertEquals(0.99, sketch3.confidence(), 5.0e-3);
CountMinSketch sketch4 = df.stat().countMinSketch(col("id"), 0.001, 0.99, 42);
- Assert.assertEquals(sketch4.totalCount(), 1000);
- Assert.assertEquals(sketch4.relativeError(), 0.001, 1e-4);
- Assert.assertEquals(sketch4.confidence(), 0.99, 5e-3);
+ Assert.assertEquals(1000, sketch4.totalCount());
+ Assert.assertEquals(0.001, sketch4.relativeError(), 1.0e-4);
+ Assert.assertEquals(0.99, sketch4.confidence(), 5.0e-3);
}
@Test
@@ -389,13 +386,13 @@ public class JavaDataFrameSuite {
}
BloomFilter filter3 = df.stat().bloomFilter("id", 1000, 64 * 5);
- Assert.assertTrue(filter3.bitSize() == 64 * 5);
+ Assert.assertEquals(64 * 5, filter3.bitSize());
for (int i = 0; i < 1000; i++) {
Assert.assertTrue(filter3.mightContain(i));
}
BloomFilter filter4 = df.stat().bloomFilter(col("id").multiply(3), 1000, 64 * 5);
- Assert.assertTrue(filter4.bitSize() == 64 * 5);
+ Assert.assertEquals(64 * 5, filter4.bitSize());
for (int i = 0; i < 1000; i++) {
Assert.assertTrue(filter4.mightContain(i * 3));
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetAggregatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetAggregatorSuite.java
index fe86371516..d3769a74b9 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetAggregatorSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetAggregatorSuite.java
@@ -24,7 +24,6 @@ import scala.Tuple2;
import org.junit.Assert;
import org.junit.Test;
-import org.apache.spark.api.java.function.MapFunction;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
@@ -41,7 +40,9 @@ public class JavaDatasetAggregatorSuite extends JavaDatasetAggregatorSuiteBase {
KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
Dataset<Tuple2<String, Integer>> agged = grouped.agg(new IntSumOf().toColumn());
- Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList());
+ Assert.assertEquals(
+ Arrays.asList(new Tuple2<>("a", 3), new Tuple2<>("b", 3)),
+ agged.collectAsList());
Dataset<Tuple2<String, Integer>> agged2 = grouped.agg(new IntSumOf().toColumn())
.as(Encoders.tuple(Encoders.STRING(), Encoders.INT()));
@@ -87,48 +88,36 @@ public class JavaDatasetAggregatorSuite extends JavaDatasetAggregatorSuiteBase {
@Test
public void testTypedAggregationAverage() {
KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
- Dataset<Tuple2<String, Double>> agged = grouped.agg(typed.avg(
- new MapFunction<Tuple2<String, Integer>, Double>() {
- public Double call(Tuple2<String, Integer> value) throws Exception {
- return (double)(value._2() * 2);
- }
- }));
- Assert.assertEquals(Arrays.asList(tuple2("a", 3.0), tuple2("b", 6.0)), agged.collectAsList());
+ Dataset<Tuple2<String, Double>> agged = grouped.agg(typed.avg(value -> (double)(value._2() * 2)));
+ Assert.assertEquals(
+ Arrays.asList(new Tuple2<>("a", 3.0), new Tuple2<>("b", 6.0)),
+ agged.collectAsList());
}
@Test
public void testTypedAggregationCount() {
KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
- Dataset<Tuple2<String, Long>> agged = grouped.agg(typed.count(
- new MapFunction<Tuple2<String, Integer>, Object>() {
- public Object call(Tuple2<String, Integer> value) throws Exception {
- return value;
- }
- }));
- Assert.assertEquals(Arrays.asList(tuple2("a", 2), tuple2("b", 1)), agged.collectAsList());
+ Dataset<Tuple2<String, Long>> agged = grouped.agg(typed.count(value -> value));
+ Assert.assertEquals(
+ Arrays.asList(new Tuple2<>("a", 2L), new Tuple2<>("b", 1L)),
+ agged.collectAsList());
}
@Test
public void testTypedAggregationSumDouble() {
KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
- Dataset<Tuple2<String, Double>> agged = grouped.agg(typed.sum(
- new MapFunction<Tuple2<String, Integer>, Double>() {
- public Double call(Tuple2<String, Integer> value) throws Exception {
- return (double)value._2();
- }
- }));
- Assert.assertEquals(Arrays.asList(tuple2("a", 3.0), tuple2("b", 3.0)), agged.collectAsList());
+ Dataset<Tuple2<String, Double>> agged = grouped.agg(typed.sum(value -> (double) value._2()));
+ Assert.assertEquals(
+ Arrays.asList(new Tuple2<>("a", 3.0), new Tuple2<>("b", 3.0)),
+ agged.collectAsList());
}
@Test
public void testTypedAggregationSumLong() {
KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
- Dataset<Tuple2<String, Long>> agged = grouped.agg(typed.sumLong(
- new MapFunction<Tuple2<String, Integer>, Long>() {
- public Long call(Tuple2<String, Integer> value) throws Exception {
- return (long)value._2();
- }
- }));
- Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList());
+ Dataset<Tuple2<String, Long>> agged = grouped.agg(typed.sumLong(value -> (long) value._2()));
+ Assert.assertEquals(
+ Arrays.asList(new Tuple2<>("a", 3L), new Tuple2<>("b", 3L)),
+ agged.collectAsList());
}
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetAggregatorSuiteBase.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetAggregatorSuiteBase.java
index 8fc4eff55d..e62db7d2cf 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetAggregatorSuiteBase.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetAggregatorSuiteBase.java
@@ -52,23 +52,13 @@ public class JavaDatasetAggregatorSuiteBase implements Serializable {
spark = null;
}
- protected <T1, T2> Tuple2<T1, T2> tuple2(T1 t1, T2 t2) {
- return new Tuple2<>(t1, t2);
- }
-
protected KeyValueGroupedDataset<String, Tuple2<String, Integer>> generateGroupedDataset() {
Encoder<Tuple2<String, Integer>> encoder = Encoders.tuple(Encoders.STRING(), Encoders.INT());
List<Tuple2<String, Integer>> data =
- Arrays.asList(tuple2("a", 1), tuple2("a", 2), tuple2("b", 3));
+ Arrays.asList(new Tuple2<>("a", 1), new Tuple2<>("a", 2), new Tuple2<>("b", 3));
Dataset<Tuple2<String, Integer>> ds = spark.createDataset(data, encoder);
- return ds.groupByKey(
- new MapFunction<Tuple2<String, Integer>, String>() {
- @Override
- public String call(Tuple2<String, Integer> value) throws Exception {
- return value._1();
- }
- },
+ return ds.groupByKey((MapFunction<Tuple2<String, Integer>, String>) value -> value._1(),
Encoders.STRING());
}
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index a94a37cb21..577672ca8e 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -96,12 +96,7 @@ public class JavaDatasetSuite implements Serializable {
@Test
public void testTypedFilterPreservingSchema() {
Dataset<Long> ds = spark.range(10);
- Dataset<Long> ds2 = ds.filter(new FilterFunction<Long>() {
- @Override
- public boolean call(Long value) throws Exception {
- return value > 3;
- }
- });
+ Dataset<Long> ds2 = ds.filter((FilterFunction<Long>) value -> value > 3);
Assert.assertEquals(ds.schema(), ds2.schema());
}
@@ -111,44 +106,28 @@ public class JavaDatasetSuite implements Serializable {
Dataset<String> ds = spark.createDataset(data, Encoders.STRING());
Assert.assertEquals("hello", ds.first());
- Dataset<String> filtered = ds.filter(new FilterFunction<String>() {
- @Override
- public boolean call(String v) throws Exception {
- return v.startsWith("h");
- }
- });
+ Dataset<String> filtered = ds.filter((FilterFunction<String>) v -> v.startsWith("h"));
Assert.assertEquals(Arrays.asList("hello"), filtered.collectAsList());
- Dataset<Integer> mapped = ds.map(new MapFunction<String, Integer>() {
- @Override
- public Integer call(String v) throws Exception {
- return v.length();
- }
- }, Encoders.INT());
+ Dataset<Integer> mapped = ds.map((MapFunction<String, Integer>) v -> v.length(), Encoders.INT());
Assert.assertEquals(Arrays.asList(5, 5), mapped.collectAsList());
- Dataset<String> parMapped = ds.mapPartitions(new MapPartitionsFunction<String, String>() {
- @Override
- public Iterator<String> call(Iterator<String> it) {
- List<String> ls = new LinkedList<>();
- while (it.hasNext()) {
- ls.add(it.next().toUpperCase(Locale.ENGLISH));
- }
- return ls.iterator();
+ Dataset<String> parMapped = ds.mapPartitions((MapPartitionsFunction<String, String>) it -> {
+ List<String> ls = new LinkedList<>();
+ while (it.hasNext()) {
+ ls.add(it.next().toUpperCase(Locale.ENGLISH));
}
+ return ls.iterator();
}, Encoders.STRING());
Assert.assertEquals(Arrays.asList("HELLO", "WORLD"), parMapped.collectAsList());
- Dataset<String> flatMapped = ds.flatMap(new FlatMapFunction<String, String>() {
- @Override
- public Iterator<String> call(String s) {
- List<String> ls = new LinkedList<>();
- for (char c : s.toCharArray()) {
- ls.add(String.valueOf(c));
- }
- return ls.iterator();
+ Dataset<String> flatMapped = ds.flatMap((FlatMapFunction<String, String>) s -> {
+ List<String> ls = new LinkedList<>();
+ for (char c : s.toCharArray()) {
+ ls.add(String.valueOf(c));
}
+ return ls.iterator();
}, Encoders.STRING());
Assert.assertEquals(
Arrays.asList("h", "e", "l", "l", "o", "w", "o", "r", "l", "d"),
@@ -157,16 +136,11 @@ public class JavaDatasetSuite implements Serializable {
@Test
public void testForeach() {
- final LongAccumulator accum = jsc.sc().longAccumulator();
+ LongAccumulator accum = jsc.sc().longAccumulator();
List<String> data = Arrays.asList("a", "b", "c");
Dataset<String> ds = spark.createDataset(data, Encoders.STRING());
- ds.foreach(new ForeachFunction<String>() {
- @Override
- public void call(String s) throws Exception {
- accum.add(1);
- }
- });
+ ds.foreach((ForeachFunction<String>) s -> accum.add(1));
Assert.assertEquals(3, accum.value().intValue());
}
@@ -175,12 +149,7 @@ public class JavaDatasetSuite implements Serializable {
List<Integer> data = Arrays.asList(1, 2, 3);
Dataset<Integer> ds = spark.createDataset(data, Encoders.INT());
- int reduced = ds.reduce(new ReduceFunction<Integer>() {
- @Override
- public Integer call(Integer v1, Integer v2) throws Exception {
- return v1 + v2;
- }
- });
+ int reduced = ds.reduce((ReduceFunction<Integer>) (v1, v2) -> v1 + v2);
Assert.assertEquals(6, reduced);
}
@@ -189,52 +158,38 @@ public class JavaDatasetSuite implements Serializable {
List<String> data = Arrays.asList("a", "foo", "bar");
Dataset<String> ds = spark.createDataset(data, Encoders.STRING());
KeyValueGroupedDataset<Integer, String> grouped = ds.groupByKey(
- new MapFunction<String, Integer>() {
- @Override
- public Integer call(String v) throws Exception {
- return v.length();
- }
- },
+ (MapFunction<String, Integer>) v -> v.length(),
Encoders.INT());
- Dataset<String> mapped = grouped.mapGroups(new MapGroupsFunction<Integer, String, String>() {
- @Override
- public String call(Integer key, Iterator<String> values) throws Exception {
- StringBuilder sb = new StringBuilder(key.toString());
- while (values.hasNext()) {
- sb.append(values.next());
- }
- return sb.toString();
+ Dataset<String> mapped = grouped.mapGroups((MapGroupsFunction<Integer, String, String>) (key, values) -> {
+ StringBuilder sb = new StringBuilder(key.toString());
+ while (values.hasNext()) {
+ sb.append(values.next());
}
+ return sb.toString();
}, Encoders.STRING());
Assert.assertEquals(asSet("1a", "3foobar"), toSet(mapped.collectAsList()));
Dataset<String> flatMapped = grouped.flatMapGroups(
- new FlatMapGroupsFunction<Integer, String, String>() {
- @Override
- public Iterator<String> call(Integer key, Iterator<String> values) {
+ (FlatMapGroupsFunction<Integer, String, String>) (key, values) -> {
StringBuilder sb = new StringBuilder(key.toString());
while (values.hasNext()) {
sb.append(values.next());
}
return Collections.singletonList(sb.toString()).iterator();
- }
- },
+ },
Encoders.STRING());
Assert.assertEquals(asSet("1a", "3foobar"), toSet(flatMapped.collectAsList()));
Dataset<String> mapped2 = grouped.mapGroupsWithState(
- new MapGroupsWithStateFunction<Integer, String, Long, String>() {
- @Override
- public String call(Integer key, Iterator<String> values, KeyedState<Long> s) {
+ (MapGroupsWithStateFunction<Integer, String, Long, String>) (key, values, s) -> {
StringBuilder sb = new StringBuilder(key.toString());
while (values.hasNext()) {
sb.append(values.next());
}
return sb.toString();
- }
},
Encoders.LONG(),
Encoders.STRING());
@@ -242,27 +197,19 @@ public class JavaDatasetSuite implements Serializable {
Assert.assertEquals(asSet("1a", "3foobar"), toSet(mapped2.collectAsList()));
Dataset<String> flatMapped2 = grouped.flatMapGroupsWithState(
- new FlatMapGroupsWithStateFunction<Integer, String, Long, String>() {
- @Override
- public Iterator<String> call(Integer key, Iterator<String> values, KeyedState<Long> s) {
+ (FlatMapGroupsWithStateFunction<Integer, String, Long, String>) (key, values, s) -> {
StringBuilder sb = new StringBuilder(key.toString());
while (values.hasNext()) {
sb.append(values.next());
}
return Collections.singletonList(sb.toString()).iterator();
- }
- },
+ },
Encoders.LONG(),
Encoders.STRING());
Assert.assertEquals(asSet("1a", "3foobar"), toSet(flatMapped2.collectAsList()));
- Dataset<Tuple2<Integer, String>> reduced = grouped.reduceGroups(new ReduceFunction<String>() {
- @Override
- public String call(String v1, String v2) throws Exception {
- return v1 + v2;
- }
- });
+ Dataset<Tuple2<Integer, String>> reduced = grouped.reduceGroups((ReduceFunction<String>) (v1, v2) -> v1 + v2);
Assert.assertEquals(
asSet(tuple2(1, "a"), tuple2(3, "foobar")),
@@ -271,29 +218,21 @@ public class JavaDatasetSuite implements Serializable {
List<Integer> data2 = Arrays.asList(2, 6, 10);
Dataset<Integer> ds2 = spark.createDataset(data2, Encoders.INT());
KeyValueGroupedDataset<Integer, Integer> grouped2 = ds2.groupByKey(
- new MapFunction<Integer, Integer>() {
- @Override
- public Integer call(Integer v) throws Exception {
- return v / 2;
- }
- },
+ (MapFunction<Integer, Integer>) v -> v / 2,
Encoders.INT());
Dataset<String> cogrouped = grouped.cogroup(
grouped2,
- new CoGroupFunction<Integer, String, Integer, String>() {
- @Override
- public Iterator<String> call(Integer key, Iterator<String> left, Iterator<Integer> right) {
- StringBuilder sb = new StringBuilder(key.toString());
- while (left.hasNext()) {
- sb.append(left.next());
- }
- sb.append("#");
- while (right.hasNext()) {
- sb.append(right.next());
- }
- return Collections.singletonList(sb.toString()).iterator();
+ (CoGroupFunction<Integer, String, Integer, String>) (key, left, right) -> {
+ StringBuilder sb = new StringBuilder(key.toString());
+ while (left.hasNext()) {
+ sb.append(left.next());
+ }
+ sb.append("#");
+ while (right.hasNext()) {
+ sb.append(right.next());
}
+ return Collections.singletonList(sb.toString()).iterator();
},
Encoders.STRING());
@@ -703,11 +642,11 @@ public class JavaDatasetSuite implements Serializable {
obj1.setD(new String[]{"hello", null});
obj1.setE(Arrays.asList("a", "b"));
obj1.setF(Arrays.asList(100L, null, 200L));
- Map<Integer, String> map1 = new HashMap<Integer, String>();
+ Map<Integer, String> map1 = new HashMap<>();
map1.put(1, "a");
map1.put(2, "b");
obj1.setG(map1);
- Map<String, String> nestedMap1 = new HashMap<String, String>();
+ Map<String, String> nestedMap1 = new HashMap<>();
nestedMap1.put("x", "1");
nestedMap1.put("y", "2");
Map<List<Long>, Map<String, String>> complexMap1 = new HashMap<>();
@@ -721,11 +660,11 @@ public class JavaDatasetSuite implements Serializable {
obj2.setD(new String[]{null, "world"});
obj2.setE(Arrays.asList("x", "y"));
obj2.setF(Arrays.asList(300L, null, 400L));
- Map<Integer, String> map2 = new HashMap<Integer, String>();
+ Map<Integer, String> map2 = new HashMap<>();
map2.put(3, "c");
map2.put(4, "d");
obj2.setG(map2);
- Map<String, String> nestedMap2 = new HashMap<String, String>();
+ Map<String, String> nestedMap2 = new HashMap<>();
nestedMap2.put("q", "1");
nestedMap2.put("w", "2");
Map<List<Long>, Map<String, String>> complexMap2 = new HashMap<>();
@@ -1328,7 +1267,7 @@ public class JavaDatasetSuite implements Serializable {
@Test
public void test() {
/* SPARK-15285 Large numbers of Nested JavaBeans generates more than 64KB java bytecode */
- List<NestedComplicatedJavaBean> data = new ArrayList<NestedComplicatedJavaBean>();
+ List<NestedComplicatedJavaBean> data = new ArrayList<>();
data.add(NestedComplicatedJavaBean.newBuilder().build());
NestedComplicatedJavaBean obj3 = new NestedComplicatedJavaBean();
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java
index bbaac5a339..250fa674d8 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java
@@ -27,7 +27,6 @@ import org.junit.Test;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
-import org.apache.spark.sql.api.java.UDF1;
import org.apache.spark.sql.api.java.UDF2;
import org.apache.spark.sql.types.DataTypes;
@@ -54,16 +53,7 @@ public class JavaUDFSuite implements Serializable {
@SuppressWarnings("unchecked")
@Test
public void udf1Test() {
- // With Java 8 lambdas:
- // sqlContext.registerFunction(
- // "stringLengthTest", (String str) -> str.length(), DataType.IntegerType);
-
- spark.udf().register("stringLengthTest", new UDF1<String, Integer>() {
- @Override
- public Integer call(String str) {
- return str.length();
- }
- }, DataTypes.IntegerType);
+ spark.udf().register("stringLengthTest", (String str) -> str.length(), DataTypes.IntegerType);
Row result = spark.sql("SELECT stringLengthTest('test')").head();
Assert.assertEquals(4, result.getInt(0));
@@ -72,18 +62,8 @@ public class JavaUDFSuite implements Serializable {
@SuppressWarnings("unchecked")
@Test
public void udf2Test() {
- // With Java 8 lambdas:
- // sqlContext.registerFunction(
- // "stringLengthTest",
- // (String str1, String str2) -> str1.length() + str2.length,
- // DataType.IntegerType);
-
- spark.udf().register("stringLengthTest", new UDF2<String, String, Integer>() {
- @Override
- public Integer call(String str1, String str2) {
- return str1.length() + str2.length();
- }
- }, DataTypes.IntegerType);
+ spark.udf().register("stringLengthTest",
+ (String str1, String str2) -> str1.length() + str2.length(), DataTypes.IntegerType);
Row result = spark.sql("SELECT stringLengthTest('test', 'test2')").head();
Assert.assertEquals(9, result.getInt(0));
@@ -91,8 +71,8 @@ public class JavaUDFSuite implements Serializable {
public static class StringLengthTest implements UDF2<String, String, Integer> {
@Override
- public Integer call(String str1, String str2) throws Exception {
- return new Integer(str1.length() + str2.length());
+ public Integer call(String str1, String str2) {
+ return str1.length() + str2.length();
}
}
@@ -113,12 +93,7 @@ public class JavaUDFSuite implements Serializable {
@SuppressWarnings("unchecked")
@Test
public void udf4Test() {
- spark.udf().register("inc", new UDF1<Long, Long>() {
- @Override
- public Long call(Long i) {
- return i + 1;
- }
- }, DataTypes.LongType);
+ spark.udf().register("inc", (Long i) -> i + 1, DataTypes.LongType);
spark.range(10).toDF("x").createOrReplaceTempView("tmp");
// This tests when Java UDFs are required to be the semantically same (See SPARK-9435).