diff options
author | Wenchen Fan <wenchen@databricks.com> | 2016-04-18 14:27:26 +0800 |
---|---|---|
committer | Wenchen Fan <wenchen@databricks.com> | 2016-04-18 14:27:26 +0800 |
commit | 2f1d0320c97f064556fa1cf98d4e30d2ab2fe661 (patch) | |
tree | 4cf4964d6de34d2c32100a42206dea9fa799f2d6 /sql/core/src | |
parent | 7de06a646dff7ede520d2e982ac0996d8c184650 (diff) | |
download | spark-2f1d0320c97f064556fa1cf98d4e30d2ab2fe661.tar.gz spark-2f1d0320c97f064556fa1cf98d4e30d2ab2fe661.tar.bz2 spark-2f1d0320c97f064556fa1cf98d4e30d2ab2fe661.zip |
[SPARK-13363][SQL] support Aggregator in RelationalGroupedDataset
## What changes were proposed in this pull request?
set the input encoder for `TypedColumn` in `RelationalGroupedDataset.agg`.
## How was this patch tested?
new tests in `DatasetAggregatorSuite`
close https://github.com/apache/spark/pull/11269
This PR brings https://github.com/apache/spark/pull/12359 up to date and fix the compile.
Author: Wenchen Fan <wenchen@databricks.com>
Closes #12451 from cloud-fan/agg.
Diffstat (limited to 'sql/core/src')
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala | 6 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala | 14 |
2 files changed, 18 insertions, 2 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 7dbf2e6c7c..0ffb136c24 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -208,7 +208,11 @@ class RelationalGroupedDataset protected[sql]( */ @scala.annotation.varargs def agg(expr: Column, exprs: Column*): DataFrame = { - toDF((expr +: exprs).map(_.expr)) + toDF((expr +: exprs).map { + case typed: TypedColumn[_, _] => + typed.withInputType(df.unresolvedTEncoder.deserializer, df.logicalPlan.output).expr + case c => c.expr + }) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 3a7215ee39..0d84a594f7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql import scala.language.postfixOps -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.expressions.scala.typed import org.apache.spark.sql.functions._ @@ -85,6 +84,15 @@ class ParameterizedTypeSum[IN, OUT : Numeric : Encoder](f: IN => OUT) override def outputEncoder: Encoder[OUT] = implicitly[Encoder[OUT]] } +object RowAgg extends Aggregator[Row, Int, Int] { + def zero: Int = 0 + def reduce(b: Int, a: Row): Int = a.getInt(0) + b + def merge(b1: Int, b2: Int): Int = b1 + b2 + def finish(r: Int): Int = r + override def bufferEncoder: Encoder[Int] = Encoders.scalaInt + override def outputEncoder: Encoder[Int] = Encoders.scalaInt +} + class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { @@ -200,4 +208,8 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { (1279869254, "Some String")) } + test("aggregator in DataFrame/Dataset[Row]") { + val df = Seq(1 -> "a", 2 -> "b", 3 -> "b").toDF("i", "j") + checkAnswer(df.groupBy($"j").agg(RowAgg.toColumn), Row("a", 1) :: Row("b", 5) :: Nil) + } } |