diff options
author | Reynold Xin <rxin@databricks.com> | 2016-04-01 22:46:56 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2016-04-01 22:46:56 -0700 |
commit | f414154418c2291448954b9f0890d592b2d823ae (patch) | |
tree | 1663d938faacb33b1607e4beb0e9ec5afdf3f493 /sql/core/src/test/scala | |
parent | fa1af0aff7bde9bbf7bfa6a3ac74699734c2fd8a (diff) | |
download | spark-f414154418c2291448954b9f0890d592b2d823ae.tar.gz spark-f414154418c2291448954b9f0890d592b2d823ae.tar.bz2 spark-f414154418c2291448954b9f0890d592b2d823ae.zip |
[SPARK-14285][SQL] Implement common type-safe aggregate functions
## What changes were proposed in this pull request?
In the Dataset API, it is fairly difficult for users to perform simple aggregations in a type-safe way at the moment because there are no aggregators that have been implemented. This pull request adds a few common aggregate functions in expressions.scala.typed package, and also creates the expressions.java.typed package without implementation. The java implementation should probably come as a separate pull request. One challenge there is to resolve the type difference between Scala primitive types and Java boxed types.
## How was this patch tested?
Added unit tests for them.
Author: Reynold Xin <rxin@databricks.com>
Closes #12077 from rxin/SPARK-14285.
Diffstat (limited to 'sql/core/src/test/scala')
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala | 64 |
1 files changed, 17 insertions, 47 deletions
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 84770169f0..5430aff6ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -20,35 +20,10 @@ package org.apache.spark.sql import scala.language.postfixOps import org.apache.spark.sql.expressions.Aggregator +import org.apache.spark.sql.expressions.scala.typed import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext -/** An `Aggregator` that adds up any numeric type returned by the given function. */ -class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] { - val numeric = implicitly[Numeric[N]] - - override def zero: N = numeric.zero - - override def reduce(b: N, a: I): N = numeric.plus(b, f(a)) - - override def merge(b1: N, b2: N): N = numeric.plus(b1, b2) - - override def finish(reduction: N): N = reduction -} - -object TypedAverage extends Aggregator[(String, Int), (Long, Long), Double] { - override def zero: (Long, Long) = (0, 0) - - override def reduce(countAndSum: (Long, Long), input: (String, Int)): (Long, Long) = { - (countAndSum._1 + 1, countAndSum._2 + input._2) - } - - override def merge(b1: (Long, Long), b2: (Long, Long)): (Long, Long) = { - (b1._1 + b2._1, b1._2 + b2._2) - } - - override def finish(countAndSum: (Long, Long)): Double = countAndSum._2 / countAndSum._1 -} object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, Long)] { @@ -113,15 +88,12 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { import testImplicits._ - def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = - new SumOf(f).toColumn - test("typed aggregation: TypedAggregator") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() checkDataset( - ds.groupByKey(_._1).agg(sum(_._2)), - ("a", 30), ("b", 3), ("c", 1)) + ds.groupByKey(_._1).agg(typed.sum(_._2)), + ("a", 30.0), ("b", 3.0), ("c", 1.0)) } test("typed aggregation: TypedAggregator, expr, expr") { @@ -129,20 +101,10 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { checkDataset( ds.groupByKey(_._1).agg( - sum(_._2), + typed.sum(_._2), expr("sum(_2)").as[Long], count("*")), - ("a", 30, 30L, 2L), ("b", 3, 3L, 2L), ("c", 1, 1L, 1L)) - } - - test("typed aggregation: complex case") { - val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS() - - checkDataset( - ds.groupByKey(_._1).agg( - expr("avg(_2)").as[Double], - TypedAverage.toColumn), - ("a", 2.0, 2.0), ("b", 3.0, 3.0)) + ("a", 30.0, 30L, 2L), ("b", 3.0, 3L, 2L), ("c", 1.0, 1L, 1L)) } test("typed aggregation: complex result type") { @@ -159,11 +121,11 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { val ds = Seq(1, 3, 2, 5).toDS() checkDataset( - ds.select(sum((i: Int) => i)), - 11) + ds.select(typed.sum((i: Int) => i)), + 11.0) checkDataset( - ds.select(sum((i: Int) => i), sum((i: Int) => i * 2)), - 11 -> 22) + ds.select(typed.sum((i: Int) => i), typed.sum((i: Int) => i * 2)), + 11.0 -> 22.0) } test("typed aggregation: class input") { @@ -206,4 +168,12 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { ds.groupByKey(_.b).agg(ComplexBufferAgg.toColumn), ("one", 1), ("two", 1)) } + + test("typed aggregate: avg, count, sum") { + val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS() + checkDataset( + ds.groupByKey(_._1).agg( + typed.avg(_._2), typed.count(_._2), typed.sum(_._2), typed.sumLong(_._2)), + ("a", 2.0, 2L, 4.0, 4L), ("b", 3.0, 1L, 3.0, 3L)) + } } |