diff options
author | Eric Liang <ekl@databricks.com> | 2016-04-05 00:30:55 -0500 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2016-04-05 00:30:55 -0500 |
commit | 064623014e0d6dfb0376722f24e81027fde649de (patch) | |
tree | 4ef26a921ede1724428746ae97f414e705ac9033 /sql/core/src/test/java | |
parent | 7db56244fa3dba92246bad6694f31bbf68ea47ec (diff) | |
download | spark-064623014e0d6dfb0376722f24e81027fde649de.tar.gz spark-064623014e0d6dfb0376722f24e81027fde649de.tar.bz2 spark-064623014e0d6dfb0376722f24e81027fde649de.zip |
[SPARK-14359] Create built-in functions for typed aggregates in Java
## What changes were proposed in this pull request?
This adds the corresponding Java static functions for built-in typed aggregates already exposed in Scala.
## How was this patch tested?
Unit tests.
rxin
Author: Eric Liang <ekl@databricks.com>
Closes #12168 from ericl/sc-2794.
Diffstat (limited to 'sql/core/src/test/java')
-rw-r--r-- | sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java | 49 |
1 files changed, 49 insertions, 0 deletions
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java index c4c455b6e6..c8d0eecd5c 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java @@ -35,6 +35,7 @@ import org.apache.spark.sql.Encoder; import org.apache.spark.sql.Encoders; import org.apache.spark.sql.KeyValueGroupedDataset; import org.apache.spark.sql.expressions.Aggregator; +import org.apache.spark.sql.expressions.java.typed; import org.apache.spark.sql.test.TestSQLContext; /** @@ -120,4 +121,52 @@ public class JavaDatasetAggregatorSuite implements Serializable { return reduction; } } + + @Test + public void testTypedAggregationAverage() { + KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset(); + Dataset<Tuple2<String, Double>> agged = grouped.agg(typed.avg( + new MapFunction<Tuple2<String, Integer>, Double>() { + public Double call(Tuple2<String, Integer> value) throws Exception { + return (double)(value._2() * 2); + } + })); + Assert.assertEquals(Arrays.asList(tuple2("a", 3.0), tuple2("b", 6.0)), agged.collectAsList()); + } + + @Test + public void testTypedAggregationCount() { + KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset(); + Dataset<Tuple2<String, Long>> agged = grouped.agg(typed.count( + new MapFunction<Tuple2<String, Integer>, Object>() { + public Object call(Tuple2<String, Integer> value) throws Exception { + return value; + } + })); + Assert.assertEquals(Arrays.asList(tuple2("a", 2), tuple2("b", 1)), agged.collectAsList()); + } + + @Test + public void testTypedAggregationSumDouble() { + KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset(); + Dataset<Tuple2<String, Double>> agged = grouped.agg(typed.sum( + new MapFunction<Tuple2<String, Integer>, Double>() { + public Double call(Tuple2<String, Integer> value) throws Exception { + return (double)value._2(); + } + })); + Assert.assertEquals(Arrays.asList(tuple2("a", 3.0), tuple2("b", 3.0)), agged.collectAsList()); + } + + @Test + public void testTypedAggregationSumLong() { + KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset(); + Dataset<Tuple2<String, Long>> agged = grouped.agg(typed.sumLong( + new MapFunction<Tuple2<String, Integer>, Long>() { + public Long call(Tuple2<String, Integer> value) throws Exception { + return (long)value._2(); + } + })); + Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList()); + } } |