diff options
Diffstat (limited to 'sql')
8 files changed, 108 insertions, 226 deletions
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index 06cd9ea2d2..bf87174835 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -157,7 +157,7 @@ public abstract class SpecificParquetRecordReaderBase<T> extends RecordReader<Vo // to the accumulator. So we can check if the row groups are filtered or not in test case. TaskContext taskContext = TaskContext$.MODULE$.get(); if (taskContext != null) { - Option<AccumulatorV2<?, ?>> accu = (Option<AccumulatorV2<?, ?>>) taskContext.taskMetrics() + Option<AccumulatorV2<?, ?>> accu = taskContext.taskMetrics() .lookForAccumulatorByName("numRowGroups"); if (accu.isDefined()) { ((LongAccumulator)accu.get()).add((long)blocks.size()); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/Java8DatasetAggregatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/Java8DatasetAggregatorSuite.java index 8b8a403e2b..6ffccee52c 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/Java8DatasetAggregatorSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/Java8DatasetAggregatorSuite.java @@ -35,27 +35,35 @@ public class Java8DatasetAggregatorSuite extends JavaDatasetAggregatorSuiteBase public void testTypedAggregationAverage() { KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset(); Dataset<Tuple2<String, Double>> agged = grouped.agg(typed.avg(v -> (double)(v._2() * 2))); - Assert.assertEquals(Arrays.asList(tuple2("a", 3.0), tuple2("b", 6.0)), agged.collectAsList()); + Assert.assertEquals( + Arrays.asList(new Tuple2<>("a", 3.0), new Tuple2<>("b", 6.0)), + agged.collectAsList()); } @Test public void testTypedAggregationCount() { KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset(); Dataset<Tuple2<String, Long>> agged = grouped.agg(typed.count(v -> v)); - Assert.assertEquals(Arrays.asList(tuple2("a", 2L), tuple2("b", 1L)), agged.collectAsList()); + Assert.assertEquals( + Arrays.asList(new Tuple2<>("a", 2L), new Tuple2<>("b", 1L)), + agged.collectAsList()); } @Test public void testTypedAggregationSumDouble() { KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset(); Dataset<Tuple2<String, Double>> agged = grouped.agg(typed.sum(v -> (double)v._2())); - Assert.assertEquals(Arrays.asList(tuple2("a", 3.0), tuple2("b", 3.0)), agged.collectAsList()); + Assert.assertEquals( + Arrays.asList(new Tuple2<>("a", 3.0), new Tuple2<>("b", 3.0)), + agged.collectAsList()); } @Test public void testTypedAggregationSumLong() { KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset(); Dataset<Tuple2<String, Long>> agged = grouped.agg(typed.sumLong(v -> (long)v._2())); - Assert.assertEquals(Arrays.asList(tuple2("a", 3L), tuple2("b", 3L)), agged.collectAsList()); + Assert.assertEquals( + Arrays.asList(new Tuple2<>("a", 3L), new Tuple2<>("b", 3L)), + agged.collectAsList()); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java index 573d0e3594..bf8ff61eae 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java @@ -30,7 +30,6 @@ import org.junit.Test; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; @@ -95,12 +94,7 @@ public class JavaApplySchemaSuite implements Serializable { personList.add(person2); JavaRDD<Row> rowRDD = jsc.parallelize(personList).map( - new Function<Person, Row>() { - @Override - public Row call(Person person) throws Exception { - return RowFactory.create(person.getName(), person.getAge()); - } - }); + person -> RowFactory.create(person.getName(), person.getAge())); List<StructField> fields = new ArrayList<>(2); fields.add(DataTypes.createStructField("name", DataTypes.StringType, false)); @@ -131,12 +125,7 @@ public class JavaApplySchemaSuite implements Serializable { personList.add(person2); JavaRDD<Row> rowRDD = jsc.parallelize(personList).map( - new Function<Person, Row>() { - @Override - public Row call(Person person) { - return RowFactory.create(person.getName(), person.getAge()); - } - }); + person -> RowFactory.create(person.getName(), person.getAge())); List<StructField> fields = new ArrayList<>(2); fields.add(DataTypes.createStructField("", DataTypes.StringType, false)); @@ -146,12 +135,7 @@ public class JavaApplySchemaSuite implements Serializable { Dataset<Row> df = spark.createDataFrame(rowRDD, schema); df.createOrReplaceTempView("people"); List<String> actual = spark.sql("SELECT * FROM people").toJavaRDD() - .map(new Function<Row, String>() { - @Override - public String call(Row row) { - return row.getString(0) + "_" + row.get(1); - } - }).collect(); + .map(row -> row.getString(0) + "_" + row.get(1)).collect(); List<String> expected = new ArrayList<>(2); expected.add("Michael_29"); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index c44fc3d393..c3b94a44c2 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -189,7 +189,7 @@ public class JavaDataFrameSuite { for (int i = 0; i < d.length(); i++) { Assert.assertEquals(bean.getD().get(i), d.apply(i)); } - // Java.math.BigInteger is equavient to Spark Decimal(38,0) + // Java.math.BigInteger is equivalent to Spark Decimal(38,0) Assert.assertEquals(new BigDecimal(bean.getE()), first.getDecimal(4)); } @@ -231,13 +231,10 @@ public class JavaDataFrameSuite { Assert.assertEquals(0, schema2.fieldIndex("id")); } - private static final Comparator<Row> crosstabRowComparator = new Comparator<Row>() { - @Override - public int compare(Row row1, Row row2) { - String item1 = row1.getString(0); - String item2 = row2.getString(0); - return item1.compareTo(item2); - } + private static final Comparator<Row> crosstabRowComparator = (row1, row2) -> { + String item1 = row1.getString(0); + String item2 = row2.getString(0); + return item1.compareTo(item2); }; @Test @@ -249,7 +246,7 @@ public class JavaDataFrameSuite { Assert.assertEquals("1", columnNames[1]); Assert.assertEquals("2", columnNames[2]); List<Row> rows = crosstab.collectAsList(); - Collections.sort(rows, crosstabRowComparator); + rows.sort(crosstabRowComparator); Integer count = 1; for (Row row : rows) { Assert.assertEquals(row.get(0).toString(), count.toString()); @@ -284,7 +281,7 @@ public class JavaDataFrameSuite { @Test public void testSampleBy() { Dataset<Row> df = spark.range(0, 100, 1, 2).select(col("id").mod(3).as("key")); - Dataset<Row> sampled = df.stat().<Integer>sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); + Dataset<Row> sampled = df.stat().sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); List<Row> actual = sampled.groupBy("key").count().orderBy("key").collectAsList(); Assert.assertEquals(0, actual.get(0).getLong(0)); Assert.assertTrue(0 <= actual.get(0).getLong(1) && actual.get(0).getLong(1) <= 8); @@ -296,7 +293,7 @@ public class JavaDataFrameSuite { public void pivot() { Dataset<Row> df = spark.table("courseSales"); List<Row> actual = df.groupBy("year") - .pivot("course", Arrays.<Object>asList("dotNET", "Java")) + .pivot("course", Arrays.asList("dotNET", "Java")) .agg(sum("earnings")).orderBy("year").collectAsList(); Assert.assertEquals(2012, actual.get(0).getInt(0)); @@ -352,24 +349,24 @@ public class JavaDataFrameSuite { Dataset<Long> df = spark.range(1000); CountMinSketch sketch1 = df.stat().countMinSketch("id", 10, 20, 42); - Assert.assertEquals(sketch1.totalCount(), 1000); - Assert.assertEquals(sketch1.depth(), 10); - Assert.assertEquals(sketch1.width(), 20); + Assert.assertEquals(1000, sketch1.totalCount()); + Assert.assertEquals(10, sketch1.depth()); + Assert.assertEquals(20, sketch1.width()); CountMinSketch sketch2 = df.stat().countMinSketch(col("id"), 10, 20, 42); - Assert.assertEquals(sketch2.totalCount(), 1000); - Assert.assertEquals(sketch2.depth(), 10); - Assert.assertEquals(sketch2.width(), 20); + Assert.assertEquals(1000, sketch2.totalCount()); + Assert.assertEquals(10, sketch2.depth()); + Assert.assertEquals(20, sketch2.width()); CountMinSketch sketch3 = df.stat().countMinSketch("id", 0.001, 0.99, 42); - Assert.assertEquals(sketch3.totalCount(), 1000); - Assert.assertEquals(sketch3.relativeError(), 0.001, 1e-4); - Assert.assertEquals(sketch3.confidence(), 0.99, 5e-3); + Assert.assertEquals(1000, sketch3.totalCount()); + Assert.assertEquals(0.001, sketch3.relativeError(), 1.0e-4); + Assert.assertEquals(0.99, sketch3.confidence(), 5.0e-3); CountMinSketch sketch4 = df.stat().countMinSketch(col("id"), 0.001, 0.99, 42); - Assert.assertEquals(sketch4.totalCount(), 1000); - Assert.assertEquals(sketch4.relativeError(), 0.001, 1e-4); - Assert.assertEquals(sketch4.confidence(), 0.99, 5e-3); + Assert.assertEquals(1000, sketch4.totalCount()); + Assert.assertEquals(0.001, sketch4.relativeError(), 1.0e-4); + Assert.assertEquals(0.99, sketch4.confidence(), 5.0e-3); } @Test @@ -389,13 +386,13 @@ public class JavaDataFrameSuite { } BloomFilter filter3 = df.stat().bloomFilter("id", 1000, 64 * 5); - Assert.assertTrue(filter3.bitSize() == 64 * 5); + Assert.assertEquals(64 * 5, filter3.bitSize()); for (int i = 0; i < 1000; i++) { Assert.assertTrue(filter3.mightContain(i)); } BloomFilter filter4 = df.stat().bloomFilter(col("id").multiply(3), 1000, 64 * 5); - Assert.assertTrue(filter4.bitSize() == 64 * 5); + Assert.assertEquals(64 * 5, filter4.bitSize()); for (int i = 0; i < 1000; i++) { Assert.assertTrue(filter4.mightContain(i * 3)); } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetAggregatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetAggregatorSuite.java index fe86371516..d3769a74b9 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetAggregatorSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetAggregatorSuite.java @@ -24,7 +24,6 @@ import scala.Tuple2; import org.junit.Assert; import org.junit.Test; -import org.apache.spark.api.java.function.MapFunction; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Encoder; import org.apache.spark.sql.Encoders; @@ -41,7 +40,9 @@ public class JavaDatasetAggregatorSuite extends JavaDatasetAggregatorSuiteBase { KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset(); Dataset<Tuple2<String, Integer>> agged = grouped.agg(new IntSumOf().toColumn()); - Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList()); + Assert.assertEquals( + Arrays.asList(new Tuple2<>("a", 3), new Tuple2<>("b", 3)), + agged.collectAsList()); Dataset<Tuple2<String, Integer>> agged2 = grouped.agg(new IntSumOf().toColumn()) .as(Encoders.tuple(Encoders.STRING(), Encoders.INT())); @@ -87,48 +88,36 @@ public class JavaDatasetAggregatorSuite extends JavaDatasetAggregatorSuiteBase { @Test public void testTypedAggregationAverage() { KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset(); - Dataset<Tuple2<String, Double>> agged = grouped.agg(typed.avg( - new MapFunction<Tuple2<String, Integer>, Double>() { - public Double call(Tuple2<String, Integer> value) throws Exception { - return (double)(value._2() * 2); - } - })); - Assert.assertEquals(Arrays.asList(tuple2("a", 3.0), tuple2("b", 6.0)), agged.collectAsList()); + Dataset<Tuple2<String, Double>> agged = grouped.agg(typed.avg(value -> (double)(value._2() * 2))); + Assert.assertEquals( + Arrays.asList(new Tuple2<>("a", 3.0), new Tuple2<>("b", 6.0)), + agged.collectAsList()); } @Test public void testTypedAggregationCount() { KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset(); - Dataset<Tuple2<String, Long>> agged = grouped.agg(typed.count( - new MapFunction<Tuple2<String, Integer>, Object>() { - public Object call(Tuple2<String, Integer> value) throws Exception { - return value; - } - })); - Assert.assertEquals(Arrays.asList(tuple2("a", 2), tuple2("b", 1)), agged.collectAsList()); + Dataset<Tuple2<String, Long>> agged = grouped.agg(typed.count(value -> value)); + Assert.assertEquals( + Arrays.asList(new Tuple2<>("a", 2L), new Tuple2<>("b", 1L)), + agged.collectAsList()); } @Test public void testTypedAggregationSumDouble() { KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset(); - Dataset<Tuple2<String, Double>> agged = grouped.agg(typed.sum( - new MapFunction<Tuple2<String, Integer>, Double>() { - public Double call(Tuple2<String, Integer> value) throws Exception { - return (double)value._2(); - } - })); - Assert.assertEquals(Arrays.asList(tuple2("a", 3.0), tuple2("b", 3.0)), agged.collectAsList()); + Dataset<Tuple2<String, Double>> agged = grouped.agg(typed.sum(value -> (double) value._2())); + Assert.assertEquals( + Arrays.asList(new Tuple2<>("a", 3.0), new Tuple2<>("b", 3.0)), + agged.collectAsList()); } @Test public void testTypedAggregationSumLong() { KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset(); - Dataset<Tuple2<String, Long>> agged = grouped.agg(typed.sumLong( - new MapFunction<Tuple2<String, Integer>, Long>() { - public Long call(Tuple2<String, Integer> value) throws Exception { - return (long)value._2(); - } - })); - Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList()); + Dataset<Tuple2<String, Long>> agged = grouped.agg(typed.sumLong(value -> (long) value._2())); + Assert.assertEquals( + Arrays.asList(new Tuple2<>("a", 3L), new Tuple2<>("b", 3L)), + agged.collectAsList()); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetAggregatorSuiteBase.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetAggregatorSuiteBase.java index 8fc4eff55d..e62db7d2cf 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetAggregatorSuiteBase.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetAggregatorSuiteBase.java @@ -52,23 +52,13 @@ public class JavaDatasetAggregatorSuiteBase implements Serializable { spark = null; } - protected <T1, T2> Tuple2<T1, T2> tuple2(T1 t1, T2 t2) { - return new Tuple2<>(t1, t2); - } - protected KeyValueGroupedDataset<String, Tuple2<String, Integer>> generateGroupedDataset() { Encoder<Tuple2<String, Integer>> encoder = Encoders.tuple(Encoders.STRING(), Encoders.INT()); List<Tuple2<String, Integer>> data = - Arrays.asList(tuple2("a", 1), tuple2("a", 2), tuple2("b", 3)); + Arrays.asList(new Tuple2<>("a", 1), new Tuple2<>("a", 2), new Tuple2<>("b", 3)); Dataset<Tuple2<String, Integer>> ds = spark.createDataset(data, encoder); - return ds.groupByKey( - new MapFunction<Tuple2<String, Integer>, String>() { - @Override - public String call(Tuple2<String, Integer> value) throws Exception { - return value._1(); - } - }, + return ds.groupByKey((MapFunction<Tuple2<String, Integer>, String>) value -> value._1(), Encoders.STRING()); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index a94a37cb21..577672ca8e 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -96,12 +96,7 @@ public class JavaDatasetSuite implements Serializable { @Test public void testTypedFilterPreservingSchema() { Dataset<Long> ds = spark.range(10); - Dataset<Long> ds2 = ds.filter(new FilterFunction<Long>() { - @Override - public boolean call(Long value) throws Exception { - return value > 3; - } - }); + Dataset<Long> ds2 = ds.filter((FilterFunction<Long>) value -> value > 3); Assert.assertEquals(ds.schema(), ds2.schema()); } @@ -111,44 +106,28 @@ public class JavaDatasetSuite implements Serializable { Dataset<String> ds = spark.createDataset(data, Encoders.STRING()); Assert.assertEquals("hello", ds.first()); - Dataset<String> filtered = ds.filter(new FilterFunction<String>() { - @Override - public boolean call(String v) throws Exception { - return v.startsWith("h"); - } - }); + Dataset<String> filtered = ds.filter((FilterFunction<String>) v -> v.startsWith("h")); Assert.assertEquals(Arrays.asList("hello"), filtered.collectAsList()); - Dataset<Integer> mapped = ds.map(new MapFunction<String, Integer>() { - @Override - public Integer call(String v) throws Exception { - return v.length(); - } - }, Encoders.INT()); + Dataset<Integer> mapped = ds.map((MapFunction<String, Integer>) v -> v.length(), Encoders.INT()); Assert.assertEquals(Arrays.asList(5, 5), mapped.collectAsList()); - Dataset<String> parMapped = ds.mapPartitions(new MapPartitionsFunction<String, String>() { - @Override - public Iterator<String> call(Iterator<String> it) { - List<String> ls = new LinkedList<>(); - while (it.hasNext()) { - ls.add(it.next().toUpperCase(Locale.ENGLISH)); - } - return ls.iterator(); + Dataset<String> parMapped = ds.mapPartitions((MapPartitionsFunction<String, String>) it -> { + List<String> ls = new LinkedList<>(); + while (it.hasNext()) { + ls.add(it.next().toUpperCase(Locale.ENGLISH)); } + return ls.iterator(); }, Encoders.STRING()); Assert.assertEquals(Arrays.asList("HELLO", "WORLD"), parMapped.collectAsList()); - Dataset<String> flatMapped = ds.flatMap(new FlatMapFunction<String, String>() { - @Override - public Iterator<String> call(String s) { - List<String> ls = new LinkedList<>(); - for (char c : s.toCharArray()) { - ls.add(String.valueOf(c)); - } - return ls.iterator(); + Dataset<String> flatMapped = ds.flatMap((FlatMapFunction<String, String>) s -> { + List<String> ls = new LinkedList<>(); + for (char c : s.toCharArray()) { + ls.add(String.valueOf(c)); } + return ls.iterator(); }, Encoders.STRING()); Assert.assertEquals( Arrays.asList("h", "e", "l", "l", "o", "w", "o", "r", "l", "d"), @@ -157,16 +136,11 @@ public class JavaDatasetSuite implements Serializable { @Test public void testForeach() { - final LongAccumulator accum = jsc.sc().longAccumulator(); + LongAccumulator accum = jsc.sc().longAccumulator(); List<String> data = Arrays.asList("a", "b", "c"); Dataset<String> ds = spark.createDataset(data, Encoders.STRING()); - ds.foreach(new ForeachFunction<String>() { - @Override - public void call(String s) throws Exception { - accum.add(1); - } - }); + ds.foreach((ForeachFunction<String>) s -> accum.add(1)); Assert.assertEquals(3, accum.value().intValue()); } @@ -175,12 +149,7 @@ public class JavaDatasetSuite implements Serializable { List<Integer> data = Arrays.asList(1, 2, 3); Dataset<Integer> ds = spark.createDataset(data, Encoders.INT()); - int reduced = ds.reduce(new ReduceFunction<Integer>() { - @Override - public Integer call(Integer v1, Integer v2) throws Exception { - return v1 + v2; - } - }); + int reduced = ds.reduce((ReduceFunction<Integer>) (v1, v2) -> v1 + v2); Assert.assertEquals(6, reduced); } @@ -189,52 +158,38 @@ public class JavaDatasetSuite implements Serializable { List<String> data = Arrays.asList("a", "foo", "bar"); Dataset<String> ds = spark.createDataset(data, Encoders.STRING()); KeyValueGroupedDataset<Integer, String> grouped = ds.groupByKey( - new MapFunction<String, Integer>() { - @Override - public Integer call(String v) throws Exception { - return v.length(); - } - }, + (MapFunction<String, Integer>) v -> v.length(), Encoders.INT()); - Dataset<String> mapped = grouped.mapGroups(new MapGroupsFunction<Integer, String, String>() { - @Override - public String call(Integer key, Iterator<String> values) throws Exception { - StringBuilder sb = new StringBuilder(key.toString()); - while (values.hasNext()) { - sb.append(values.next()); - } - return sb.toString(); + Dataset<String> mapped = grouped.mapGroups((MapGroupsFunction<Integer, String, String>) (key, values) -> { + StringBuilder sb = new StringBuilder(key.toString()); + while (values.hasNext()) { + sb.append(values.next()); } + return sb.toString(); }, Encoders.STRING()); Assert.assertEquals(asSet("1a", "3foobar"), toSet(mapped.collectAsList())); Dataset<String> flatMapped = grouped.flatMapGroups( - new FlatMapGroupsFunction<Integer, String, String>() { - @Override - public Iterator<String> call(Integer key, Iterator<String> values) { + (FlatMapGroupsFunction<Integer, String, String>) (key, values) -> { StringBuilder sb = new StringBuilder(key.toString()); while (values.hasNext()) { sb.append(values.next()); } return Collections.singletonList(sb.toString()).iterator(); - } - }, + }, Encoders.STRING()); Assert.assertEquals(asSet("1a", "3foobar"), toSet(flatMapped.collectAsList())); Dataset<String> mapped2 = grouped.mapGroupsWithState( - new MapGroupsWithStateFunction<Integer, String, Long, String>() { - @Override - public String call(Integer key, Iterator<String> values, KeyedState<Long> s) { + (MapGroupsWithStateFunction<Integer, String, Long, String>) (key, values, s) -> { StringBuilder sb = new StringBuilder(key.toString()); while (values.hasNext()) { sb.append(values.next()); } return sb.toString(); - } }, Encoders.LONG(), Encoders.STRING()); @@ -242,27 +197,19 @@ public class JavaDatasetSuite implements Serializable { Assert.assertEquals(asSet("1a", "3foobar"), toSet(mapped2.collectAsList())); Dataset<String> flatMapped2 = grouped.flatMapGroupsWithState( - new FlatMapGroupsWithStateFunction<Integer, String, Long, String>() { - @Override - public Iterator<String> call(Integer key, Iterator<String> values, KeyedState<Long> s) { + (FlatMapGroupsWithStateFunction<Integer, String, Long, String>) (key, values, s) -> { StringBuilder sb = new StringBuilder(key.toString()); while (values.hasNext()) { sb.append(values.next()); } return Collections.singletonList(sb.toString()).iterator(); - } - }, + }, Encoders.LONG(), Encoders.STRING()); Assert.assertEquals(asSet("1a", "3foobar"), toSet(flatMapped2.collectAsList())); - Dataset<Tuple2<Integer, String>> reduced = grouped.reduceGroups(new ReduceFunction<String>() { - @Override - public String call(String v1, String v2) throws Exception { - return v1 + v2; - } - }); + Dataset<Tuple2<Integer, String>> reduced = grouped.reduceGroups((ReduceFunction<String>) (v1, v2) -> v1 + v2); Assert.assertEquals( asSet(tuple2(1, "a"), tuple2(3, "foobar")), @@ -271,29 +218,21 @@ public class JavaDatasetSuite implements Serializable { List<Integer> data2 = Arrays.asList(2, 6, 10); Dataset<Integer> ds2 = spark.createDataset(data2, Encoders.INT()); KeyValueGroupedDataset<Integer, Integer> grouped2 = ds2.groupByKey( - new MapFunction<Integer, Integer>() { - @Override - public Integer call(Integer v) throws Exception { - return v / 2; - } - }, + (MapFunction<Integer, Integer>) v -> v / 2, Encoders.INT()); Dataset<String> cogrouped = grouped.cogroup( grouped2, - new CoGroupFunction<Integer, String, Integer, String>() { - @Override - public Iterator<String> call(Integer key, Iterator<String> left, Iterator<Integer> right) { - StringBuilder sb = new StringBuilder(key.toString()); - while (left.hasNext()) { - sb.append(left.next()); - } - sb.append("#"); - while (right.hasNext()) { - sb.append(right.next()); - } - return Collections.singletonList(sb.toString()).iterator(); + (CoGroupFunction<Integer, String, Integer, String>) (key, left, right) -> { + StringBuilder sb = new StringBuilder(key.toString()); + while (left.hasNext()) { + sb.append(left.next()); + } + sb.append("#"); + while (right.hasNext()) { + sb.append(right.next()); } + return Collections.singletonList(sb.toString()).iterator(); }, Encoders.STRING()); @@ -703,11 +642,11 @@ public class JavaDatasetSuite implements Serializable { obj1.setD(new String[]{"hello", null}); obj1.setE(Arrays.asList("a", "b")); obj1.setF(Arrays.asList(100L, null, 200L)); - Map<Integer, String> map1 = new HashMap<Integer, String>(); + Map<Integer, String> map1 = new HashMap<>(); map1.put(1, "a"); map1.put(2, "b"); obj1.setG(map1); - Map<String, String> nestedMap1 = new HashMap<String, String>(); + Map<String, String> nestedMap1 = new HashMap<>(); nestedMap1.put("x", "1"); nestedMap1.put("y", "2"); Map<List<Long>, Map<String, String>> complexMap1 = new HashMap<>(); @@ -721,11 +660,11 @@ public class JavaDatasetSuite implements Serializable { obj2.setD(new String[]{null, "world"}); obj2.setE(Arrays.asList("x", "y")); obj2.setF(Arrays.asList(300L, null, 400L)); - Map<Integer, String> map2 = new HashMap<Integer, String>(); + Map<Integer, String> map2 = new HashMap<>(); map2.put(3, "c"); map2.put(4, "d"); obj2.setG(map2); - Map<String, String> nestedMap2 = new HashMap<String, String>(); + Map<String, String> nestedMap2 = new HashMap<>(); nestedMap2.put("q", "1"); nestedMap2.put("w", "2"); Map<List<Long>, Map<String, String>> complexMap2 = new HashMap<>(); @@ -1328,7 +1267,7 @@ public class JavaDatasetSuite implements Serializable { @Test public void test() { /* SPARK-15285 Large numbers of Nested JavaBeans generates more than 64KB java bytecode */ - List<NestedComplicatedJavaBean> data = new ArrayList<NestedComplicatedJavaBean>(); + List<NestedComplicatedJavaBean> data = new ArrayList<>(); data.add(NestedComplicatedJavaBean.newBuilder().build()); NestedComplicatedJavaBean obj3 = new NestedComplicatedJavaBean(); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java index bbaac5a339..250fa674d8 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java @@ -27,7 +27,6 @@ import org.junit.Test; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; -import org.apache.spark.sql.api.java.UDF1; import org.apache.spark.sql.api.java.UDF2; import org.apache.spark.sql.types.DataTypes; @@ -54,16 +53,7 @@ public class JavaUDFSuite implements Serializable { @SuppressWarnings("unchecked") @Test public void udf1Test() { - // With Java 8 lambdas: - // sqlContext.registerFunction( - // "stringLengthTest", (String str) -> str.length(), DataType.IntegerType); - - spark.udf().register("stringLengthTest", new UDF1<String, Integer>() { - @Override - public Integer call(String str) { - return str.length(); - } - }, DataTypes.IntegerType); + spark.udf().register("stringLengthTest", (String str) -> str.length(), DataTypes.IntegerType); Row result = spark.sql("SELECT stringLengthTest('test')").head(); Assert.assertEquals(4, result.getInt(0)); @@ -72,18 +62,8 @@ public class JavaUDFSuite implements Serializable { @SuppressWarnings("unchecked") @Test public void udf2Test() { - // With Java 8 lambdas: - // sqlContext.registerFunction( - // "stringLengthTest", - // (String str1, String str2) -> str1.length() + str2.length, - // DataType.IntegerType); - - spark.udf().register("stringLengthTest", new UDF2<String, String, Integer>() { - @Override - public Integer call(String str1, String str2) { - return str1.length() + str2.length(); - } - }, DataTypes.IntegerType); + spark.udf().register("stringLengthTest", + (String str1, String str2) -> str1.length() + str2.length(), DataTypes.IntegerType); Row result = spark.sql("SELECT stringLengthTest('test', 'test2')").head(); Assert.assertEquals(9, result.getInt(0)); @@ -91,8 +71,8 @@ public class JavaUDFSuite implements Serializable { public static class StringLengthTest implements UDF2<String, String, Integer> { @Override - public Integer call(String str1, String str2) throws Exception { - return new Integer(str1.length() + str2.length()); + public Integer call(String str1, String str2) { + return str1.length() + str2.length(); } } @@ -113,12 +93,7 @@ public class JavaUDFSuite implements Serializable { @SuppressWarnings("unchecked") @Test public void udf4Test() { - spark.udf().register("inc", new UDF1<Long, Long>() { - @Override - public Long call(Long i) { - return i + 1; - } - }, DataTypes.LongType); + spark.udf().register("inc", (Long i) -> i + 1, DataTypes.LongType); spark.range(10).toDF("x").createOrReplaceTempView("tmp"); // This tests when Java UDFs are required to be the semantically same (See SPARK-9435). |