aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/test/java
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2016-04-01 22:46:56 -0700
committerReynold Xin <rxin@databricks.com>2016-04-01 22:46:56 -0700
commitf414154418c2291448954b9f0890d592b2d823ae (patch)
tree1663d938faacb33b1607e4beb0e9ec5afdf3f493 /sql/core/src/test/java
parentfa1af0aff7bde9bbf7bfa6a3ac74699734c2fd8a (diff)
downloadspark-f414154418c2291448954b9f0890d592b2d823ae.tar.gz
spark-f414154418c2291448954b9f0890d592b2d823ae.tar.bz2
spark-f414154418c2291448954b9f0890d592b2d823ae.zip
[SPARK-14285][SQL] Implement common type-safe aggregate functions
## What changes were proposed in this pull request? In the Dataset API, it is fairly difficult for users to perform simple aggregations in a type-safe way at the moment because there are no aggregators that have been implemented. This pull request adds a few common aggregate functions in expressions.scala.typed package, and also creates the expressions.java.typed package without implementation. The java implementation should probably come as a separate pull request. One challenge there is to resolve the type difference between Scala primitive types and Java boxed types. ## How was this patch tested? Added unit tests for them. Author: Reynold Xin <rxin@databricks.com> Closes #12077 from rxin/SPARK-14285.
Diffstat (limited to 'sql/core/src/test/java')
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java54
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java123
2 files changed, 123 insertions, 54 deletions
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index a6c819373b..a5ab446e08 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -37,7 +37,6 @@ import org.apache.spark.SparkContext;
import org.apache.spark.api.java.function.*;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.*;
-import org.apache.spark.sql.expressions.Aggregator;
import org.apache.spark.sql.test.TestSQLContext;
import org.apache.spark.sql.catalyst.encoders.OuterScopes;
import org.apache.spark.sql.catalyst.expressions.GenericRow;
@@ -385,59 +384,6 @@ public class JavaDatasetSuite implements Serializable {
Assert.assertEquals(data, ds.collectAsList());
}
- @Test
- public void testTypedAggregation() {
- Encoder<Tuple2<String, Integer>> encoder = Encoders.tuple(Encoders.STRING(), Encoders.INT());
- List<Tuple2<String, Integer>> data =
- Arrays.asList(tuple2("a", 1), tuple2("a", 2), tuple2("b", 3));
- Dataset<Tuple2<String, Integer>> ds = context.createDataset(data, encoder);
-
- KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = ds.groupByKey(
- new MapFunction<Tuple2<String, Integer>, String>() {
- @Override
- public String call(Tuple2<String, Integer> value) throws Exception {
- return value._1();
- }
- },
- Encoders.STRING());
-
- Dataset<Tuple2<String, Integer>> agged =
- grouped.agg(new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()));
- Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList());
-
- Dataset<Tuple2<String, Integer>> agged2 = grouped.agg(
- new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()))
- .as(Encoders.tuple(Encoders.STRING(), Encoders.INT()));
- Assert.assertEquals(
- Arrays.asList(
- new Tuple2<>("a", 3),
- new Tuple2<>("b", 3)),
- agged2.collectAsList());
- }
-
- static class IntSumOf extends Aggregator<Tuple2<String, Integer>, Integer, Integer> {
-
- @Override
- public Integer zero() {
- return 0;
- }
-
- @Override
- public Integer reduce(Integer l, Tuple2<String, Integer> t) {
- return l + t._2();
- }
-
- @Override
- public Integer merge(Integer b1, Integer b2) {
- return b1 + b2;
- }
-
- @Override
- public Integer finish(Integer reduction) {
- return reduction;
- }
- }
-
public static class KryoSerializable {
String value;
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java
new file mode 100644
index 0000000000..c4c455b6e6
--- /dev/null
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package test.org.apache.spark.sql.sources;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+import scala.Tuple2;
+
+import org.apache.spark.SparkContext;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.MapFunction;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.Encoders;
+import org.apache.spark.sql.KeyValueGroupedDataset;
+import org.apache.spark.sql.expressions.Aggregator;
+import org.apache.spark.sql.test.TestSQLContext;
+
+/**
+ * Suite for testing the aggregate functionality of Datasets in Java.
+ */
+public class JavaDatasetAggregatorSuite implements Serializable {
+ private transient JavaSparkContext jsc;
+ private transient TestSQLContext context;
+
+ @Before
+ public void setUp() {
+ // Trigger static initializer of TestData
+ SparkContext sc = new SparkContext("local[*]", "testing");
+ jsc = new JavaSparkContext(sc);
+ context = new TestSQLContext(sc);
+ context.loadTestData();
+ }
+
+ @After
+ public void tearDown() {
+ context.sparkContext().stop();
+ context = null;
+ jsc = null;
+ }
+
+ private <T1, T2> Tuple2<T1, T2> tuple2(T1 t1, T2 t2) {
+ return new Tuple2<>(t1, t2);
+ }
+
+ private KeyValueGroupedDataset<String, Tuple2<String, Integer>> generateGroupedDataset() {
+ Encoder<Tuple2<String, Integer>> encoder = Encoders.tuple(Encoders.STRING(), Encoders.INT());
+ List<Tuple2<String, Integer>> data =
+ Arrays.asList(tuple2("a", 1), tuple2("a", 2), tuple2("b", 3));
+ Dataset<Tuple2<String, Integer>> ds = context.createDataset(data, encoder);
+
+ return ds.groupByKey(
+ new MapFunction<Tuple2<String, Integer>, String>() {
+ @Override
+ public String call(Tuple2<String, Integer> value) throws Exception {
+ return value._1();
+ }
+ },
+ Encoders.STRING());
+ }
+
+ @Test
+ public void testTypedAggregationAnonClass() {
+ KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
+
+ Dataset<Tuple2<String, Integer>> agged =
+ grouped.agg(new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()));
+ Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList());
+
+ Dataset<Tuple2<String, Integer>> agged2 = grouped.agg(
+ new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()))
+ .as(Encoders.tuple(Encoders.STRING(), Encoders.INT()));
+ Assert.assertEquals(
+ Arrays.asList(
+ new Tuple2<>("a", 3),
+ new Tuple2<>("b", 3)),
+ agged2.collectAsList());
+ }
+
+ static class IntSumOf extends Aggregator<Tuple2<String, Integer>, Integer, Integer> {
+
+ @Override
+ public Integer zero() {
+ return 0;
+ }
+
+ @Override
+ public Integer reduce(Integer l, Tuple2<String, Integer> t) {
+ return l + t._2();
+ }
+
+ @Override
+ public Integer merge(Integer b1, Integer b2) {
+ return b1 + b2;
+ }
+
+ @Override
+ public Integer finish(Integer reduction) {
+ return reduction;
+ }
+ }
+}