aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2016-04-01 22:46:56 -0700
committerReynold Xin <rxin@databricks.com>2016-04-01 22:46:56 -0700
commitf414154418c2291448954b9f0890d592b2d823ae (patch)
tree1663d938faacb33b1607e4beb0e9ec5afdf3f493
parentfa1af0aff7bde9bbf7bfa6a3ac74699734c2fd8a (diff)
downloadspark-f414154418c2291448954b9f0890d592b2d823ae.tar.gz
spark-f414154418c2291448954b9f0890d592b2d823ae.tar.bz2
spark-f414154418c2291448954b9f0890d592b2d823ae.zip
[SPARK-14285][SQL] Implement common type-safe aggregate functions
## What changes were proposed in this pull request? In the Dataset API, it is fairly difficult for users to perform simple aggregations in a type-safe way at the moment because there are no aggregators that have been implemented. This pull request adds a few common aggregate functions in expressions.scala.typed package, and also creates the expressions.java.typed package without implementation. The java implementation should probably come as a separate pull request. One challenge there is to resolve the type difference between Scala primitive types and Java boxed types. ## How was this patch tested? Added unit tests for them. Author: Reynold Xin <rxin@databricks.com> Closes #12077 from rxin/SPARK-14285.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala69
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java34
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/expressions/scala/typed.scala89
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala4
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java54
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java123
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala64
9 files changed, 342 insertions, 111 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala
new file mode 100644
index 0000000000..9afc29038b
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.sql.expressions.Aggregator
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+// This file defines internal implementations for aggregators.
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+
+class TypedSum[IN, OUT : Numeric](f: IN => OUT) extends Aggregator[IN, OUT, OUT] {
+ val numeric = implicitly[Numeric[OUT]]
+ override def zero: OUT = numeric.zero
+ override def reduce(b: OUT, a: IN): OUT = numeric.plus(b, f(a))
+ override def merge(b1: OUT, b2: OUT): OUT = numeric.plus(b1, b2)
+ override def finish(reduction: OUT): OUT = reduction
+}
+
+
+class TypedSumDouble[IN](f: IN => Double) extends Aggregator[IN, Double, Double] {
+ override def zero: Double = 0.0
+ override def reduce(b: Double, a: IN): Double = b + f(a)
+ override def merge(b1: Double, b2: Double): Double = b1 + b2
+ override def finish(reduction: Double): Double = reduction
+}
+
+
+class TypedSumLong[IN](f: IN => Long) extends Aggregator[IN, Long, Long] {
+ override def zero: Long = 0L
+ override def reduce(b: Long, a: IN): Long = b + f(a)
+ override def merge(b1: Long, b2: Long): Long = b1 + b2
+ override def finish(reduction: Long): Long = reduction
+}
+
+
+class TypedCount[IN](f: IN => Any) extends Aggregator[IN, Long, Long] {
+ override def zero: Long = 0
+ override def reduce(b: Long, a: IN): Long = {
+ if (f(a) == null) b else b + 1
+ }
+ override def merge(b1: Long, b2: Long): Long = b1 + b2
+ override def finish(reduction: Long): Long = reduction
+}
+
+
+class TypedAverage[IN](f: IN => Double) extends Aggregator[IN, (Double, Long), Double] {
+ override def zero: (Double, Long) = (0.0, 0L)
+ override def reduce(b: (Double, Long), a: IN): (Double, Long) = (f(a) + b._1, 1 + b._2)
+ override def finish(reduction: (Double, Long)): Double = reduction._1 / reduction._2
+ override def merge(b1: (Double, Long), b2: (Double, Long)): (Double, Long) = {
+ (b1._1 + b2._1, b1._2 + b2._2)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala
index e9b60841fc..350c283646 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala
@@ -42,7 +42,7 @@ object Window {
* Creates a [[WindowSpec]] with the partitioning defined.
* @since 1.4.0
*/
- @scala.annotation.varargs
+ @_root_.scala.annotation.varargs
def partitionBy(colName: String, colNames: String*): WindowSpec = {
spec.partitionBy(colName, colNames : _*)
}
@@ -51,7 +51,7 @@ object Window {
* Creates a [[WindowSpec]] with the partitioning defined.
* @since 1.4.0
*/
- @scala.annotation.varargs
+ @_root_.scala.annotation.varargs
def partitionBy(cols: Column*): WindowSpec = {
spec.partitionBy(cols : _*)
}
@@ -60,7 +60,7 @@ object Window {
* Creates a [[WindowSpec]] with the ordering defined.
* @since 1.4.0
*/
- @scala.annotation.varargs
+ @_root_.scala.annotation.varargs
def orderBy(colName: String, colNames: String*): WindowSpec = {
spec.orderBy(colName, colNames : _*)
}
@@ -69,7 +69,7 @@ object Window {
* Creates a [[WindowSpec]] with the ordering defined.
* @since 1.4.0
*/
- @scala.annotation.varargs
+ @_root_.scala.annotation.varargs
def orderBy(cols: Column*): WindowSpec = {
spec.orderBy(cols : _*)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala
index 9e9c58cb66..d716da2668 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala
@@ -39,7 +39,7 @@ class WindowSpec private[sql](
* Defines the partitioning columns in a [[WindowSpec]].
* @since 1.4.0
*/
- @scala.annotation.varargs
+ @_root_.scala.annotation.varargs
def partitionBy(colName: String, colNames: String*): WindowSpec = {
partitionBy((colName +: colNames).map(Column(_)): _*)
}
@@ -48,7 +48,7 @@ class WindowSpec private[sql](
* Defines the partitioning columns in a [[WindowSpec]].
* @since 1.4.0
*/
- @scala.annotation.varargs
+ @_root_.scala.annotation.varargs
def partitionBy(cols: Column*): WindowSpec = {
new WindowSpec(cols.map(_.expr), orderSpec, frame)
}
@@ -57,7 +57,7 @@ class WindowSpec private[sql](
* Defines the ordering columns in a [[WindowSpec]].
* @since 1.4.0
*/
- @scala.annotation.varargs
+ @_root_.scala.annotation.varargs
def orderBy(colName: String, colNames: String*): WindowSpec = {
orderBy((colName +: colNames).map(Column(_)): _*)
}
@@ -66,7 +66,7 @@ class WindowSpec private[sql](
* Defines the ordering columns in a [[WindowSpec]].
* @since 1.4.0
*/
- @scala.annotation.varargs
+ @_root_.scala.annotation.varargs
def orderBy(cols: Column*): WindowSpec = {
val sortOrder: Seq[SortOrder] = cols.map { col =>
col.expr match {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java b/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java
new file mode 100644
index 0000000000..cdba970d8f
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.expressions.java;
+
+import org.apache.spark.annotation.Experimental;
+import org.apache.spark.sql.Dataset;
+
+/**
+ * :: Experimental ::
+ * Type-safe functions available for {@link Dataset} operations in Java.
+ *
+ * Scala users should use {@link org.apache.spark.sql.expressions.scala.typed}.
+ *
+ * @since 2.0.0
+ */
+@Experimental
+public class typed {
+
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/scala/typed.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/scala/typed.scala
new file mode 100644
index 0000000000..d0eb190afd
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/scala/typed.scala
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.expressions.scala
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.sql._
+import org.apache.spark.sql.execution.aggregate._
+
+/**
+ * :: Experimental ::
+ * Type-safe functions available for [[Dataset]] operations in Scala.
+ *
+ * Java users should use [[org.apache.spark.sql.expressions.java.typed]].
+ *
+ * @since 2.0.0
+ */
+@Experimental
+// scalastyle:off
+object typed {
+ // scalastyle:on
+
+ // Note: whenever we update this file, we should update the corresponding Java version too.
+ // The reason we have separate files for Java and Scala is because in the Scala version, we can
+ // use tighter types (primitive types) for return types, whereas in the Java version we can only
+ // use boxed primitive types.
+ // For example, avg in the Scala veresion returns Scala primitive Double, whose bytecode
+ // signature is just a java.lang.Object; avg in the Java version returns java.lang.Double.
+
+ // TODO: This is pretty hacky. Maybe we should have an object for implicit encoders.
+ private val implicits = new SQLImplicits {
+ override protected def _sqlContext: SQLContext = null
+ }
+
+ import implicits._
+
+ /**
+ * Average aggregate function.
+ *
+ * @since 2.0.0
+ */
+ def avg[IN](f: IN => Double): TypedColumn[IN, Double] = new TypedAverage(f).toColumn
+
+ /**
+ * Count aggregate function.
+ *
+ * @since 2.0.0
+ */
+ def count[IN](f: IN => Any): TypedColumn[IN, Long] = new TypedCount(f).toColumn
+
+ /**
+ * Sum aggregate function for floating point (double) type.
+ *
+ * @since 2.0.0
+ */
+ def sum[IN](f: IN => Double): TypedColumn[IN, Double] = new TypedSumDouble[IN](f).toColumn
+
+ /**
+ * Sum aggregate function for integral (long, i.e. 64 bit integer) type.
+ *
+ * @since 2.0.0
+ */
+ def sumLong[IN](f: IN => Long): TypedColumn[IN, Long] = new TypedSumLong[IN](f).toColumn
+
+ // TODO:
+ // stddevOf: Double
+ // varianceOf: Double
+ // approxCountDistinct: Long
+
+ // minOf: T
+ // maxOf: T
+
+ // firstOf: T
+ // lastOf: T
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala
index 8b355befc3..48925910ac 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala
@@ -106,7 +106,7 @@ abstract class UserDefinedAggregateFunction extends Serializable {
/**
* Creates a [[Column]] for this UDAF using given [[Column]]s as input arguments.
*/
- @scala.annotation.varargs
+ @_root_.scala.annotation.varargs
def apply(exprs: Column*): Column = {
val aggregateExpression =
AggregateExpression(
@@ -120,7 +120,7 @@ abstract class UserDefinedAggregateFunction extends Serializable {
* Creates a [[Column]] for this UDAF using the distinct values of the given
* [[Column]]s as input arguments.
*/
- @scala.annotation.varargs
+ @_root_.scala.annotation.varargs
def distinct(exprs: Column*): Column = {
val aggregateExpression =
AggregateExpression(
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index a6c819373b..a5ab446e08 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -37,7 +37,6 @@ import org.apache.spark.SparkContext;
import org.apache.spark.api.java.function.*;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.*;
-import org.apache.spark.sql.expressions.Aggregator;
import org.apache.spark.sql.test.TestSQLContext;
import org.apache.spark.sql.catalyst.encoders.OuterScopes;
import org.apache.spark.sql.catalyst.expressions.GenericRow;
@@ -385,59 +384,6 @@ public class JavaDatasetSuite implements Serializable {
Assert.assertEquals(data, ds.collectAsList());
}
- @Test
- public void testTypedAggregation() {
- Encoder<Tuple2<String, Integer>> encoder = Encoders.tuple(Encoders.STRING(), Encoders.INT());
- List<Tuple2<String, Integer>> data =
- Arrays.asList(tuple2("a", 1), tuple2("a", 2), tuple2("b", 3));
- Dataset<Tuple2<String, Integer>> ds = context.createDataset(data, encoder);
-
- KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = ds.groupByKey(
- new MapFunction<Tuple2<String, Integer>, String>() {
- @Override
- public String call(Tuple2<String, Integer> value) throws Exception {
- return value._1();
- }
- },
- Encoders.STRING());
-
- Dataset<Tuple2<String, Integer>> agged =
- grouped.agg(new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()));
- Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList());
-
- Dataset<Tuple2<String, Integer>> agged2 = grouped.agg(
- new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()))
- .as(Encoders.tuple(Encoders.STRING(), Encoders.INT()));
- Assert.assertEquals(
- Arrays.asList(
- new Tuple2<>("a", 3),
- new Tuple2<>("b", 3)),
- agged2.collectAsList());
- }
-
- static class IntSumOf extends Aggregator<Tuple2<String, Integer>, Integer, Integer> {
-
- @Override
- public Integer zero() {
- return 0;
- }
-
- @Override
- public Integer reduce(Integer l, Tuple2<String, Integer> t) {
- return l + t._2();
- }
-
- @Override
- public Integer merge(Integer b1, Integer b2) {
- return b1 + b2;
- }
-
- @Override
- public Integer finish(Integer reduction) {
- return reduction;
- }
- }
-
public static class KryoSerializable {
String value;
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java
new file mode 100644
index 0000000000..c4c455b6e6
--- /dev/null
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package test.org.apache.spark.sql.sources;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+import scala.Tuple2;
+
+import org.apache.spark.SparkContext;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.MapFunction;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.Encoders;
+import org.apache.spark.sql.KeyValueGroupedDataset;
+import org.apache.spark.sql.expressions.Aggregator;
+import org.apache.spark.sql.test.TestSQLContext;
+
+/**
+ * Suite for testing the aggregate functionality of Datasets in Java.
+ */
+public class JavaDatasetAggregatorSuite implements Serializable {
+ private transient JavaSparkContext jsc;
+ private transient TestSQLContext context;
+
+ @Before
+ public void setUp() {
+ // Trigger static initializer of TestData
+ SparkContext sc = new SparkContext("local[*]", "testing");
+ jsc = new JavaSparkContext(sc);
+ context = new TestSQLContext(sc);
+ context.loadTestData();
+ }
+
+ @After
+ public void tearDown() {
+ context.sparkContext().stop();
+ context = null;
+ jsc = null;
+ }
+
+ private <T1, T2> Tuple2<T1, T2> tuple2(T1 t1, T2 t2) {
+ return new Tuple2<>(t1, t2);
+ }
+
+ private KeyValueGroupedDataset<String, Tuple2<String, Integer>> generateGroupedDataset() {
+ Encoder<Tuple2<String, Integer>> encoder = Encoders.tuple(Encoders.STRING(), Encoders.INT());
+ List<Tuple2<String, Integer>> data =
+ Arrays.asList(tuple2("a", 1), tuple2("a", 2), tuple2("b", 3));
+ Dataset<Tuple2<String, Integer>> ds = context.createDataset(data, encoder);
+
+ return ds.groupByKey(
+ new MapFunction<Tuple2<String, Integer>, String>() {
+ @Override
+ public String call(Tuple2<String, Integer> value) throws Exception {
+ return value._1();
+ }
+ },
+ Encoders.STRING());
+ }
+
+ @Test
+ public void testTypedAggregationAnonClass() {
+ KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
+
+ Dataset<Tuple2<String, Integer>> agged =
+ grouped.agg(new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()));
+ Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList());
+
+ Dataset<Tuple2<String, Integer>> agged2 = grouped.agg(
+ new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()))
+ .as(Encoders.tuple(Encoders.STRING(), Encoders.INT()));
+ Assert.assertEquals(
+ Arrays.asList(
+ new Tuple2<>("a", 3),
+ new Tuple2<>("b", 3)),
+ agged2.collectAsList());
+ }
+
+ static class IntSumOf extends Aggregator<Tuple2<String, Integer>, Integer, Integer> {
+
+ @Override
+ public Integer zero() {
+ return 0;
+ }
+
+ @Override
+ public Integer reduce(Integer l, Tuple2<String, Integer> t) {
+ return l + t._2();
+ }
+
+ @Override
+ public Integer merge(Integer b1, Integer b2) {
+ return b1 + b2;
+ }
+
+ @Override
+ public Integer finish(Integer reduction) {
+ return reduction;
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
index 84770169f0..5430aff6ce 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
@@ -20,35 +20,10 @@ package org.apache.spark.sql
import scala.language.postfixOps
import org.apache.spark.sql.expressions.Aggregator
+import org.apache.spark.sql.expressions.scala.typed
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
-/** An `Aggregator` that adds up any numeric type returned by the given function. */
-class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] {
- val numeric = implicitly[Numeric[N]]
-
- override def zero: N = numeric.zero
-
- override def reduce(b: N, a: I): N = numeric.plus(b, f(a))
-
- override def merge(b1: N, b2: N): N = numeric.plus(b1, b2)
-
- override def finish(reduction: N): N = reduction
-}
-
-object TypedAverage extends Aggregator[(String, Int), (Long, Long), Double] {
- override def zero: (Long, Long) = (0, 0)
-
- override def reduce(countAndSum: (Long, Long), input: (String, Int)): (Long, Long) = {
- (countAndSum._1 + 1, countAndSum._2 + input._2)
- }
-
- override def merge(b1: (Long, Long), b2: (Long, Long)): (Long, Long) = {
- (b1._1 + b2._1, b1._2 + b2._2)
- }
-
- override def finish(countAndSum: (Long, Long)): Double = countAndSum._2 / countAndSum._1
-}
object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, Long)] {
@@ -113,15 +88,12 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
import testImplicits._
- def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] =
- new SumOf(f).toColumn
-
test("typed aggregation: TypedAggregator") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
checkDataset(
- ds.groupByKey(_._1).agg(sum(_._2)),
- ("a", 30), ("b", 3), ("c", 1))
+ ds.groupByKey(_._1).agg(typed.sum(_._2)),
+ ("a", 30.0), ("b", 3.0), ("c", 1.0))
}
test("typed aggregation: TypedAggregator, expr, expr") {
@@ -129,20 +101,10 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
checkDataset(
ds.groupByKey(_._1).agg(
- sum(_._2),
+ typed.sum(_._2),
expr("sum(_2)").as[Long],
count("*")),
- ("a", 30, 30L, 2L), ("b", 3, 3L, 2L), ("c", 1, 1L, 1L))
- }
-
- test("typed aggregation: complex case") {
- val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS()
-
- checkDataset(
- ds.groupByKey(_._1).agg(
- expr("avg(_2)").as[Double],
- TypedAverage.toColumn),
- ("a", 2.0, 2.0), ("b", 3.0, 3.0))
+ ("a", 30.0, 30L, 2L), ("b", 3.0, 3L, 2L), ("c", 1.0, 1L, 1L))
}
test("typed aggregation: complex result type") {
@@ -159,11 +121,11 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
val ds = Seq(1, 3, 2, 5).toDS()
checkDataset(
- ds.select(sum((i: Int) => i)),
- 11)
+ ds.select(typed.sum((i: Int) => i)),
+ 11.0)
checkDataset(
- ds.select(sum((i: Int) => i), sum((i: Int) => i * 2)),
- 11 -> 22)
+ ds.select(typed.sum((i: Int) => i), typed.sum((i: Int) => i * 2)),
+ 11.0 -> 22.0)
}
test("typed aggregation: class input") {
@@ -206,4 +168,12 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
ds.groupByKey(_.b).agg(ComplexBufferAgg.toColumn),
("one", 1), ("two", 1))
}
+
+ test("typed aggregate: avg, count, sum") {
+ val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS()
+ checkDataset(
+ ds.groupByKey(_._1).agg(
+ typed.avg(_._2), typed.count(_._2), typed.sum(_._2), typed.sumLong(_._2)),
+ ("a", 2.0, 2L, 4.0, 4L), ("b", 3.0, 1L, 3.0, 3L))
+ }
}