aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/sql/functions.py14
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala1
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala76
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala143
4 files changed, 226 insertions, 8 deletions
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 92e724fef4..88924e2981 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -348,13 +348,13 @@ def grouping_id(*cols):
grouping columns).
>>> df.cube("name").agg(grouping_id(), sum("age")).orderBy("name").show()
- +-----+------------+--------+
- | name|groupingid()|sum(age)|
- +-----+------------+--------+
- | null| 1| 7|
- |Alice| 0| 2|
- | Bob| 0| 5|
- +-----+------------+--------+
+ +-----+-------------+--------+
+ | name|grouping_id()|sum(age)|
+ +-----+-------------+--------+
+ | null| 1| 7|
+ |Alice| 0| 2|
+ | Bob| 0| 5|
+ +-----+-------------+--------+
"""
sc = SparkContext._active_spark_context
jc = sc._jvm.functions.grouping_id(_to_seq(sc, cols, _to_java_column))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala
index a204060630..437e417266 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala
@@ -63,4 +63,5 @@ case class GroupingID(groupByExprs: Seq[Expression]) extends Expression with Une
override def children: Seq[Expression] = groupByExprs
override def dataType: DataType = IntegerType
override def nullable: Boolean = false
+ override def prettyName: String = "grouping_id"
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
index 9a14ccff57..8d411a9a40 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
import org.apache.spark.sql.catalyst.util.quoteIdentifier
import org.apache.spark.sql.execution.datasources.LogicalRelation
-import org.apache.spark.sql.types.{DataType, NullType}
+import org.apache.spark.sql.types.{ByteType, DataType, IntegerType, NullType}
/**
* A place holder for generated SQL for subquery expression.
@@ -118,6 +118,9 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
case p: Project =>
projectToSQL(p, isDistinct = false)
+ case a @ Aggregate(_, _, e @ Expand(_, _, p: Project)) if isGroupingSet(a, e, p) =>
+ groupingSetToSQL(a, e, p)
+
case p: Aggregate =>
aggregateToSQL(p)
@@ -244,6 +247,77 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
)
}
+ private def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean =
+ output1.size == output2.size &&
+ output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2))
+
+ private def isGroupingSet(a: Aggregate, e: Expand, p: Project): Boolean = {
+ assert(a.child == e && e.child == p)
+ a.groupingExpressions.forall(_.isInstanceOf[Attribute]) &&
+ sameOutput(e.output, p.child.output ++ a.groupingExpressions.map(_.asInstanceOf[Attribute]))
+ }
+
+ private def groupingSetToSQL(
+ agg: Aggregate,
+ expand: Expand,
+ project: Project): String = {
+ assert(agg.groupingExpressions.length > 1)
+
+ // The last column of Expand is always grouping ID
+ val gid = expand.output.last
+
+ val numOriginalOutput = project.child.output.length
+ // Assumption: Aggregate's groupingExpressions is composed of
+ // 1) the attributes of aliased group by expressions
+ // 2) gid, which is always the last one
+ val groupByAttributes = agg.groupingExpressions.dropRight(1).map(_.asInstanceOf[Attribute])
+ // Assumption: Project's projectList is composed of
+ // 1) the original output (Project's child.output),
+ // 2) the aliased group by expressions.
+ val groupByExprs = project.projectList.drop(numOriginalOutput).map(_.asInstanceOf[Alias].child)
+ val groupingSQL = groupByExprs.map(_.sql).mkString(", ")
+
+ // a map from group by attributes to the original group by expressions.
+ val groupByAttrMap = AttributeMap(groupByAttributes.zip(groupByExprs))
+
+ val groupingSet: Seq[Seq[Expression]] = expand.projections.map { project =>
+ // Assumption: expand.projections is composed of
+ // 1) the original output (Project's child.output),
+ // 2) group by attributes(or null literal)
+ // 3) gid, which is always the last one in each project in Expand
+ project.drop(numOriginalOutput).dropRight(1).collect {
+ case attr: Attribute if groupByAttrMap.contains(attr) => groupByAttrMap(attr)
+ }
+ }
+ val groupingSetSQL =
+ "GROUPING SETS(" +
+ groupingSet.map(e => s"(${e.map(_.sql).mkString(", ")})").mkString(", ") + ")"
+
+ val aggExprs = agg.aggregateExpressions.map { case expr =>
+ expr.transformDown {
+ // grouping_id() is converted to VirtualColumn.groupingIdName by Analyzer. Revert it back.
+ case ar: AttributeReference if ar == gid => GroupingID(Nil)
+ case ar: AttributeReference if groupByAttrMap.contains(ar) => groupByAttrMap(ar)
+ case a @ Cast(BitwiseAnd(
+ ShiftRight(ar: AttributeReference, Literal(value: Any, IntegerType)),
+ Literal(1, IntegerType)), ByteType) if ar == gid =>
+ // for converting an expression to its original SQL format grouping(col)
+ val idx = groupByExprs.length - 1 - value.asInstanceOf[Int]
+ groupByExprs.lift(idx).map(Grouping).getOrElse(a)
+ }
+ }
+
+ build(
+ "SELECT",
+ aggExprs.map(_.sql).mkString(", "),
+ if (agg.child == OneRowRelation) "" else "FROM",
+ toSQL(project.child),
+ "GROUP BY",
+ groupingSQL,
+ groupingSetSQL
+ )
+ }
+
object Canonicalizer extends RuleExecutor[LogicalPlan] {
override protected def batches: Seq[Batch] = Seq(
Batch("Canonicalizer", FixedPoint(100),
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala
index d708fcf8dd..f457d43e19 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala
@@ -218,6 +218,149 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils {
checkHiveQl("SELECT DISTINCT id FROM parquet_t0")
}
+ test("rollup/cube #1") {
+ // Original logical plan:
+ // Aggregate [(key#17L % cast(5 as bigint))#47L,grouping__id#46],
+ // [(count(1),mode=Complete,isDistinct=false) AS cnt#43L,
+ // (key#17L % cast(5 as bigint))#47L AS _c1#45L,
+ // grouping__id#46 AS _c2#44]
+ // +- Expand [List(key#17L, value#18, (key#17L % cast(5 as bigint))#47L, 0),
+ // List(key#17L, value#18, null, 1)],
+ // [key#17L,value#18,(key#17L % cast(5 as bigint))#47L,grouping__id#46]
+ // +- Project [key#17L,
+ // value#18,
+ // (key#17L % cast(5 as bigint)) AS (key#17L % cast(5 as bigint))#47L]
+ // +- Subquery t1
+ // +- Relation[key#17L,value#18] ParquetRelation
+ // Converted SQL:
+ // SELECT count( 1) AS `cnt`,
+ // (`t1`.`key` % CAST(5 AS BIGINT)),
+ // grouping_id() AS `_c2`
+ // FROM `default`.`t1`
+ // GROUP BY (`t1`.`key` % CAST(5 AS BIGINT))
+ // GROUPING SETS (((`t1`.`key` % CAST(5 AS BIGINT))), ())
+ checkHiveQl(
+ "SELECT count(*) as cnt, key%5, grouping_id() FROM parquet_t1 GROUP BY key % 5 WITH ROLLUP")
+ checkHiveQl(
+ "SELECT count(*) as cnt, key%5, grouping_id() FROM parquet_t1 GROUP BY key % 5 WITH CUBE")
+ }
+
+ test("rollup/cube #2") {
+ checkHiveQl("SELECT key, value, count(value) FROM parquet_t1 GROUP BY key, value WITH ROLLUP")
+ checkHiveQl("SELECT key, value, count(value) FROM parquet_t1 GROUP BY key, value WITH CUBE")
+ }
+
+ test("rollup/cube #3") {
+ checkHiveQl(
+ "SELECT key, count(value), grouping_id() FROM parquet_t1 GROUP BY key, value WITH ROLLUP")
+ checkHiveQl(
+ "SELECT key, count(value), grouping_id() FROM parquet_t1 GROUP BY key, value WITH CUBE")
+ }
+
+ test("rollup/cube #4") {
+ checkHiveQl(
+ s"""
+ |SELECT count(*) as cnt, key % 5 as k1, key - 5 as k2, grouping_id() FROM parquet_t1
+ |GROUP BY key % 5, key - 5 WITH ROLLUP
+ """.stripMargin)
+ checkHiveQl(
+ s"""
+ |SELECT count(*) as cnt, key % 5 as k1, key - 5 as k2, grouping_id() FROM parquet_t1
+ |GROUP BY key % 5, key - 5 WITH CUBE
+ """.stripMargin)
+ }
+
+ test("rollup/cube #5") {
+ checkHiveQl(
+ s"""
+ |SELECT count(*) AS cnt, key % 5 AS k1, key - 5 AS k2, grouping_id(key % 5, key - 5) AS k3
+ |FROM (SELECT key, key%2, key - 5 FROM parquet_t1) t GROUP BY key%5, key-5
+ |WITH ROLLUP
+ """.stripMargin)
+ checkHiveQl(
+ s"""
+ |SELECT count(*) AS cnt, key % 5 AS k1, key - 5 AS k2, grouping_id(key % 5, key - 5) AS k3
+ |FROM (SELECT key, key % 2, key - 5 FROM parquet_t1) t GROUP BY key % 5, key - 5
+ |WITH CUBE
+ """.stripMargin)
+ }
+
+ test("rollup/cube #6") {
+ checkHiveQl("SELECT a, b, sum(c) FROM parquet_t2 GROUP BY ROLLUP(a, b) ORDER BY a, b")
+ checkHiveQl("SELECT a, b, sum(c) FROM parquet_t2 GROUP BY CUBE(a, b) ORDER BY a, b")
+ checkHiveQl("SELECT a, b, sum(a) FROM parquet_t2 GROUP BY ROLLUP(a, b) ORDER BY a, b")
+ checkHiveQl("SELECT a, b, sum(a) FROM parquet_t2 GROUP BY CUBE(a, b) ORDER BY a, b")
+ checkHiveQl("SELECT a + b, b, sum(a - b) FROM parquet_t2 GROUP BY a + b, b WITH ROLLUP")
+ checkHiveQl("SELECT a + b, b, sum(a - b) FROM parquet_t2 GROUP BY a + b, b WITH CUBE")
+ }
+
+ test("rollup/cube #7") {
+ checkHiveQl("SELECT a, b, grouping_id(a, b) FROM parquet_t2 GROUP BY cube(a, b)")
+ checkHiveQl("SELECT a, b, grouping(b) FROM parquet_t2 GROUP BY cube(a, b)")
+ checkHiveQl("SELECT a, b, grouping(a) FROM parquet_t2 GROUP BY cube(a, b)")
+ }
+
+ test("rollup/cube #8") {
+ // grouping_id() is part of another expression
+ checkHiveQl(
+ s"""
+ |SELECT hkey AS k1, value - 5 AS k2, hash(grouping_id()) AS hgid
+ |FROM (SELECT hash(key) as hkey, key as value FROM parquet_t1) t GROUP BY hkey, value-5
+ |WITH ROLLUP
+ """.stripMargin)
+ checkHiveQl(
+ s"""
+ |SELECT hkey AS k1, value - 5 AS k2, hash(grouping_id()) AS hgid
+ |FROM (SELECT hash(key) as hkey, key as value FROM parquet_t1) t GROUP BY hkey, value-5
+ |WITH CUBE
+ """.stripMargin)
+ }
+
+ test("rollup/cube #9") {
+ // self join is used as the child node of ROLLUP/CUBE with replaced quantifiers
+ checkHiveQl(
+ s"""
+ |SELECT t.key - 5, cnt, SUM(cnt)
+ |FROM (SELECT x.key, COUNT(*) as cnt
+ |FROM parquet_t1 x JOIN parquet_t1 y ON x.key = y.key GROUP BY x.key) t
+ |GROUP BY cnt, t.key - 5
+ |WITH ROLLUP
+ """.stripMargin)
+ checkHiveQl(
+ s"""
+ |SELECT t.key - 5, cnt, SUM(cnt)
+ |FROM (SELECT x.key, COUNT(*) as cnt
+ |FROM parquet_t1 x JOIN parquet_t1 y ON x.key = y.key GROUP BY x.key) t
+ |GROUP BY cnt, t.key - 5
+ |WITH CUBE
+ """.stripMargin)
+ }
+
+ test("grouping sets #1") {
+ checkHiveQl(
+ s"""
+ |SELECT count(*) AS cnt, key % 5 AS k1, key - 5 AS k2, grouping_id() AS k3
+ |FROM (SELECT key, key % 2, key - 5 FROM parquet_t1) t GROUP BY key % 5, key - 5
+ |GROUPING SETS (key % 5, key - 5)
+ """.stripMargin)
+ }
+
+ test("grouping sets #2") {
+ checkHiveQl(
+ "SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (a, b) ORDER BY a, b")
+ checkHiveQl(
+ "SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (a) ORDER BY a, b")
+ checkHiveQl(
+ "SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (b) ORDER BY a, b")
+ checkHiveQl(
+ "SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (()) ORDER BY a, b")
+ checkHiveQl(
+ s"""
+ |SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b
+ |GROUPING SETS ((), (a), (a, b)) ORDER BY a, b
+ """.stripMargin)
+ }
+
test("cluster by") {
checkHiveQl("SELECT id FROM parquet_t0 CLUSTER BY id")
}