aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorpetermaxlee <petermaxlee@gmail.com>2016-08-21 00:25:55 +0800
committerWenchen Fan <wenchen@databricks.com>2016-08-21 00:25:55 +0800
commit9560c8d29542a5dcaaa07b7af9ef5ddcdbb5d14d (patch)
tree388d35f3f0833bbd2a653ad22f185569b55ed12b
parent31a015572024046f4deaa6cec66bb6fab110f31d (diff)
downloadspark-9560c8d29542a5dcaaa07b7af9ef5ddcdbb5d14d.tar.gz
spark-9560c8d29542a5dcaaa07b7af9ef5ddcdbb5d14d.tar.bz2
spark-9560c8d29542a5dcaaa07b7af9ef5ddcdbb5d14d.zip
[SPARK-17124][SQL] RelationalGroupedDataset.agg should preserve order and allow multiple aggregates per column
## What changes were proposed in this pull request? This patch fixes a longstanding issue with one of the RelationalGroupedDataset.agg function. Even though the signature accepts vararg of pairs, the underlying implementation turns the seq into a map, and thus not order preserving nor allowing multiple aggregates per column. This change also allows users to use this function to run multiple different aggregations for a single column, e.g. ``` agg("age" -> "max", "age" -> "count") ``` ## How was this patch tested? Added a test case in DataFrameAggregateSuite. Author: petermaxlee <petermaxlee@gmail.com> Closes #14697 from petermaxlee/SPARK-17124.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala10
2 files changed, 14 insertions, 2 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index 7cfd1cdc7d..53d732403f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -128,7 +128,7 @@ class RelationalGroupedDataset protected[sql](
}
/**
- * (Scala-specific) Compute aggregates by specifying a map from column name to
+ * (Scala-specific) Compute aggregates by specifying the column names and
* aggregate methods. The resulting [[DataFrame]] will also contain the grouping columns.
*
* The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`.
@@ -143,7 +143,9 @@ class RelationalGroupedDataset protected[sql](
* @since 1.3.0
*/
def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = {
- agg((aggExpr +: aggExprs).toMap)
+ toDF((aggExpr +: aggExprs).map { case (colName, expr) =>
+ strToExpr(expr)(df(colName).expr)
+ })
}
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 92aa7b9543..69a3b5f278 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -87,6 +87,16 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
)
}
+ test("SPARK-17124 agg should be ordering preserving") {
+ val df = spark.range(2)
+ val ret = df.groupBy("id").agg("id" -> "sum", "id" -> "count", "id" -> "min")
+ assert(ret.schema.map(_.name) == Seq("id", "sum(id)", "count(id)", "min(id)"))
+ checkAnswer(
+ ret,
+ Row(0, 0, 1, 0) :: Row(1, 1, 1, 1) :: Nil
+ )
+ }
+
test("rollup") {
checkAnswer(
courseSales.rollup("course", "year").sum("earnings"),