From d084a2de3271fd8b0d29ee67e1e214e8b9d96d6d Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 4 Jan 2016 14:26:56 -0800 Subject: [SPARK-12541] [SQL] support cube/rollup as function This PR enable cube/rollup as function, so they can be used as this: ``` select a, b, sum(c) from t group by rollup(a, b) ``` Author: Davies Liu Closes #10522 from davies/rollup. --- .../scala/org/apache/spark/sql/GroupedData.scala | 6 ++--- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 29 ++++++++++++++++++++++ 2 files changed, 32 insertions(+), 3 deletions(-) (limited to 'sql/core') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 13341a88a6..2aa82f1496 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -24,7 +24,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedAlias, UnresolvedAttribute, Star} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.logical.{Pivot, Rollup, Cube, Aggregate} +import org.apache.spark.sql.catalyst.plans.logical.{Pivot, Aggregate} import org.apache.spark.sql.types.NumericType @@ -58,10 +58,10 @@ class GroupedData protected[sql]( df.sqlContext, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan)) case GroupedData.RollupType => DataFrame( - df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aliasedAgg)) + df.sqlContext, Aggregate(Seq(Rollup(groupingExprs)), aliasedAgg, df.logicalPlan)) case GroupedData.CubeType => DataFrame( - df.sqlContext, Cube(groupingExprs, df.logicalPlan, aliasedAgg)) + df.sqlContext, Aggregate(Seq(Cube(groupingExprs)), aliasedAgg, df.logicalPlan)) case GroupedData.PivotType(pivotCol, values) => val aliasedGrps = groupingExprs.map(alias) DataFrame( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index bb82b562aa..115b617c21 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2028,4 +2028,33 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(false) :: Row(true) :: Nil) } + test("rollup") { + checkAnswer( + sql("select course, year, sum(earnings) from courseSales group by rollup(course, year)" + + " order by course, year"), + Row(null, null, 113000.0) :: + Row("Java", null, 50000.0) :: + Row("Java", 2012, 20000.0) :: + Row("Java", 2013, 30000.0) :: + Row("dotNET", null, 63000.0) :: + Row("dotNET", 2012, 15000.0) :: + Row("dotNET", 2013, 48000.0) :: Nil + ) + } + + test("cube") { + checkAnswer( + sql("select course, year, sum(earnings) from courseSales group by cube(course, year)"), + Row("Java", 2012, 20000.0) :: + Row("Java", 2013, 30000.0) :: + Row("Java", null, 50000.0) :: + Row("dotNET", 2012, 15000.0) :: + Row("dotNET", 2013, 48000.0) :: + Row("dotNET", null, 63000.0) :: + Row(null, 2012, 35000.0) :: + Row(null, 2013, 78000.0) :: + Row(null, null, 113000.0) :: Nil + ) + } + } -- cgit v1.2.3