aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/test/scala
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2016-04-01 22:46:56 -0700
committerReynold Xin <rxin@databricks.com>2016-04-01 22:46:56 -0700
commitf414154418c2291448954b9f0890d592b2d823ae (patch)
tree1663d938faacb33b1607e4beb0e9ec5afdf3f493 /sql/core/src/test/scala
parentfa1af0aff7bde9bbf7bfa6a3ac74699734c2fd8a (diff)
downloadspark-f414154418c2291448954b9f0890d592b2d823ae.tar.gz
spark-f414154418c2291448954b9f0890d592b2d823ae.tar.bz2
spark-f414154418c2291448954b9f0890d592b2d823ae.zip
[SPARK-14285][SQL] Implement common type-safe aggregate functions
## What changes were proposed in this pull request? In the Dataset API, it is fairly difficult for users to perform simple aggregations in a type-safe way at the moment because there are no aggregators that have been implemented. This pull request adds a few common aggregate functions in expressions.scala.typed package, and also creates the expressions.java.typed package without implementation. The java implementation should probably come as a separate pull request. One challenge there is to resolve the type difference between Scala primitive types and Java boxed types. ## How was this patch tested? Added unit tests for them. Author: Reynold Xin <rxin@databricks.com> Closes #12077 from rxin/SPARK-14285.
Diffstat (limited to 'sql/core/src/test/scala')
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala64
1 files changed, 17 insertions, 47 deletions
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
index 84770169f0..5430aff6ce 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
@@ -20,35 +20,10 @@ package org.apache.spark.sql
import scala.language.postfixOps
import org.apache.spark.sql.expressions.Aggregator
+import org.apache.spark.sql.expressions.scala.typed
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
-/** An `Aggregator` that adds up any numeric type returned by the given function. */
-class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] {
- val numeric = implicitly[Numeric[N]]
-
- override def zero: N = numeric.zero
-
- override def reduce(b: N, a: I): N = numeric.plus(b, f(a))
-
- override def merge(b1: N, b2: N): N = numeric.plus(b1, b2)
-
- override def finish(reduction: N): N = reduction
-}
-
-object TypedAverage extends Aggregator[(String, Int), (Long, Long), Double] {
- override def zero: (Long, Long) = (0, 0)
-
- override def reduce(countAndSum: (Long, Long), input: (String, Int)): (Long, Long) = {
- (countAndSum._1 + 1, countAndSum._2 + input._2)
- }
-
- override def merge(b1: (Long, Long), b2: (Long, Long)): (Long, Long) = {
- (b1._1 + b2._1, b1._2 + b2._2)
- }
-
- override def finish(countAndSum: (Long, Long)): Double = countAndSum._2 / countAndSum._1
-}
object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, Long)] {
@@ -113,15 +88,12 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
import testImplicits._
- def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] =
- new SumOf(f).toColumn
-
test("typed aggregation: TypedAggregator") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
checkDataset(
- ds.groupByKey(_._1).agg(sum(_._2)),
- ("a", 30), ("b", 3), ("c", 1))
+ ds.groupByKey(_._1).agg(typed.sum(_._2)),
+ ("a", 30.0), ("b", 3.0), ("c", 1.0))
}
test("typed aggregation: TypedAggregator, expr, expr") {
@@ -129,20 +101,10 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
checkDataset(
ds.groupByKey(_._1).agg(
- sum(_._2),
+ typed.sum(_._2),
expr("sum(_2)").as[Long],
count("*")),
- ("a", 30, 30L, 2L), ("b", 3, 3L, 2L), ("c", 1, 1L, 1L))
- }
-
- test("typed aggregation: complex case") {
- val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS()
-
- checkDataset(
- ds.groupByKey(_._1).agg(
- expr("avg(_2)").as[Double],
- TypedAverage.toColumn),
- ("a", 2.0, 2.0), ("b", 3.0, 3.0))
+ ("a", 30.0, 30L, 2L), ("b", 3.0, 3L, 2L), ("c", 1.0, 1L, 1L))
}
test("typed aggregation: complex result type") {
@@ -159,11 +121,11 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
val ds = Seq(1, 3, 2, 5).toDS()
checkDataset(
- ds.select(sum((i: Int) => i)),
- 11)
+ ds.select(typed.sum((i: Int) => i)),
+ 11.0)
checkDataset(
- ds.select(sum((i: Int) => i), sum((i: Int) => i * 2)),
- 11 -> 22)
+ ds.select(typed.sum((i: Int) => i), typed.sum((i: Int) => i * 2)),
+ 11.0 -> 22.0)
}
test("typed aggregation: class input") {
@@ -206,4 +168,12 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
ds.groupByKey(_.b).agg(ComplexBufferAgg.toColumn),
("one", 1), ("two", 1))
}
+
+ test("typed aggregate: avg, count, sum") {
+ val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS()
+ checkDataset(
+ ds.groupByKey(_._1).agg(
+ typed.avg(_._2), typed.count(_._2), typed.sum(_._2), typed.sumLong(_._2)),
+ ("a", 2.0, 2L, 4.0, 4L), ("b", 3.0, 1L, 3.0, 3L))
+ }
}