diff options
author | Reynold Xin <rxin@databricks.com> | 2016-04-01 22:46:56 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2016-04-01 22:46:56 -0700 |
commit | f414154418c2291448954b9f0890d592b2d823ae (patch) | |
tree | 1663d938faacb33b1607e4beb0e9ec5afdf3f493 | |
parent | fa1af0aff7bde9bbf7bfa6a3ac74699734c2fd8a (diff) | |
download | spark-f414154418c2291448954b9f0890d592b2d823ae.tar.gz spark-f414154418c2291448954b9f0890d592b2d823ae.tar.bz2 spark-f414154418c2291448954b9f0890d592b2d823ae.zip |
[SPARK-14285][SQL] Implement common type-safe aggregate functions
## What changes were proposed in this pull request?
In the Dataset API, it is fairly difficult for users to perform simple aggregations in a type-safe way at the moment because there are no aggregators that have been implemented. This pull request adds a few common aggregate functions in expressions.scala.typed package, and also creates the expressions.java.typed package without implementation. The java implementation should probably come as a separate pull request. One challenge there is to resolve the type difference between Scala primitive types and Java boxed types.
## How was this patch tested?
Added unit tests for them.
Author: Reynold Xin <rxin@databricks.com>
Closes #12077 from rxin/SPARK-14285.
9 files changed, 342 insertions, 111 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala new file mode 100644 index 0000000000..9afc29038b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.sql.expressions.Aggregator + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// This file defines internal implementations for aggregators. +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +class TypedSum[IN, OUT : Numeric](f: IN => OUT) extends Aggregator[IN, OUT, OUT] { + val numeric = implicitly[Numeric[OUT]] + override def zero: OUT = numeric.zero + override def reduce(b: OUT, a: IN): OUT = numeric.plus(b, f(a)) + override def merge(b1: OUT, b2: OUT): OUT = numeric.plus(b1, b2) + override def finish(reduction: OUT): OUT = reduction +} + + +class TypedSumDouble[IN](f: IN => Double) extends Aggregator[IN, Double, Double] { + override def zero: Double = 0.0 + override def reduce(b: Double, a: IN): Double = b + f(a) + override def merge(b1: Double, b2: Double): Double = b1 + b2 + override def finish(reduction: Double): Double = reduction +} + + +class TypedSumLong[IN](f: IN => Long) extends Aggregator[IN, Long, Long] { + override def zero: Long = 0L + override def reduce(b: Long, a: IN): Long = b + f(a) + override def merge(b1: Long, b2: Long): Long = b1 + b2 + override def finish(reduction: Long): Long = reduction +} + + +class TypedCount[IN](f: IN => Any) extends Aggregator[IN, Long, Long] { + override def zero: Long = 0 + override def reduce(b: Long, a: IN): Long = { + if (f(a) == null) b else b + 1 + } + override def merge(b1: Long, b2: Long): Long = b1 + b2 + override def finish(reduction: Long): Long = reduction +} + + +class TypedAverage[IN](f: IN => Double) extends Aggregator[IN, (Double, Long), Double] { + override def zero: (Double, Long) = (0.0, 0L) + override def reduce(b: (Double, Long), a: IN): (Double, Long) = (f(a) + b._1, 1 + b._2) + override def finish(reduction: (Double, Long)): Double = reduction._1 / reduction._2 + override def merge(b1: (Double, Long), b2: (Double, Long)): (Double, Long) = { + (b1._1 + b2._1, b1._2 + b2._2) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala index e9b60841fc..350c283646 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala @@ -42,7 +42,7 @@ object Window { * Creates a [[WindowSpec]] with the partitioning defined. * @since 1.4.0 */ - @scala.annotation.varargs + @_root_.scala.annotation.varargs def partitionBy(colName: String, colNames: String*): WindowSpec = { spec.partitionBy(colName, colNames : _*) } @@ -51,7 +51,7 @@ object Window { * Creates a [[WindowSpec]] with the partitioning defined. * @since 1.4.0 */ - @scala.annotation.varargs + @_root_.scala.annotation.varargs def partitionBy(cols: Column*): WindowSpec = { spec.partitionBy(cols : _*) } @@ -60,7 +60,7 @@ object Window { * Creates a [[WindowSpec]] with the ordering defined. * @since 1.4.0 */ - @scala.annotation.varargs + @_root_.scala.annotation.varargs def orderBy(colName: String, colNames: String*): WindowSpec = { spec.orderBy(colName, colNames : _*) } @@ -69,7 +69,7 @@ object Window { * Creates a [[WindowSpec]] with the ordering defined. * @since 1.4.0 */ - @scala.annotation.varargs + @_root_.scala.annotation.varargs def orderBy(cols: Column*): WindowSpec = { spec.orderBy(cols : _*) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index 9e9c58cb66..d716da2668 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -39,7 +39,7 @@ class WindowSpec private[sql]( * Defines the partitioning columns in a [[WindowSpec]]. * @since 1.4.0 */ - @scala.annotation.varargs + @_root_.scala.annotation.varargs def partitionBy(colName: String, colNames: String*): WindowSpec = { partitionBy((colName +: colNames).map(Column(_)): _*) } @@ -48,7 +48,7 @@ class WindowSpec private[sql]( * Defines the partitioning columns in a [[WindowSpec]]. * @since 1.4.0 */ - @scala.annotation.varargs + @_root_.scala.annotation.varargs def partitionBy(cols: Column*): WindowSpec = { new WindowSpec(cols.map(_.expr), orderSpec, frame) } @@ -57,7 +57,7 @@ class WindowSpec private[sql]( * Defines the ordering columns in a [[WindowSpec]]. * @since 1.4.0 */ - @scala.annotation.varargs + @_root_.scala.annotation.varargs def orderBy(colName: String, colNames: String*): WindowSpec = { orderBy((colName +: colNames).map(Column(_)): _*) } @@ -66,7 +66,7 @@ class WindowSpec private[sql]( * Defines the ordering columns in a [[WindowSpec]]. * @since 1.4.0 */ - @scala.annotation.varargs + @_root_.scala.annotation.varargs def orderBy(cols: Column*): WindowSpec = { val sortOrder: Seq[SortOrder] = cols.map { col => col.expr match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java b/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java new file mode 100644 index 0000000000..cdba970d8f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.expressions.java; + +import org.apache.spark.annotation.Experimental; +import org.apache.spark.sql.Dataset; + +/** + * :: Experimental :: + * Type-safe functions available for {@link Dataset} operations in Java. + * + * Scala users should use {@link org.apache.spark.sql.expressions.scala.typed}. + * + * @since 2.0.0 + */ +@Experimental +public class typed { + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/scala/typed.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/scala/typed.scala new file mode 100644 index 0000000000..d0eb190afd --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/scala/typed.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.expressions.scala + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql._ +import org.apache.spark.sql.execution.aggregate._ + +/** + * :: Experimental :: + * Type-safe functions available for [[Dataset]] operations in Scala. + * + * Java users should use [[org.apache.spark.sql.expressions.java.typed]]. + * + * @since 2.0.0 + */ +@Experimental +// scalastyle:off +object typed { + // scalastyle:on + + // Note: whenever we update this file, we should update the corresponding Java version too. + // The reason we have separate files for Java and Scala is because in the Scala version, we can + // use tighter types (primitive types) for return types, whereas in the Java version we can only + // use boxed primitive types. + // For example, avg in the Scala veresion returns Scala primitive Double, whose bytecode + // signature is just a java.lang.Object; avg in the Java version returns java.lang.Double. + + // TODO: This is pretty hacky. Maybe we should have an object for implicit encoders. + private val implicits = new SQLImplicits { + override protected def _sqlContext: SQLContext = null + } + + import implicits._ + + /** + * Average aggregate function. + * + * @since 2.0.0 + */ + def avg[IN](f: IN => Double): TypedColumn[IN, Double] = new TypedAverage(f).toColumn + + /** + * Count aggregate function. + * + * @since 2.0.0 + */ + def count[IN](f: IN => Any): TypedColumn[IN, Long] = new TypedCount(f).toColumn + + /** + * Sum aggregate function for floating point (double) type. + * + * @since 2.0.0 + */ + def sum[IN](f: IN => Double): TypedColumn[IN, Double] = new TypedSumDouble[IN](f).toColumn + + /** + * Sum aggregate function for integral (long, i.e. 64 bit integer) type. + * + * @since 2.0.0 + */ + def sumLong[IN](f: IN => Long): TypedColumn[IN, Long] = new TypedSumLong[IN](f).toColumn + + // TODO: + // stddevOf: Double + // varianceOf: Double + // approxCountDistinct: Long + + // minOf: T + // maxOf: T + + // firstOf: T + // lastOf: T +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala index 8b355befc3..48925910ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala @@ -106,7 +106,7 @@ abstract class UserDefinedAggregateFunction extends Serializable { /** * Creates a [[Column]] for this UDAF using given [[Column]]s as input arguments. */ - @scala.annotation.varargs + @_root_.scala.annotation.varargs def apply(exprs: Column*): Column = { val aggregateExpression = AggregateExpression( @@ -120,7 +120,7 @@ abstract class UserDefinedAggregateFunction extends Serializable { * Creates a [[Column]] for this UDAF using the distinct values of the given * [[Column]]s as input arguments. */ - @scala.annotation.varargs + @_root_.scala.annotation.varargs def distinct(exprs: Column*): Column = { val aggregateExpression = AggregateExpression( diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index a6c819373b..a5ab446e08 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -37,7 +37,6 @@ import org.apache.spark.SparkContext; import org.apache.spark.api.java.function.*; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.*; -import org.apache.spark.sql.expressions.Aggregator; import org.apache.spark.sql.test.TestSQLContext; import org.apache.spark.sql.catalyst.encoders.OuterScopes; import org.apache.spark.sql.catalyst.expressions.GenericRow; @@ -385,59 +384,6 @@ public class JavaDatasetSuite implements Serializable { Assert.assertEquals(data, ds.collectAsList()); } - @Test - public void testTypedAggregation() { - Encoder<Tuple2<String, Integer>> encoder = Encoders.tuple(Encoders.STRING(), Encoders.INT()); - List<Tuple2<String, Integer>> data = - Arrays.asList(tuple2("a", 1), tuple2("a", 2), tuple2("b", 3)); - Dataset<Tuple2<String, Integer>> ds = context.createDataset(data, encoder); - - KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = ds.groupByKey( - new MapFunction<Tuple2<String, Integer>, String>() { - @Override - public String call(Tuple2<String, Integer> value) throws Exception { - return value._1(); - } - }, - Encoders.STRING()); - - Dataset<Tuple2<String, Integer>> agged = - grouped.agg(new IntSumOf().toColumn(Encoders.INT(), Encoders.INT())); - Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList()); - - Dataset<Tuple2<String, Integer>> agged2 = grouped.agg( - new IntSumOf().toColumn(Encoders.INT(), Encoders.INT())) - .as(Encoders.tuple(Encoders.STRING(), Encoders.INT())); - Assert.assertEquals( - Arrays.asList( - new Tuple2<>("a", 3), - new Tuple2<>("b", 3)), - agged2.collectAsList()); - } - - static class IntSumOf extends Aggregator<Tuple2<String, Integer>, Integer, Integer> { - - @Override - public Integer zero() { - return 0; - } - - @Override - public Integer reduce(Integer l, Tuple2<String, Integer> t) { - return l + t._2(); - } - - @Override - public Integer merge(Integer b1, Integer b2) { - return b1 + b2; - } - - @Override - public Integer finish(Integer reduction) { - return reduction; - } - } - public static class KryoSerializable { String value; diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java new file mode 100644 index 0000000000..c4c455b6e6 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import scala.Tuple2; + +import org.apache.spark.SparkContext; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.MapFunction; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.KeyValueGroupedDataset; +import org.apache.spark.sql.expressions.Aggregator; +import org.apache.spark.sql.test.TestSQLContext; + +/** + * Suite for testing the aggregate functionality of Datasets in Java. + */ +public class JavaDatasetAggregatorSuite implements Serializable { + private transient JavaSparkContext jsc; + private transient TestSQLContext context; + + @Before + public void setUp() { + // Trigger static initializer of TestData + SparkContext sc = new SparkContext("local[*]", "testing"); + jsc = new JavaSparkContext(sc); + context = new TestSQLContext(sc); + context.loadTestData(); + } + + @After + public void tearDown() { + context.sparkContext().stop(); + context = null; + jsc = null; + } + + private <T1, T2> Tuple2<T1, T2> tuple2(T1 t1, T2 t2) { + return new Tuple2<>(t1, t2); + } + + private KeyValueGroupedDataset<String, Tuple2<String, Integer>> generateGroupedDataset() { + Encoder<Tuple2<String, Integer>> encoder = Encoders.tuple(Encoders.STRING(), Encoders.INT()); + List<Tuple2<String, Integer>> data = + Arrays.asList(tuple2("a", 1), tuple2("a", 2), tuple2("b", 3)); + Dataset<Tuple2<String, Integer>> ds = context.createDataset(data, encoder); + + return ds.groupByKey( + new MapFunction<Tuple2<String, Integer>, String>() { + @Override + public String call(Tuple2<String, Integer> value) throws Exception { + return value._1(); + } + }, + Encoders.STRING()); + } + + @Test + public void testTypedAggregationAnonClass() { + KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset(); + + Dataset<Tuple2<String, Integer>> agged = + grouped.agg(new IntSumOf().toColumn(Encoders.INT(), Encoders.INT())); + Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList()); + + Dataset<Tuple2<String, Integer>> agged2 = grouped.agg( + new IntSumOf().toColumn(Encoders.INT(), Encoders.INT())) + .as(Encoders.tuple(Encoders.STRING(), Encoders.INT())); + Assert.assertEquals( + Arrays.asList( + new Tuple2<>("a", 3), + new Tuple2<>("b", 3)), + agged2.collectAsList()); + } + + static class IntSumOf extends Aggregator<Tuple2<String, Integer>, Integer, Integer> { + + @Override + public Integer zero() { + return 0; + } + + @Override + public Integer reduce(Integer l, Tuple2<String, Integer> t) { + return l + t._2(); + } + + @Override + public Integer merge(Integer b1, Integer b2) { + return b1 + b2; + } + + @Override + public Integer finish(Integer reduction) { + return reduction; + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 84770169f0..5430aff6ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -20,35 +20,10 @@ package org.apache.spark.sql import scala.language.postfixOps import org.apache.spark.sql.expressions.Aggregator +import org.apache.spark.sql.expressions.scala.typed import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext -/** An `Aggregator` that adds up any numeric type returned by the given function. */ -class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] { - val numeric = implicitly[Numeric[N]] - - override def zero: N = numeric.zero - - override def reduce(b: N, a: I): N = numeric.plus(b, f(a)) - - override def merge(b1: N, b2: N): N = numeric.plus(b1, b2) - - override def finish(reduction: N): N = reduction -} - -object TypedAverage extends Aggregator[(String, Int), (Long, Long), Double] { - override def zero: (Long, Long) = (0, 0) - - override def reduce(countAndSum: (Long, Long), input: (String, Int)): (Long, Long) = { - (countAndSum._1 + 1, countAndSum._2 + input._2) - } - - override def merge(b1: (Long, Long), b2: (Long, Long)): (Long, Long) = { - (b1._1 + b2._1, b1._2 + b2._2) - } - - override def finish(countAndSum: (Long, Long)): Double = countAndSum._2 / countAndSum._1 -} object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, Long)] { @@ -113,15 +88,12 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { import testImplicits._ - def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = - new SumOf(f).toColumn - test("typed aggregation: TypedAggregator") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() checkDataset( - ds.groupByKey(_._1).agg(sum(_._2)), - ("a", 30), ("b", 3), ("c", 1)) + ds.groupByKey(_._1).agg(typed.sum(_._2)), + ("a", 30.0), ("b", 3.0), ("c", 1.0)) } test("typed aggregation: TypedAggregator, expr, expr") { @@ -129,20 +101,10 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { checkDataset( ds.groupByKey(_._1).agg( - sum(_._2), + typed.sum(_._2), expr("sum(_2)").as[Long], count("*")), - ("a", 30, 30L, 2L), ("b", 3, 3L, 2L), ("c", 1, 1L, 1L)) - } - - test("typed aggregation: complex case") { - val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS() - - checkDataset( - ds.groupByKey(_._1).agg( - expr("avg(_2)").as[Double], - TypedAverage.toColumn), - ("a", 2.0, 2.0), ("b", 3.0, 3.0)) + ("a", 30.0, 30L, 2L), ("b", 3.0, 3L, 2L), ("c", 1.0, 1L, 1L)) } test("typed aggregation: complex result type") { @@ -159,11 +121,11 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { val ds = Seq(1, 3, 2, 5).toDS() checkDataset( - ds.select(sum((i: Int) => i)), - 11) + ds.select(typed.sum((i: Int) => i)), + 11.0) checkDataset( - ds.select(sum((i: Int) => i), sum((i: Int) => i * 2)), - 11 -> 22) + ds.select(typed.sum((i: Int) => i), typed.sum((i: Int) => i * 2)), + 11.0 -> 22.0) } test("typed aggregation: class input") { @@ -206,4 +168,12 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { ds.groupByKey(_.b).agg(ComplexBufferAgg.toColumn), ("one", 1), ("two", 1)) } + + test("typed aggregate: avg, count, sum") { + val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS() + checkDataset( + ds.groupByKey(_._1).agg( + typed.avg(_._2), typed.count(_._2), typed.sum(_._2), typed.sumLong(_._2)), + ("a", 2.0, 2L, 4.0, 4L), ("b", 3.0, 1L, 3.0, 3L)) + } } |