aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-01-04 14:26:56 -0800
committerDavies Liu <davies.liu@gmail.com>2016-01-04 14:26:56 -0800
commitd084a2de3271fd8b0d29ee67e1e214e8b9d96d6d (patch)
tree208d156e5ee567d81ca35acbfcc3c0df929e5386 /sql/core
parent93ef9b6a2aa1830170cb101f191022f2dda62c41 (diff)
downloadspark-d084a2de3271fd8b0d29ee67e1e214e8b9d96d6d.tar.gz
spark-d084a2de3271fd8b0d29ee67e1e214e8b9d96d6d.tar.bz2
spark-d084a2de3271fd8b0d29ee67e1e214e8b9d96d6d.zip
[SPARK-12541] [SQL] support cube/rollup as function
This PR enable cube/rollup as function, so they can be used as this: ``` select a, b, sum(c) from t group by rollup(a, b) ``` Author: Davies Liu <davies@databricks.com> Closes #10522 from davies/rollup.
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala29
2 files changed, 32 insertions, 3 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
index 13341a88a6..2aa82f1496 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
@@ -24,7 +24,7 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedAlias, UnresolvedAttribute, Star}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.plans.logical.{Pivot, Rollup, Cube, Aggregate}
+import org.apache.spark.sql.catalyst.plans.logical.{Pivot, Aggregate}
import org.apache.spark.sql.types.NumericType
@@ -58,10 +58,10 @@ class GroupedData protected[sql](
df.sqlContext, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan))
case GroupedData.RollupType =>
DataFrame(
- df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aliasedAgg))
+ df.sqlContext, Aggregate(Seq(Rollup(groupingExprs)), aliasedAgg, df.logicalPlan))
case GroupedData.CubeType =>
DataFrame(
- df.sqlContext, Cube(groupingExprs, df.logicalPlan, aliasedAgg))
+ df.sqlContext, Aggregate(Seq(Cube(groupingExprs)), aliasedAgg, df.logicalPlan))
case GroupedData.PivotType(pivotCol, values) =>
val aliasedGrps = groupingExprs.map(alias)
DataFrame(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index bb82b562aa..115b617c21 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -2028,4 +2028,33 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
Row(false) :: Row(true) :: Nil)
}
+ test("rollup") {
+ checkAnswer(
+ sql("select course, year, sum(earnings) from courseSales group by rollup(course, year)" +
+ " order by course, year"),
+ Row(null, null, 113000.0) ::
+ Row("Java", null, 50000.0) ::
+ Row("Java", 2012, 20000.0) ::
+ Row("Java", 2013, 30000.0) ::
+ Row("dotNET", null, 63000.0) ::
+ Row("dotNET", 2012, 15000.0) ::
+ Row("dotNET", 2013, 48000.0) :: Nil
+ )
+ }
+
+ test("cube") {
+ checkAnswer(
+ sql("select course, year, sum(earnings) from courseSales group by cube(course, year)"),
+ Row("Java", 2012, 20000.0) ::
+ Row("Java", 2013, 30000.0) ::
+ Row("Java", null, 50000.0) ::
+ Row("dotNET", 2012, 15000.0) ::
+ Row("dotNET", 2013, 48000.0) ::
+ Row("dotNET", null, 63000.0) ::
+ Row(null, 2012, 35000.0) ::
+ Row(null, 2013, 78000.0) ::
+ Row(null, null, 113000.0) :: Nil
+ )
+ }
+
}