aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjiangxingbo <jiangxb1987@gmail.com>2016-11-08 15:11:03 +0100
committerHerman van Hovell <hvanhovell@databricks.com>2016-11-08 15:11:03 +0100
commit344dcad70173abcb348c68fdb0219960b5b06635 (patch)
tree703279c99408b5ceeb872a36ef4d3eb68c9b2721
parentb1033fb74595716a8973acae43a6415d8e0a76d2 (diff)
downloadspark-344dcad70173abcb348c68fdb0219960b5b06635.tar.gz
spark-344dcad70173abcb348c68fdb0219960b5b06635.tar.bz2
spark-344dcad70173abcb348c68fdb0219960b5b06635.zip
[SPARK-17868][SQL] Do not use bitmasks during parsing and analysis of CUBE/ROLLUP/GROUPING SETS
## What changes were proposed in this pull request? We generate bitmasks for grouping sets during the parsing process, and use these during analysis. These bitmasks are difficult to work with in practice and have lead to numerous bugs. This PR removes these and use actual sets instead, however we still need to generate these offsets for the grouping_id. This PR does the following works: 1. Replace bitmasks by actual grouping sets durning Parsing/Analysis stage of CUBE/ROLLUP/GROUPING SETS; 2. Add new testsuite `ResolveGroupingAnalyticsSuite` to test the `Analyzer.ResolveGroupingAnalytics` rule directly; 3. Fix a minor bug in `ResolveGroupingAnalytics`. ## How was this patch tested? By existing test cases, and add new testsuite `ResolveGroupingAnalyticsSuite` to test directly. Author: jiangxingbo <jiangxb1987@gmail.com> Closes #15484 from jiangxb1987/group-set.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala219
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala29
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala67
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala291
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala5
5 files changed, 474 insertions, 137 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 8dbec40800..dd68d60d3e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -217,11 +217,9 @@ class Analyzer(
* Group Count: N + 1 (N is the number of group expressions)
*
* We need to get all of its subsets for the rule described above, the subset is
- * represented as the bit masks.
+ * represented as sequence of expressions.
*/
- def bitmasks(r: Rollup): Seq[Int] = {
- Seq.tabulate(r.groupByExprs.length + 1)(idx => (1 << idx) - 1)
- }
+ def rollupExprs(exprs: Seq[Expression]): Seq[Seq[Expression]] = exprs.inits.toSeq
/*
* GROUP BY a, b, c WITH CUBE
@@ -230,10 +228,14 @@ class Analyzer(
* Group Count: 2 ^ N (N is the number of group expressions)
*
* We need to get all of its subsets for a given GROUPBY expression, the subsets are
- * represented as the bit masks.
+ * represented as sequence of expressions.
*/
- def bitmasks(c: Cube): Seq[Int] = {
- Seq.tabulate(1 << c.groupByExprs.length)(i => i)
+ def cubeExprs(exprs: Seq[Expression]): Seq[Seq[Expression]] = exprs.toList match {
+ case x :: xs =>
+ val initial = cubeExprs(xs)
+ initial.map(x +: _) ++ initial
+ case Nil =>
+ Seq(Seq.empty)
}
private def hasGroupingAttribute(expr: Expression): Boolean = {
@@ -256,17 +258,17 @@ class Analyzer(
expr transform {
case e: GroupingID =>
if (e.groupByExprs.isEmpty || e.groupByExprs == groupByExprs) {
- gid
+ Alias(gid, toPrettySQL(e))()
} else {
throw new AnalysisException(
s"Columns of grouping_id (${e.groupByExprs.mkString(",")}) does not match " +
s"grouping columns (${groupByExprs.mkString(",")})")
}
- case Grouping(col: Expression) =>
+ case e @ Grouping(col: Expression) =>
val idx = groupByExprs.indexOf(col)
if (idx >= 0) {
- Cast(BitwiseAnd(ShiftRight(gid, Literal(groupByExprs.length - 1 - idx)),
- Literal(1)), ByteType)
+ Alias(Cast(BitwiseAnd(ShiftRight(gid, Literal(groupByExprs.length - 1 - idx)),
+ Literal(1)), ByteType), toPrettySQL(e))()
} else {
throw new AnalysisException(s"Column of grouping ($col) can't be found " +
s"in grouping columns ${groupByExprs.mkString(",")}")
@@ -274,85 +276,107 @@ class Analyzer(
}
}
- // This require transformUp to replace grouping()/grouping_id() in resolved Filter/Sort
- def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
- case a if !a.childrenResolved => a // be sure all of the children are resolved.
- case p if p.expressions.exists(hasGroupingAttribute) =>
- failAnalysis(
- s"${VirtualColumn.hiveGroupingIdName} is deprecated; use grouping_id() instead")
-
- case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child) =>
- GroupingSets(bitmasks(c), groupByExprs, child, aggregateExpressions)
- case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child) =>
- GroupingSets(bitmasks(r), groupByExprs, child, aggregateExpressions)
+ /*
+ * Create new alias for all group by expressions for `Expand` operator.
+ */
+ private def constructGroupByAlias(groupByExprs: Seq[Expression]): Seq[Alias] = {
+ groupByExprs.map {
+ case e: NamedExpression => Alias(e, e.name)()
+ case other => Alias(other, other.toString)()
+ }
+ }
- // Ensure all the expressions have been resolved.
- case x: GroupingSets if x.expressions.forall(_.resolved) =>
- val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)()
-
- // Expand works by setting grouping expressions to null as determined by the bitmasks. To
- // prevent these null values from being used in an aggregate instead of the original value
- // we need to create new aliases for all group by expressions that will only be used for
- // the intended purpose.
- val groupByAliases: Seq[Alias] = x.groupByExprs.map {
- case e: NamedExpression => Alias(e, e.name)()
- case other => Alias(other, other.toString)()
+ /*
+ * Construct [[Expand]] operator with grouping sets.
+ */
+ private def constructExpand(
+ selectedGroupByExprs: Seq[Seq[Expression]],
+ child: LogicalPlan,
+ groupByAliases: Seq[Alias],
+ gid: Attribute): LogicalPlan = {
+ // Change the nullability of group by aliases if necessary. For example, if we have
+ // GROUPING SETS ((a,b), a), we do not need to change the nullability of a, but we
+ // should change the nullabilty of b to be TRUE.
+ // TODO: For Cube/Rollup just set nullability to be `true`.
+ val expandedAttributes = groupByAliases.map { alias =>
+ if (selectedGroupByExprs.exists(!_.contains(alias.child))) {
+ alias.toAttribute.withNullability(true)
+ } else {
+ alias.toAttribute
}
+ }
- // The rightmost bit in the bitmasks corresponds to the last expression in groupByAliases
- // with 0 indicating this expression is in the grouping set. The following line of code
- // calculates the bitmask representing the expressions that absent in at least one grouping
- // set (indicated by 1).
- val nullBitmask = x.bitmasks.reduce(_ | _)
-
- val attrLength = groupByAliases.length
- val expandedAttributes = groupByAliases.zipWithIndex.map { case (a, idx) =>
- a.toAttribute.withNullability(((nullBitmask >> (attrLength - idx - 1)) & 1) == 1)
+ val groupingSetsAttributes = selectedGroupByExprs.map { groupingSetExprs =>
+ groupingSetExprs.map { expr =>
+ val alias = groupByAliases.find(_.child.semanticEquals(expr)).getOrElse(
+ failAnalysis(s"$expr doesn't show up in the GROUP BY list $groupByAliases"))
+ // Map alias to expanded attribute.
+ expandedAttributes.find(_.semanticEquals(alias.toAttribute)).getOrElse(
+ alias.toAttribute)
}
+ }
- val expand = Expand(x.bitmasks, groupByAliases, expandedAttributes, gid, x.child)
- val groupingAttrs = expand.output.drop(x.child.output.length)
+ Expand(groupingSetsAttributes, groupByAliases, expandedAttributes, gid, child)
+ }
- val aggregations: Seq[NamedExpression] = x.aggregations.map { case expr =>
- // collect all the found AggregateExpression, so we can check an expression is part of
- // any AggregateExpression or not.
- val aggsBuffer = ArrayBuffer[Expression]()
- // Returns whether the expression belongs to any expressions in `aggsBuffer` or not.
- def isPartOfAggregation(e: Expression): Boolean = {
- aggsBuffer.exists(a => a.find(_ eq e).isDefined)
+ /*
+ * Construct new aggregate expressions by replacing grouping functions.
+ */
+ private def constructAggregateExprs(
+ groupByExprs: Seq[Expression],
+ aggregations: Seq[NamedExpression],
+ groupByAliases: Seq[Alias],
+ groupingAttrs: Seq[Expression],
+ gid: Attribute): Seq[NamedExpression] = aggregations.map {
+ // collect all the found AggregateExpression, so we can check an expression is part of
+ // any AggregateExpression or not.
+ val aggsBuffer = ArrayBuffer[Expression]()
+ // Returns whether the expression belongs to any expressions in `aggsBuffer` or not.
+ def isPartOfAggregation(e: Expression): Boolean = {
+ aggsBuffer.exists(a => a.find(_ eq e).isDefined)
+ }
+ replaceGroupingFunc(_, groupByExprs, gid).transformDown {
+ // AggregateExpression should be computed on the unmodified value of its argument
+ // expressions, so we should not replace any references to grouping expression
+ // inside it.
+ case e: AggregateExpression =>
+ aggsBuffer += e
+ e
+ case e if isPartOfAggregation(e) => e
+ case e =>
+ // Replace expression by expand output attribute.
+ val index = groupByAliases.indexWhere(_.child.semanticEquals(e))
+ if (index == -1) {
+ e
+ } else {
+ groupingAttrs(index)
}
- replaceGroupingFunc(expr, x.groupByExprs, gid).transformDown {
- // AggregateExpression should be computed on the unmodified value of its argument
- // expressions, so we should not replace any references to grouping expression
- // inside it.
- case e: AggregateExpression =>
- aggsBuffer += e
- e
- case e if isPartOfAggregation(e) => e
- case e =>
- val index = groupByAliases.indexWhere(_.child.semanticEquals(e))
- if (index == -1) {
- e
- } else {
- groupingAttrs(index)
- }
- }.asInstanceOf[NamedExpression]
- }
+ }.asInstanceOf[NamedExpression]
+ }
- Aggregate(groupingAttrs, aggregations, expand)
+ /*
+ * Construct [[Aggregate]] operator from Cube/Rollup/GroupingSets.
+ */
+ private def constructAggregate(
+ selectedGroupByExprs: Seq[Seq[Expression]],
+ groupByExprs: Seq[Expression],
+ aggregationExprs: Seq[NamedExpression],
+ child: LogicalPlan): LogicalPlan = {
+ val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)()
- case f @ Filter(cond, child) if hasGroupingFunction(cond) =>
- val groupingExprs = findGroupingExprs(child)
- // The unresolved grouping id will be resolved by ResolveMissingReferences
- val newCond = replaceGroupingFunc(cond, groupingExprs, VirtualColumn.groupingIdAttribute)
- f.copy(condition = newCond)
+ // Expand works by setting grouping expressions to null as determined by the
+ // `selectedGroupByExprs`. To prevent these null values from being used in an aggregate
+ // instead of the original value we need to create new aliases for all group by expressions
+ // that will only be used for the intended purpose.
+ val groupByAliases = constructGroupByAlias(groupByExprs)
- case s @ Sort(order, _, child) if order.exists(hasGroupingFunction) =>
- val groupingExprs = findGroupingExprs(child)
- val gid = VirtualColumn.groupingIdAttribute
- // The unresolved grouping id will be resolved by ResolveMissingReferences
- val newOrder = order.map(replaceGroupingFunc(_, groupingExprs, gid).asInstanceOf[SortOrder])
- s.copy(order = newOrder)
+ val expand = constructExpand(selectedGroupByExprs, child, groupByAliases, gid)
+ val groupingAttrs = expand.output.drop(child.output.length)
+
+ val aggregations = constructAggregateExprs(
+ groupByExprs, aggregationExprs, groupByAliases, groupingAttrs, gid)
+
+ Aggregate(groupingAttrs, aggregations, expand)
}
private def findGroupingExprs(plan: LogicalPlan): Seq[Expression] = {
@@ -369,6 +393,41 @@ class Analyzer(
failAnalysis(s"grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")
}
}
+
+ // This require transformUp to replace grouping()/grouping_id() in resolved Filter/Sort
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
+ case a if !a.childrenResolved => a // be sure all of the children are resolved.
+ case p if p.expressions.exists(hasGroupingAttribute) =>
+ failAnalysis(
+ s"${VirtualColumn.hiveGroupingIdName} is deprecated; use grouping_id() instead")
+
+ // Ensure group by expressions and aggregate expressions have been resolved.
+ case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child)
+ if (groupByExprs ++ aggregateExpressions).forall(_.resolved) =>
+ constructAggregate(cubeExprs(groupByExprs), groupByExprs, aggregateExpressions, child)
+ case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child)
+ if (groupByExprs ++ aggregateExpressions).forall(_.resolved) =>
+ constructAggregate(rollupExprs(groupByExprs), groupByExprs, aggregateExpressions, child)
+ // Ensure all the expressions have been resolved.
+ case x: GroupingSets if x.expressions.forall(_.resolved) =>
+ constructAggregate(x.selectedGroupByExprs, x.groupByExprs, x.aggregations, x.child)
+
+ // We should make sure all expressions in condition have been resolved.
+ case f @ Filter(cond, child) if hasGroupingFunction(cond) && cond.resolved =>
+ val groupingExprs = findGroupingExprs(child)
+ // The unresolved grouping id will be resolved by ResolveMissingReferences
+ val newCond = replaceGroupingFunc(cond, groupingExprs, VirtualColumn.groupingIdAttribute)
+ f.copy(condition = newCond)
+
+ // We should make sure all [[SortOrder]]s have been resolved.
+ case s @ Sort(order, _, child)
+ if order.exists(hasGroupingFunction) && order.forall(_.resolved) =>
+ val groupingExprs = findGroupingExprs(child)
+ val gid = VirtualColumn.groupingIdAttribute
+ // The unresolved grouping id will be resolved by ResolveMissingReferences
+ val newOrder = order.map(replaceGroupingFunc(_, groupingExprs, gid).asInstanceOf[SortOrder])
+ s.copy(order = newOrder)
+ }
}
object ResolvePivot extends Rule[LogicalPlan] {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 4b151c81d8..2c4db0d2c3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -492,33 +492,18 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
ctx: AggregationContext,
selectExpressions: Seq[NamedExpression],
query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
- import ctx._
- val groupByExpressions = expressionList(groupingExpressions)
+ val groupByExpressions = expressionList(ctx.groupingExpressions)
- if (GROUPING != null) {
+ if (ctx.GROUPING != null) {
// GROUP BY .... GROUPING SETS (...)
- val expressionMap = groupByExpressions.zipWithIndex.toMap
- val numExpressions = expressionMap.size
- val mask = (1 << numExpressions) - 1
- val masks = ctx.groupingSet.asScala.map {
- _.expression.asScala.foldLeft(mask) {
- case (bitmap, eCtx) =>
- // Find the index of the expression.
- val e = typedVisit[Expression](eCtx)
- val index = expressionMap.find(_._1.semanticEquals(e)).map(_._2).getOrElse(
- throw new ParseException(
- s"$e doesn't show up in the GROUP BY list", ctx))
- // 0 means that the column at the given index is a grouping column, 1 means it is not,
- // so we unset the bit in bitmap.
- bitmap & ~(1 << (numExpressions - 1 - index))
- }
- }
- GroupingSets(masks, groupByExpressions, query, selectExpressions)
+ val selectedGroupByExprs =
+ ctx.groupingSet.asScala.map(_.expression.asScala.map(e => expression(e)))
+ GroupingSets(selectedGroupByExprs, groupByExpressions, query, selectExpressions)
} else {
// GROUP BY .... (WITH CUBE | WITH ROLLUP)?
- val mappedGroupByExpressions = if (CUBE != null) {
+ val mappedGroupByExpressions = if (ctx.CUBE != null) {
Seq(Cube(groupByExpressions))
- } else if (ROLLUP != null) {
+ } else if (ctx.ROLLUP != null) {
Seq(Rollup(groupByExpressions))
} else {
groupByExpressions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 65ceab2ce2..dcae7b026f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -17,8 +17,7 @@
package org.apache.spark.sql.catalyst.plans.logical
-import scala.collection.mutable.ArrayBuffer
-
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.catalog.CatalogTypes
@@ -523,51 +522,56 @@ case class Window(
object Expand {
/**
- * Extract attribute set according to the grouping id.
+ * Build bit mask from attributes of selected grouping set. A bit in the bitmask is corresponding
+ * to an attribute in group by attributes sequence, the selected attribute has corresponding bit
+ * set to 0 and otherwise set to 1. For example, if we have GroupBy attributes (a, b, c, d), the
+ * bitmask 5(whose binary form is 0101) represents grouping set (a, c).
*
- * @param bitmask bitmask to represent the selected of the attribute sequence
- * @param attrs the attributes in sequence
- * @return the attributes of non selected specified via bitmask (with the bit set to 1)
+ * @param groupingSetAttrs The attributes of selected grouping set
+ * @param attrMap Mapping group by attributes to its index in attributes sequence
+ * @return The bitmask which represents the selected attributes out of group by attributes.
*/
- private def buildNonSelectAttrSet(
- bitmask: Int,
- attrs: Seq[Attribute]): AttributeSet = {
- val nonSelect = new ArrayBuffer[Attribute]()
-
- var bit = attrs.length - 1
- while (bit >= 0) {
- if (((bitmask >> bit) & 1) == 1) nonSelect += attrs(attrs.length - bit - 1)
- bit -= 1
- }
-
- AttributeSet(nonSelect)
+ private def buildBitmask(
+ groupingSetAttrs: Seq[Attribute],
+ attrMap: Map[Attribute, Int]): Int = {
+ val numAttributes = attrMap.size
+ val mask = (1 << numAttributes) - 1
+ // Calculate the attrbute masks of selected grouping set. For example, if we have GroupBy
+ // attributes (a, b, c, d), grouping set (a, c) will produce the following sequence:
+ // (15, 7, 13), whose binary form is (1111, 0111, 1101)
+ val masks = (mask +: groupingSetAttrs.map(attrMap).map(index =>
+ // 0 means that the column at the given index is a grouping column, 1 means it is not,
+ // so we unset the bit in bitmap.
+ ~(1 << (numAttributes - 1 - index))
+ ))
+ // Reduce masks to generate an bitmask for the selected grouping set.
+ masks.reduce(_ & _)
}
/**
* Apply the all of the GroupExpressions to every input row, hence we will get
* multiple output rows for an input row.
*
- * @param bitmasks The bitmask set represents the grouping sets
+ * @param groupingSetsAttrs The attributes of grouping sets
* @param groupByAliases The aliased original group by expressions
* @param groupByAttrs The attributes of aliased group by expressions
* @param gid Attribute of the grouping id
* @param child Child operator
*/
def apply(
- bitmasks: Seq[Int],
+ groupingSetsAttrs: Seq[Seq[Attribute]],
groupByAliases: Seq[Alias],
groupByAttrs: Seq[Attribute],
gid: Attribute,
child: LogicalPlan): Expand = {
+ val attrMap = groupByAttrs.zipWithIndex.toMap
+
// Create an array of Projections for the child projection, and replace the projections'
// expressions which equal GroupBy expressions with Literal(null), if those expressions
- // are not set for this grouping set (according to the bit mask).
- val projections = bitmasks.map { bitmask =>
- // get the non selected grouping attributes according to the bit mask
- val nonSelectedGroupAttrSet = buildNonSelectAttrSet(bitmask, groupByAttrs)
-
+ // are not set for this grouping set.
+ val projections = groupingSetsAttrs.map { groupingSetAttrs =>
child.output ++ groupByAttrs.map { attr =>
- if (nonSelectedGroupAttrSet.contains(attr)) {
+ if (!groupingSetAttrs.contains(attr)) {
// if the input attribute in the Invalid Grouping Expression set of for this group
// replace it with constant null
Literal.create(null, attr.dataType)
@@ -575,7 +579,7 @@ object Expand {
attr
}
// groupingId is the last output, here we use the bit mask as the concrete value for it.
- } :+ Literal.create(bitmask, IntegerType)
+ } :+ Literal.create(buildBitmask(groupingSetAttrs, attrMap), IntegerType)
}
// the `groupByAttrs` has different meaning in `Expand.output`, it could be the original
@@ -616,16 +620,15 @@ case class Expand(
*
* We will transform GROUPING SETS into logical plan Aggregate(.., Expand) in Analyzer
*
- * @param bitmasks A list of bitmasks, each of the bitmask indicates the selected
- * GroupBy expressions
- * @param groupByExprs The Group By expressions candidates, take effective only if the
- * associated bit in the bitmask set to 1.
+ * @param selectedGroupByExprs A sequence of selected GroupBy expressions, all exprs should
+ * exists in groupByExprs.
+ * @param groupByExprs The Group By expressions candidates.
* @param child Child operator
* @param aggregations The Aggregation expressions, those non selected group by expressions
* will be considered as constant null if it appears in the expressions
*/
case class GroupingSets(
- bitmasks: Seq[Int],
+ selectedGroupByExprs: Seq[Seq[Expression]],
groupByExprs: Seq[Expression],
child: LogicalPlan,
aggregations: Seq[NamedExpression]) extends UnaryNode {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala
new file mode 100644
index 0000000000..2a0205bdc9
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala
@@ -0,0 +1,291 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.types._
+
+class ResolveGroupingAnalyticsSuite extends AnalysisTest {
+
+ lazy val a = 'a.int
+ lazy val b = 'b.string
+ lazy val c = 'c.string
+ lazy val unresolved_a = UnresolvedAttribute("a")
+ lazy val unresolved_b = UnresolvedAttribute("b")
+ lazy val unresolved_c = UnresolvedAttribute("c")
+ lazy val gid = 'spark_grouping_id.int.withNullability(false)
+ lazy val hive_gid = 'grouping__id.int.withNullability(false)
+ lazy val grouping_a = Cast(ShiftRight(gid, 1) & 1, ByteType)
+ lazy val nulInt = Literal(null, IntegerType)
+ lazy val nulStr = Literal(null, StringType)
+ lazy val r1 = LocalRelation(a, b, c)
+
+ test("rollupExprs") {
+ val testRollup = (exprs: Seq[Expression], rollup: Seq[Seq[Expression]]) => {
+ val result = SimpleAnalyzer.ResolveGroupingAnalytics.rollupExprs(exprs)
+ assert(result.sortBy(_.hashCode) == rollup.sortBy(_.hashCode))
+ }
+
+ testRollup(Seq(a, b, c), Seq(Seq(), Seq(a), Seq(a, b), Seq(a, b, c)))
+ testRollup(Seq(c, b, a), Seq(Seq(), Seq(c), Seq(c, b), Seq(c, b, a)))
+ testRollup(Seq(a), Seq(Seq(), Seq(a)))
+ testRollup(Seq(), Seq(Seq()))
+ }
+
+ test("cubeExprs") {
+ val testCube = (exprs: Seq[Expression], cube: Seq[Seq[Expression]]) => {
+ val result = SimpleAnalyzer.ResolveGroupingAnalytics.cubeExprs(exprs)
+ assert(result.sortBy(_.hashCode) == cube.sortBy(_.hashCode))
+ }
+
+ testCube(Seq(a, b, c),
+ Seq(Seq(), Seq(a), Seq(b), Seq(c), Seq(a, b), Seq(a, c), Seq(b, c), Seq(a, b, c)))
+ testCube(Seq(c, b, a),
+ Seq(Seq(), Seq(a), Seq(b), Seq(c), Seq(c, b), Seq(c, a), Seq(b, a), Seq(c, b, a)))
+ testCube(Seq(a), Seq(Seq(), Seq(a)))
+ testCube(Seq(), Seq(Seq()))
+ }
+
+ test("grouping sets") {
+ val originalPlan = GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)),
+ Seq(unresolved_a, unresolved_b), r1,
+ Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c))))
+ val expected = Aggregate(Seq(a, b, gid), Seq(a, b, count(c).as("count(c)")),
+ Expand(
+ Seq(Seq(a, b, c, nulInt, nulStr, 3), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, a, b, 0)),
+ Seq(a, b, c, a, b, gid),
+ Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)))
+ checkAnalysis(originalPlan, expected)
+
+ val originalPlan2 = GroupingSets(Seq(), Seq(unresolved_a, unresolved_b), r1,
+ Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c))))
+ val expected2 = Aggregate(Seq(a, b, gid), Seq(a, b, count(c).as("count(c)")),
+ Expand(
+ Seq(),
+ Seq(a, b, c, a, b, gid),
+ Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)))
+ checkAnalysis(originalPlan2, expected2)
+
+ val originalPlan3 = GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b),
+ Seq(unresolved_c)), Seq(unresolved_a, unresolved_b), r1,
+ Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c))))
+ assertAnalysisError(originalPlan3, Seq("doesn't show up in the GROUP BY list"))
+ }
+
+ test("cube") {
+ val originalPlan = Aggregate(Seq(Cube(Seq(unresolved_a, unresolved_b))),
+ Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c))), r1)
+ val expected = Aggregate(Seq(a, b, gid), Seq(a, b, count(c).as("count(c)")),
+ Expand(
+ Seq(Seq(a, b, c, a, b, 0), Seq(a, b, c, a, nulStr, 1),
+ Seq(a, b, c, nulInt, b, 2), Seq(a, b, c, nulInt, nulStr, 3)),
+ Seq(a, b, c, a, b, gid),
+ Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)))
+ checkAnalysis(originalPlan, expected)
+
+ val originalPlan2 = Aggregate(Seq(Cube(Seq())), Seq(UnresolvedAlias(count(unresolved_c))), r1)
+ val expected2 = Aggregate(Seq(gid), Seq(count(c).as("count(c)")),
+ Expand(
+ Seq(Seq(a, b, c, 0)),
+ Seq(a, b, c, gid),
+ Project(Seq(a, b, c), r1)))
+ checkAnalysis(originalPlan2, expected2)
+ }
+
+ test("rollup") {
+ val originalPlan = Aggregate(Seq(Rollup(Seq(unresolved_a, unresolved_b))),
+ Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c))), r1)
+ val expected = Aggregate(Seq(a, b, gid), Seq(a, b, count(c).as("count(c)")),
+ Expand(
+ Seq(Seq(a, b, c, a, b, 0), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, nulInt, nulStr, 3)),
+ Seq(a, b, c, a, b, gid),
+ Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)))
+ checkAnalysis(originalPlan, expected)
+
+ val originalPlan2 = Aggregate(Seq(Rollup(Seq())), Seq(UnresolvedAlias(count(unresolved_c))), r1)
+ val expected2 = Aggregate(Seq(gid), Seq(count(c).as("count(c)")),
+ Expand(
+ Seq(Seq(a, b, c, 0)),
+ Seq(a, b, c, gid),
+ Project(Seq(a, b, c), r1)))
+ checkAnalysis(originalPlan2, expected2)
+ }
+
+ test("grouping function") {
+ // GrouingSets
+ val originalPlan = GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)),
+ Seq(unresolved_a, unresolved_b), r1,
+ Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c)),
+ UnresolvedAlias(Grouping(unresolved_a))))
+ val expected = Aggregate(Seq(a, b, gid),
+ Seq(a, b, count(c).as("count(c)"), grouping_a.as("grouping(a)")),
+ Expand(
+ Seq(Seq(a, b, c, nulInt, nulStr, 3), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, a, b, 0)),
+ Seq(a, b, c, a, b, gid),
+ Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)))
+ checkAnalysis(originalPlan, expected)
+
+ // Cube
+ val originalPlan2 = Aggregate(Seq(Cube(Seq(unresolved_a, unresolved_b))),
+ Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c)),
+ UnresolvedAlias(Grouping(unresolved_a))), r1)
+ val expected2 = Aggregate(Seq(a, b, gid),
+ Seq(a, b, count(c).as("count(c)"), grouping_a.as("grouping(a)")),
+ Expand(
+ Seq(Seq(a, b, c, a, b, 0), Seq(a, b, c, a, nulStr, 1),
+ Seq(a, b, c, nulInt, b, 2), Seq(a, b, c, nulInt, nulStr, 3)),
+ Seq(a, b, c, a, b, gid),
+ Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)))
+ checkAnalysis(originalPlan2, expected2)
+
+ // Rollup
+ val originalPlan3 = Aggregate(Seq(Rollup(Seq(unresolved_a, unresolved_b))),
+ Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c)),
+ UnresolvedAlias(Grouping(unresolved_a))), r1)
+ val expected3 = Aggregate(Seq(a, b, gid),
+ Seq(a, b, count(c).as("count(c)"), grouping_a.as("grouping(a)")),
+ Expand(
+ Seq(Seq(a, b, c, a, b, 0), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, nulInt, nulStr, 3)),
+ Seq(a, b, c, a, b, gid),
+ Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)))
+ checkAnalysis(originalPlan3, expected3)
+ }
+
+ test("grouping_id") {
+ // GrouingSets
+ val originalPlan = GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)),
+ Seq(unresolved_a, unresolved_b), r1,
+ Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c)),
+ UnresolvedAlias(GroupingID(Seq(unresolved_a, unresolved_b)))))
+ val expected = Aggregate(Seq(a, b, gid),
+ Seq(a, b, count(c).as("count(c)"), gid.as("grouping_id(a, b)")),
+ Expand(
+ Seq(Seq(a, b, c, nulInt, nulStr, 3), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, a, b, 0)),
+ Seq(a, b, c, a, b, gid),
+ Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)))
+ checkAnalysis(originalPlan, expected)
+
+ // Cube
+ val originalPlan2 = Aggregate(Seq(Cube(Seq(unresolved_a, unresolved_b))),
+ Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c)),
+ UnresolvedAlias(GroupingID(Seq(unresolved_a, unresolved_b)))), r1)
+ val expected2 = Aggregate(Seq(a, b, gid),
+ Seq(a, b, count(c).as("count(c)"), gid.as("grouping_id(a, b)")),
+ Expand(
+ Seq(Seq(a, b, c, a, b, 0), Seq(a, b, c, a, nulStr, 1),
+ Seq(a, b, c, nulInt, b, 2), Seq(a, b, c, nulInt, nulStr, 3)),
+ Seq(a, b, c, a, b, gid),
+ Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)))
+ checkAnalysis(originalPlan2, expected2)
+
+ // Rollup
+ val originalPlan3 = Aggregate(Seq(Rollup(Seq(unresolved_a, unresolved_b))),
+ Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c)),
+ UnresolvedAlias(GroupingID(Seq(unresolved_a, unresolved_b)))), r1)
+ val expected3 = Aggregate(Seq(a, b, gid),
+ Seq(a, b, count(c).as("count(c)"), gid.as("grouping_id(a, b)")),
+ Expand(
+ Seq(Seq(a, b, c, a, b, 0), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, nulInt, nulStr, 3)),
+ Seq(a, b, c, a, b, gid),
+ Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)))
+ checkAnalysis(originalPlan3, expected3)
+ }
+
+ test("filter with grouping function") {
+ // Filter with Grouping function
+ val originalPlan = Filter(Grouping(unresolved_a) === 0,
+ GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)),
+ Seq(unresolved_a, unresolved_b), r1, Seq(unresolved_a, unresolved_b)))
+ val expected = Project(Seq(a, b), Filter(Cast(grouping_a, IntegerType) === 0,
+ Aggregate(Seq(a, b, gid),
+ Seq(a, b, gid),
+ Expand(
+ Seq(Seq(a, b, c, nulInt, nulStr, 3), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, a, b, 0)),
+ Seq(a, b, c, a, b, gid),
+ Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)))))
+ checkAnalysis(originalPlan, expected)
+
+ val originalPlan2 = Filter(Grouping(unresolved_a) === 0,
+ Aggregate(Seq(unresolved_a), Seq(UnresolvedAlias(count(unresolved_b))), r1))
+ assertAnalysisError(originalPlan2,
+ Seq("grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup"))
+
+ // Filter with GroupingID
+ val originalPlan3 = Filter(GroupingID(Seq(unresolved_a, unresolved_b)) === 1,
+ GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)),
+ Seq(unresolved_a, unresolved_b), r1, Seq(unresolved_a, unresolved_b)))
+ val expected3 = Project(Seq(a, b), Filter(gid === 1,
+ Aggregate(Seq(a, b, gid),
+ Seq(a, b, gid),
+ Expand(
+ Seq(Seq(a, b, c, nulInt, nulStr, 3), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, a, b, 0)),
+ Seq(a, b, c, a, b, gid),
+ Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)))))
+ checkAnalysis(originalPlan3, expected3)
+
+ val originalPlan4 = Filter(GroupingID(Seq(unresolved_a)) === 1,
+ Aggregate(Seq(unresolved_a), Seq(UnresolvedAlias(count(unresolved_b))), r1))
+ assertAnalysisError(originalPlan4,
+ Seq("grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup"))
+ }
+
+ test("sort with grouping function") {
+ // Sort with Grouping function
+ val originalPlan = Sort(
+ Seq(SortOrder(Grouping(unresolved_a), Ascending)), true,
+ GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)),
+ Seq(unresolved_a, unresolved_b), r1, Seq(unresolved_a, unresolved_b)))
+ val expected = Project(Seq(a, b), Sort(
+ Seq(SortOrder('aggOrder.byte.withNullability(false), Ascending)), true,
+ Aggregate(Seq(a, b, gid),
+ Seq(a, b, grouping_a.as("aggOrder")),
+ Expand(
+ Seq(Seq(a, b, c, nulInt, nulStr, 3), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, a, b, 0)),
+ Seq(a, b, c, a, b, gid),
+ Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)))))
+ checkAnalysis(originalPlan, expected)
+
+ val originalPlan2 = Sort(Seq(SortOrder(Grouping(unresolved_a), Ascending)), true,
+ Aggregate(Seq(unresolved_a), Seq(unresolved_a, UnresolvedAlias(count(unresolved_b))), r1))
+ assertAnalysisError(originalPlan2,
+ Seq("grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup"))
+
+ // Sort with GroupingID
+ val originalPlan3 = Sort(
+ Seq(SortOrder(GroupingID(Seq(unresolved_a, unresolved_b)), Ascending)), true,
+ GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)),
+ Seq(unresolved_a, unresolved_b), r1, Seq(unresolved_a, unresolved_b)))
+ val expected3 = Project(Seq(a, b), Sort(
+ Seq(SortOrder('aggOrder.int.withNullability(false), Ascending)), true,
+ Aggregate(Seq(a, b, gid),
+ Seq(a, b, gid.as("aggOrder")),
+ Expand(
+ Seq(Seq(a, b, c, nulInt, nulStr, 3), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, a, b, 0)),
+ Seq(a, b, c, a, b, gid),
+ Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)))))
+ checkAnalysis(originalPlan3, expected3)
+
+ val originalPlan4 = Sort(
+ Seq(SortOrder(GroupingID(Seq(unresolved_a)), Ascending)), true,
+ Aggregate(Seq(unresolved_a), Seq(unresolved_a, UnresolvedAlias(count(unresolved_b))), r1))
+ assertAnalysisError(originalPlan4,
+ Seq("grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup"))
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
index 7400f3430e..5f0f6ee479 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
@@ -233,9 +233,8 @@ class PlanParserSuite extends PlanTest {
// Grouping Sets
assertEqual(s"$sql grouping sets((a, b), (a), ())",
- GroupingSets(Seq(0, 1, 3), Seq('a, 'b), table("d"), Seq('a, 'b, 'sum.function('c).as("c"))))
- intercept(s"$sql grouping sets((a, b), (c), ())",
- "c doesn't show up in the GROUP BY list")
+ GroupingSets(Seq(Seq('a, 'b), Seq('a), Seq()), Seq('a, 'b), table("d"),
+ Seq('a, 'b, 'sum.function('c).as("c"))))
}
test("limit") {