aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-02-10 20:13:38 -0800
committerDavies Liu <davies.liu@gmail.com>2016-02-10 20:13:38 -0800
commitb5761d150b66ee0ae5f1be897d9d7a1abb039884 (patch)
tree4d2f839c621b844f09d7e5045c23156cec3a12a6 /sql/core
parent0f09f0226983cdc409ef504dff48395787dc844f (diff)
downloadspark-b5761d150b66ee0ae5f1be897d9d7a1abb039884.tar.gz
spark-b5761d150b66ee0ae5f1be897d9d7a1abb039884.tar.bz2
spark-b5761d150b66ee0ae5f1be897d9d7a1abb039884.zip
[SPARK-12706] [SQL] grouping() and grouping_id()
Grouping() returns a column is aggregated or not, grouping_id() returns the aggregation levels. grouping()/grouping_id() could be used with window function, but does not work in having/sort clause, will be fixed by another PR. The GROUPING__ID/grouping_id() in Hive is wrong (according to docs), we also did it wrongly, this PR change that to match the behavior in most databases (also the docs of Hive). Author: Davies Liu <davies@databricks.com> Closes #10677 from davies/grouping.
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala46
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala44
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala50
3 files changed, 140 insertions, 0 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index b970eee4e3..d34d377ab6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -396,6 +396,52 @@ object functions extends LegacyFunctions {
*/
def first(columnName: String): Column = first(Column(columnName))
+
+ /**
+ * Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated
+ * or not, returns 1 for aggregated or 0 for not aggregated in the result set.
+ *
+ * @group agg_funcs
+ * @since 2.0.0
+ */
+ def grouping(e: Column): Column = Column(Grouping(e.expr))
+
+ /**
+ * Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated
+ * or not, returns 1 for aggregated or 0 for not aggregated in the result set.
+ *
+ * @group agg_funcs
+ * @since 2.0.0
+ */
+ def grouping(columnName: String): Column = grouping(Column(columnName))
+
+ /**
+ * Aggregate function: returns the level of grouping, equals to
+ *
+ * (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + ... + grouping(cn)
+ *
+ * Note: the list of columns should match with grouping columns exactly, or empty (means all the
+ * grouping columns).
+ *
+ * @group agg_funcs
+ * @since 2.0.0
+ */
+ def grouping_id(cols: Column*): Column = Column(GroupingID(cols.map(_.expr)))
+
+ /**
+ * Aggregate function: returns the level of grouping, equals to
+ *
+ * (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + ... + grouping(cn)
+ *
+ * Note: the list of columns should match with grouping columns exactly.
+ *
+ * @group agg_funcs
+ * @since 2.0.0
+ */
+ def grouping_id(colName: String, colNames: String*): Column = {
+ grouping_id((Seq(colName) ++ colNames).map(n => Column(n)) : _*)
+ }
+
/**
* Aggregate function: returns the kurtosis of the values in a group.
*
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 08fb7c9d84..78bf6c1bce 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql
+import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.DecimalType
@@ -98,6 +99,49 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
assert(cube0.where("date IS NULL").count > 0)
}
+ test("grouping and grouping_id") {
+ checkAnswer(
+ courseSales.cube("course", "year")
+ .agg(grouping("course"), grouping("year"), grouping_id("course", "year")),
+ Row("Java", 2012, 0, 0, 0) ::
+ Row("Java", 2013, 0, 0, 0) ::
+ Row("Java", null, 0, 1, 1) ::
+ Row("dotNET", 2012, 0, 0, 0) ::
+ Row("dotNET", 2013, 0, 0, 0) ::
+ Row("dotNET", null, 0, 1, 1) ::
+ Row(null, 2012, 1, 0, 2) ::
+ Row(null, 2013, 1, 0, 2) ::
+ Row(null, null, 1, 1, 3) :: Nil
+ )
+
+ intercept[AnalysisException] {
+ courseSales.groupBy().agg(grouping("course")).explain()
+ }
+ intercept[AnalysisException] {
+ courseSales.groupBy().agg(grouping_id("course")).explain()
+ }
+ }
+
+ test("grouping/grouping_id inside window function") {
+
+ val w = Window.orderBy(sum("earnings"))
+ checkAnswer(
+ courseSales.cube("course", "year")
+ .agg(sum("earnings"),
+ grouping_id("course", "year"),
+ rank().over(Window.partitionBy(grouping_id("course", "year")).orderBy(sum("earnings")))),
+ Row("Java", 2012, 20000.0, 0, 2) ::
+ Row("Java", 2013, 30000.0, 0, 3) ::
+ Row("Java", null, 50000.0, 1, 1) ::
+ Row("dotNET", 2012, 15000.0, 0, 1) ::
+ Row("dotNET", 2013, 48000.0, 0, 4) ::
+ Row("dotNET", null, 63000.0, 1, 2) ::
+ Row(null, 2012, 35000.0, 2, 1) ::
+ Row(null, 2013, 78000.0, 2, 2) ::
+ Row(null, null, 113000.0, 3, 1) :: Nil
+ )
+ }
+
test("rollup overlapping columns") {
checkAnswer(
testData2.rollup($"a" + $"b" as "foo", $"b" as "bar").agg(sum($"a" - $"b") as "foo"),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 8ef7b61314..f665a1c87b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -2055,6 +2055,56 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
)
}
+ test("grouping sets") {
+ checkAnswer(
+ sql("select course, year, sum(earnings) from courseSales group by course, year " +
+ "grouping sets(course, year)"),
+ Row("Java", null, 50000.0) ::
+ Row("dotNET", null, 63000.0) ::
+ Row(null, 2012, 35000.0) ::
+ Row(null, 2013, 78000.0) :: Nil
+ )
+
+ checkAnswer(
+ sql("select course, year, sum(earnings) from courseSales group by course, year " +
+ "grouping sets(course)"),
+ Row("Java", null, 50000.0) ::
+ Row("dotNET", null, 63000.0) :: Nil
+ )
+
+ checkAnswer(
+ sql("select course, year, sum(earnings) from courseSales group by course, year " +
+ "grouping sets(year)"),
+ Row(null, 2012, 35000.0) ::
+ Row(null, 2013, 78000.0) :: Nil
+ )
+ }
+
+ test("grouping and grouping_id") {
+ checkAnswer(
+ sql("select course, year, grouping(course), grouping(year), grouping_id(course, year)" +
+ " from courseSales group by cube(course, year)"),
+ Row("Java", 2012, 0, 0, 0) ::
+ Row("Java", 2013, 0, 0, 0) ::
+ Row("Java", null, 0, 1, 1) ::
+ Row("dotNET", 2012, 0, 0, 0) ::
+ Row("dotNET", 2013, 0, 0, 0) ::
+ Row("dotNET", null, 0, 1, 1) ::
+ Row(null, 2012, 1, 0, 2) ::
+ Row(null, 2013, 1, 0, 2) ::
+ Row(null, null, 1, 1, 3) :: Nil
+ )
+
+ var error = intercept[AnalysisException] {
+ sql("select course, year, grouping(course) from courseSales group by course, year")
+ }
+ assert(error.getMessage contains "grouping() can only be used with GroupingSets/Cube/Rollup")
+ error = intercept[AnalysisException] {
+ sql("select course, year, grouping_id(course, year) from courseSales group by course, year")
+ }
+ assert(error.getMessage contains "grouping_id() can only be used with GroupingSets/Cube/Rollup")
+ }
+
test("SPARK-13056: Null in map value causes NPE") {
val df = Seq(1 -> Map("abc" -> "somestring", "cba" -> null)).toDF("key", "value")
withTempTable("maptest") {