diff options
author | Wenchen Fan <wenchen@databricks.com> | 2016-04-16 00:31:51 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2016-04-16 00:31:51 -0700 |
commit | 12854464c4fa30c4df3b5b17bd8914d048dbf4a9 (patch) | |
tree | 84db7105342804ed5580c287db25cbd639bf407b /sql/core | |
parent | f4be0946af219379fb2476e6f80b2e50463adeb2 (diff) | |
download | spark-12854464c4fa30c4df3b5b17bd8914d048dbf4a9.tar.gz spark-12854464c4fa30c4df3b5b17bd8914d048dbf4a9.tar.bz2 spark-12854464c4fa30c4df3b5b17bd8914d048dbf4a9.zip |
[SPARK-13363][SQL] support Aggregator in RelationalGroupedDataset
## What changes were proposed in this pull request?
set the input encoder for `TypedColumn` in `RelationalGroupedDataset.agg`.
## How was this patch tested?
new tests in `DatasetAggregatorSuite`
close https://github.com/apache/spark/pull/11269
Author: Wenchen Fan <wenchen@databricks.com>
Closes #12359 from cloud-fan/agg.
Diffstat (limited to 'sql/core')
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala | 6 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala | 14 |
2 files changed, 18 insertions, 2 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 7dbf2e6c7c..deb2e82165 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -208,7 +208,11 @@ class RelationalGroupedDataset protected[sql]( */ @scala.annotation.varargs def agg(expr: Column, exprs: Column*): DataFrame = { - toDF((expr +: exprs).map(_.expr)) + toDF((expr +: exprs).map { + case typed: TypedColumn[_, _] => + typed.withInputType(df.resolvedTEncoder, df.logicalPlan.output).expr + case c => c.expr + }) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 3a7215ee39..0d84a594f7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql import scala.language.postfixOps -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.expressions.scala.typed import org.apache.spark.sql.functions._ @@ -85,6 +84,15 @@ class ParameterizedTypeSum[IN, OUT : Numeric : Encoder](f: IN => OUT) override def outputEncoder: Encoder[OUT] = implicitly[Encoder[OUT]] } +object RowAgg extends Aggregator[Row, Int, Int] { + def zero: Int = 0 + def reduce(b: Int, a: Row): Int = a.getInt(0) + b + def merge(b1: Int, b2: Int): Int = b1 + b2 + def finish(r: Int): Int = r + override def bufferEncoder: Encoder[Int] = Encoders.scalaInt + override def outputEncoder: Encoder[Int] = Encoders.scalaInt +} + class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { @@ -200,4 +208,8 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { (1279869254, "Some String")) } + test("aggregator in DataFrame/Dataset[Row]") { + val df = Seq(1 -> "a", 2 -> "b", 3 -> "b").toDF("i", "j") + checkAnswer(df.groupBy($"j").agg(RowAgg.toColumn), Row("a", 1) :: Row("b", 5) :: Nil) + } } |