aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-04-18 14:27:26 +0800
committerWenchen Fan <wenchen@databricks.com>2016-04-18 14:27:26 +0800
commit2f1d0320c97f064556fa1cf98d4e30d2ab2fe661 (patch)
tree4cf4964d6de34d2c32100a42206dea9fa799f2d6
parent7de06a646dff7ede520d2e982ac0996d8c184650 (diff)
downloadspark-2f1d0320c97f064556fa1cf98d4e30d2ab2fe661.tar.gz
spark-2f1d0320c97f064556fa1cf98d4e30d2ab2fe661.tar.bz2
spark-2f1d0320c97f064556fa1cf98d4e30d2ab2fe661.zip
[SPARK-13363][SQL] support Aggregator in RelationalGroupedDataset
## What changes were proposed in this pull request? set the input encoder for `TypedColumn` in `RelationalGroupedDataset.agg`. ## How was this patch tested? new tests in `DatasetAggregatorSuite` close https://github.com/apache/spark/pull/11269 This PR brings https://github.com/apache/spark/pull/12359 up to date and fix the compile. Author: Wenchen Fan <wenchen@databricks.com> Closes #12451 from cloud-fan/agg.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala14
2 files changed, 18 insertions, 2 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index 7dbf2e6c7c..0ffb136c24 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -208,7 +208,11 @@ class RelationalGroupedDataset protected[sql](
*/
@scala.annotation.varargs
def agg(expr: Column, exprs: Column*): DataFrame = {
- toDF((expr +: exprs).map(_.expr))
+ toDF((expr +: exprs).map {
+ case typed: TypedColumn[_, _] =>
+ typed.withInputType(df.unresolvedTEncoder.deserializer, df.logicalPlan.output).expr
+ case c => c.expr
+ })
}
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
index 3a7215ee39..0d84a594f7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql
import scala.language.postfixOps
-import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.expressions.scala.typed
import org.apache.spark.sql.functions._
@@ -85,6 +84,15 @@ class ParameterizedTypeSum[IN, OUT : Numeric : Encoder](f: IN => OUT)
override def outputEncoder: Encoder[OUT] = implicitly[Encoder[OUT]]
}
+object RowAgg extends Aggregator[Row, Int, Int] {
+ def zero: Int = 0
+ def reduce(b: Int, a: Row): Int = a.getInt(0) + b
+ def merge(b1: Int, b2: Int): Int = b1 + b2
+ def finish(r: Int): Int = r
+ override def bufferEncoder: Encoder[Int] = Encoders.scalaInt
+ override def outputEncoder: Encoder[Int] = Encoders.scalaInt
+}
+
class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
@@ -200,4 +208,8 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
(1279869254, "Some String"))
}
+ test("aggregator in DataFrame/Dataset[Row]") {
+ val df = Seq(1 -> "a", 2 -> "b", 3 -> "b").toDF("i", "j")
+ checkAnswer(df.groupBy($"j").agg(RowAgg.toColumn), Row("a", 1) :: Row("b", 5) :: Nil)
+ }
}