diff options
author | petermaxlee <petermaxlee@gmail.com> | 2016-08-21 00:25:55 +0800 |
---|---|---|
committer | Wenchen Fan <wenchen@databricks.com> | 2016-08-21 00:25:55 +0800 |
commit | 9560c8d29542a5dcaaa07b7af9ef5ddcdbb5d14d (patch) | |
tree | 388d35f3f0833bbd2a653ad22f185569b55ed12b /sql | |
parent | 31a015572024046f4deaa6cec66bb6fab110f31d (diff) | |
download | spark-9560c8d29542a5dcaaa07b7af9ef5ddcdbb5d14d.tar.gz spark-9560c8d29542a5dcaaa07b7af9ef5ddcdbb5d14d.tar.bz2 spark-9560c8d29542a5dcaaa07b7af9ef5ddcdbb5d14d.zip |
[SPARK-17124][SQL] RelationalGroupedDataset.agg should preserve order and allow multiple aggregates per column
## What changes were proposed in this pull request?
This patch fixes a longstanding issue with one of the RelationalGroupedDataset.agg function. Even though the signature accepts vararg of pairs, the underlying implementation turns the seq into a map, and thus not order preserving nor allowing multiple aggregates per column.
This change also allows users to use this function to run multiple different aggregations for a single column, e.g.
```
agg("age" -> "max", "age" -> "count")
```
## How was this patch tested?
Added a test case in DataFrameAggregateSuite.
Author: petermaxlee <petermaxlee@gmail.com>
Closes #14697 from petermaxlee/SPARK-17124.
Diffstat (limited to 'sql')
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala | 6 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala | 10 |
2 files changed, 14 insertions, 2 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 7cfd1cdc7d..53d732403f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -128,7 +128,7 @@ class RelationalGroupedDataset protected[sql]( } /** - * (Scala-specific) Compute aggregates by specifying a map from column name to + * (Scala-specific) Compute aggregates by specifying the column names and * aggregate methods. The resulting [[DataFrame]] will also contain the grouping columns. * * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. @@ -143,7 +143,9 @@ class RelationalGroupedDataset protected[sql]( * @since 1.3.0 */ def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = { - agg((aggExpr +: aggExprs).toMap) + toDF((aggExpr +: aggExprs).map { case (colName, expr) => + strToExpr(expr)(df(colName).expr) + }) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 92aa7b9543..69a3b5f278 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -87,6 +87,16 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { ) } + test("SPARK-17124 agg should be ordering preserving") { + val df = spark.range(2) + val ret = df.groupBy("id").agg("id" -> "sum", "id" -> "count", "id" -> "min") + assert(ret.schema.map(_.name) == Seq("id", "sum(id)", "count(id)", "min(id)")) + checkAnswer( + ret, + Row(0, 0, 1, 0) :: Row(1, 1, 1, 1) :: Nil + ) + } + test("rollup") { checkAnswer( courseSales.rollup("course", "year").sum("earnings"), |