From 2f1d0320c97f064556fa1cf98d4e30d2ab2fe661 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 18 Apr 2016 14:27:26 +0800 Subject: [SPARK-13363][SQL] support Aggregator in RelationalGroupedDataset ## What changes were proposed in this pull request? set the input encoder for `TypedColumn` in `RelationalGroupedDataset.agg`. ## How was this patch tested? new tests in `DatasetAggregatorSuite` close https://github.com/apache/spark/pull/11269 This PR brings https://github.com/apache/spark/pull/12359 up to date and fix the compile. Author: Wenchen Fan Closes #12451 from cloud-fan/agg. --- .../org/apache/spark/sql/RelationalGroupedDataset.scala | 6 +++++- .../org/apache/spark/sql/DatasetAggregatorSuite.scala | 14 +++++++++++++- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 7dbf2e6c7c..0ffb136c24 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -208,7 +208,11 @@ class RelationalGroupedDataset protected[sql]( */ @scala.annotation.varargs def agg(expr: Column, exprs: Column*): DataFrame = { - toDF((expr +: exprs).map(_.expr)) + toDF((expr +: exprs).map { + case typed: TypedColumn[_, _] => + typed.withInputType(df.unresolvedTEncoder.deserializer, df.logicalPlan.output).expr + case c => c.expr + }) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 3a7215ee39..0d84a594f7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql import scala.language.postfixOps -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.expressions.scala.typed import org.apache.spark.sql.functions._ @@ -85,6 +84,15 @@ class ParameterizedTypeSum[IN, OUT : Numeric : Encoder](f: IN => OUT) override def outputEncoder: Encoder[OUT] = implicitly[Encoder[OUT]] } +object RowAgg extends Aggregator[Row, Int, Int] { + def zero: Int = 0 + def reduce(b: Int, a: Row): Int = a.getInt(0) + b + def merge(b1: Int, b2: Int): Int = b1 + b2 + def finish(r: Int): Int = r + override def bufferEncoder: Encoder[Int] = Encoders.scalaInt + override def outputEncoder: Encoder[Int] = Encoders.scalaInt +} + class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { @@ -200,4 +208,8 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { (1279869254, "Some String")) } + test("aggregator in DataFrame/Dataset[Row]") { + val df = Seq(1 -> "a", 2 -> "b", 3 -> "b").toDF("i", "j") + checkAnswer(df.groupBy($"j").agg(RowAgg.toColumn), Row("a", 1) :: Row("b", 5) :: Nil) + } } -- cgit v1.2.3