aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/test/java
diff options
context:
space:
mode:
authorEric Liang <ekl@databricks.com>2016-04-05 00:30:55 -0500
committerReynold Xin <rxin@databricks.com>2016-04-05 00:30:55 -0500
commit064623014e0d6dfb0376722f24e81027fde649de (patch)
tree4ef26a921ede1724428746ae97f414e705ac9033 /sql/core/src/test/java
parent7db56244fa3dba92246bad6694f31bbf68ea47ec (diff)
downloadspark-064623014e0d6dfb0376722f24e81027fde649de.tar.gz
spark-064623014e0d6dfb0376722f24e81027fde649de.tar.bz2
spark-064623014e0d6dfb0376722f24e81027fde649de.zip
[SPARK-14359] Create built-in functions for typed aggregates in Java
## What changes were proposed in this pull request? This adds the corresponding Java static functions for built-in typed aggregates already exposed in Scala. ## How was this patch tested? Unit tests. rxin Author: Eric Liang <ekl@databricks.com> Closes #12168 from ericl/sc-2794.
Diffstat (limited to 'sql/core/src/test/java')
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java49
1 files changed, 49 insertions, 0 deletions
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java
index c4c455b6e6..c8d0eecd5c 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java
@@ -35,6 +35,7 @@ import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.KeyValueGroupedDataset;
import org.apache.spark.sql.expressions.Aggregator;
+import org.apache.spark.sql.expressions.java.typed;
import org.apache.spark.sql.test.TestSQLContext;
/**
@@ -120,4 +121,52 @@ public class JavaDatasetAggregatorSuite implements Serializable {
return reduction;
}
}
+
+ @Test
+ public void testTypedAggregationAverage() {
+ KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
+ Dataset<Tuple2<String, Double>> agged = grouped.agg(typed.avg(
+ new MapFunction<Tuple2<String, Integer>, Double>() {
+ public Double call(Tuple2<String, Integer> value) throws Exception {
+ return (double)(value._2() * 2);
+ }
+ }));
+ Assert.assertEquals(Arrays.asList(tuple2("a", 3.0), tuple2("b", 6.0)), agged.collectAsList());
+ }
+
+ @Test
+ public void testTypedAggregationCount() {
+ KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
+ Dataset<Tuple2<String, Long>> agged = grouped.agg(typed.count(
+ new MapFunction<Tuple2<String, Integer>, Object>() {
+ public Object call(Tuple2<String, Integer> value) throws Exception {
+ return value;
+ }
+ }));
+ Assert.assertEquals(Arrays.asList(tuple2("a", 2), tuple2("b", 1)), agged.collectAsList());
+ }
+
+ @Test
+ public void testTypedAggregationSumDouble() {
+ KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
+ Dataset<Tuple2<String, Double>> agged = grouped.agg(typed.sum(
+ new MapFunction<Tuple2<String, Integer>, Double>() {
+ public Double call(Tuple2<String, Integer> value) throws Exception {
+ return (double)value._2();
+ }
+ }));
+ Assert.assertEquals(Arrays.asList(tuple2("a", 3.0), tuple2("b", 3.0)), agged.collectAsList());
+ }
+
+ @Test
+ public void testTypedAggregationSumLong() {
+ KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
+ Dataset<Tuple2<String, Long>> agged = grouped.agg(typed.sumLong(
+ new MapFunction<Tuple2<String, Integer>, Long>() {
+ public Long call(Tuple2<String, Integer> value) throws Exception {
+ return (long)value._2();
+ }
+ }));
+ Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList());
+ }
}