diff options
author | Cheng Hao <hao.cheng@intel.com> | 2014-12-18 18:58:29 -0800 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2014-12-18 18:58:29 -0800 |
commit | f728e0fe7e860fe6dd3437e248472a67a2d435f8 (patch) | |
tree | 28a78bc4c5a9820d9558471882760c0134997c12 /sql/catalyst | |
parent | 9804a759b68f56eceb8a2f4ea90f76a92b5f9f67 (diff) | |
download | spark-f728e0fe7e860fe6dd3437e248472a67a2d435f8.tar.gz spark-f728e0fe7e860fe6dd3437e248472a67a2d435f8.tar.bz2 spark-f728e0fe7e860fe6dd3437e248472a67a2d435f8.zip |
[SPARK-2663] [SQL] Support the Grouping Set
Add support for `GROUPING SETS`, `ROLLUP`, `CUBE` and the the virtual column `GROUPING__ID`.
More details on how to use the `GROUPING SETS" can be found at: https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation,+Cube,+Grouping+and+Rollup
https://issues.apache.org/jira/secure/attachment/12676811/grouping_set.pdf
The generic idea of the implementations are :
1 Replace the `ROLLUP`, `CUBE` with `GROUPING SETS`
2 Explode each of the input row, and then feed them to `Aggregate`
* Each grouping set are represented as the bit mask for the `GroupBy Expression List`, for each bit, `1` means the expression is selected, otherwise `0` (left is the lower bit, and right is the higher bit in the `GroupBy Expression List`)
* Several of projections are constructed according to the grouping sets, and within each projection(Seq[Expression), we replace those expressions with `Literal(null)` if it's not selected in the grouping set (based on the bit mask)
* Output Schema of `Explode` is `child.output :+ grouping__id`
* GroupBy Expressions of `Aggregate` is `GroupBy Expression List :+ grouping__id`
* Keep the `Aggregation expressions` the same for the `Aggregate`
The expressions substitutions happen in Logic Plan analyzing, so we will benefit from the Logical Plan optimization (e.g. expression constant folding, and map side aggregation etc.), Only an `Explosive` operator added for Physical Plan, which will explode the rows according the pre-set projections.
A known issue will be done in the follow up PR:
* Optimization `ColumnPruning` is not supported yet for `Explosive` node.
Author: Cheng Hao <hao.cheng@intel.com>
Closes #1567 from chenghao-intel/grouping_sets and squashes the following commits:
fe65fcc [Cheng Hao] Remove the extra space
3547056 [Cheng Hao] Add more doc and Simplify the Expand
a7c869d [Cheng Hao] update code as feedbacks
d23c672 [Cheng Hao] Add GroupingExpression to replace the Seq[Expression]
414b165 [Cheng Hao] revert the unnecessary changes
ec276c6 [Cheng Hao] Support Rollup/Cube/GroupingSets
Diffstat (limited to 'sql/catalyst')
4 files changed, 195 insertions, 1 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 3705fcc1f1..1c4088b843 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.catalyst.analysis +import org.apache.spark.util.collection.OpenHashSet import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.types.StructType +import org.apache.spark.sql.catalyst.types.IntegerType /** * A trivial [[Analyzer]] with an [[EmptyCatalog]] and [[EmptyFunctionRegistry]]. Used for testing @@ -56,6 +58,7 @@ class Analyzer(catalog: Catalog, Batch("Resolution", fixedPoint, ResolveReferences :: ResolveRelations :: + ResolveGroupingAnalytics :: ResolveSortReferences :: NewRelationInstances :: ImplicitGenerate :: @@ -102,6 +105,93 @@ class Analyzer(catalog: Catalog, } } + object ResolveGroupingAnalytics extends Rule[LogicalPlan] { + /** + * Extract attribute set according to the grouping id + * @param bitmask bitmask to represent the selected of the attribute sequence + * @param exprs the attributes in sequence + * @return the attributes of non selected specified via bitmask (with the bit set to 1) + */ + private def buildNonSelectExprSet(bitmask: Int, exprs: Seq[Expression]) + : OpenHashSet[Expression] = { + val set = new OpenHashSet[Expression](2) + + var bit = exprs.length - 1 + while (bit >= 0) { + if (((bitmask >> bit) & 1) == 0) set.add(exprs(bit)) + bit -= 1 + } + + set + } + + /* + * GROUP BY a, b, c, WITH ROLLUP + * is equivalent to + * GROUP BY a, b, c GROUPING SETS ( (a, b, c), (a, b), (a), ( )). + * Group Count: N + 1 (N is the number of group expression) + * + * We need to get all of its subsets for the rule described above, the subset is + * represented as the bit masks. + */ + def bitmasks(r: Rollup): Seq[Int] = { + Seq.tabulate(r.groupByExprs.length + 1)(idx => {(1 << idx) - 1}) + } + + /* + * GROUP BY a, b, c, WITH CUBE + * is equivalent to + * GROUP BY a, b, c GROUPING SETS ( (a, b, c), (a, b), (b, c), (a, c), (a), (b), (c), ( ) ). + * Group Count: 2^N (N is the number of group expression) + * + * We need to get all of its sub sets for a given GROUPBY expressions, the subset is + * represented as the bit masks. + */ + def bitmasks(c: Cube): Seq[Int] = { + Seq.tabulate(1 << c.groupByExprs.length)(i => i) + } + + /** + * Create an array of Projections for the child projection, and replace the projections' + * expressions which equal GroupBy expressions with Literal(null), if those expressions + * are not set for this grouping set (according to the bit mask). + */ + private[this] def expand(g: GroupingSets): Seq[GroupExpression] = { + val result = new scala.collection.mutable.ArrayBuffer[GroupExpression] + + g.bitmasks.foreach { bitmask => + // get the non selected grouping attributes according to the bit mask + val nonSelectedGroupExprSet = buildNonSelectExprSet(bitmask, g.groupByExprs) + + val substitution = (g.child.output :+ g.gid).map(expr => expr transformDown { + case x: Expression if nonSelectedGroupExprSet.contains(x) => + // if the input attribute in the Invalid Grouping Expression set of for this group + // replace it with constant null + Literal(null, expr.dataType) + case x if x == g.gid => + // replace the groupingId with concrete value (the bit mask) + Literal(bitmask, IntegerType) + }) + + result += GroupExpression(substitution) + } + + result.toSeq + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case a: Cube if a.resolved => + GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations, a.gid) + case a: Rollup if a.resolved => + GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations, a.gid) + case x: GroupingSets if x.resolved => + Aggregate( + x.groupByExprs :+ x.gid, + x.aggregations, + Expand(expand(x), x.child.output :+ x.gid, x.child)) + } + } + /** * Checks for non-aggregated attributes with aggregation */ @@ -183,6 +273,11 @@ class Analyzer(catalog: Catalog, case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") q transformExpressions { + case u @ UnresolvedAttribute(name) + if resolver(name, VirtualColumn.groupingIdName) && + q.isInstanceOf[GroupingAnalytics] => + // Resolve the virtual column GROUPING__ID for the operator GroupingAnalytics + q.asInstanceOf[GroupingAnalytics].gid case u @ UnresolvedAttribute(name) => // Leave unchanged if resolution fails. Hopefully will be resolved next round. val result = q.resolveChildren(name, resolver).getOrElse(u) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index bc45881e42..ac5b02c2e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -284,6 +284,17 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression] abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] { self: Product => +} - +// TODO Semantically we probably not need GroupExpression +// All we need is holding the Seq[Expression], and ONLY used in doing the +// expressions transformation correctly. Probably will be removed since it's +// not like a real expressions. +case class GroupExpression(children: Seq[Expression]) extends Expression { + self: Product => + type EvaluatedType = Seq[Any] + override def eval(input: Row): EvaluatedType = ??? + override def nullable = false + override def foldable = false + override def dataType = ??? } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 7634d392d4..a3c300b5d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -187,3 +187,8 @@ case class AttributeReference( override def toString: String = s"$name#${exprId.id}$typeSuffix" } + +object VirtualColumn { + val groupingIdName = "grouping__id" + def newGroupingId = AttributeReference(groupingIdName, IntegerType, false)() +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 64b8d45ebb..a9282b98ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -143,6 +143,89 @@ case class Aggregate( override def output = aggregateExpressions.map(_.toAttribute) } +/** + * Apply the all of the GroupExpressions to every input row, hence we will get + * multiple output rows for a input row. + * @param projections The group of expressions, all of the group expressions should + * output the same schema specified by the parameter `output` + * @param output The output Schema + * @param child Child operator + */ +case class Expand( + projections: Seq[GroupExpression], + output: Seq[Attribute], + child: LogicalPlan) extends UnaryNode + +trait GroupingAnalytics extends UnaryNode { + self: Product => + def gid: AttributeReference + def groupByExprs: Seq[Expression] + def aggregations: Seq[NamedExpression] + + override def output = aggregations.map(_.toAttribute) +} + +/** + * A GROUP BY clause with GROUPING SETS can generate a result set equivalent + * to generated by a UNION ALL of multiple simple GROUP BY clauses. + * + * We will transform GROUPING SETS into logical plan Aggregate(.., Expand) in Analyzer + * @param bitmasks A list of bitmasks, each of the bitmask indicates the selected + * GroupBy expressions + * @param groupByExprs The Group By expressions candidates, take effective only if the + * associated bit in the bitmask set to 1. + * @param child Child operator + * @param aggregations The Aggregation expressions, those non selected group by expressions + * will be considered as constant null if it appears in the expressions + * @param gid The attribute represents the virtual column GROUPING__ID, and it's also + * the bitmask indicates the selected GroupBy Expressions for each + * aggregating output row. + * The associated output will be one of the value in `bitmasks` + */ +case class GroupingSets( + bitmasks: Seq[Int], + groupByExprs: Seq[Expression], + child: LogicalPlan, + aggregations: Seq[NamedExpression], + gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics + +/** + * Cube is a syntactic sugar for GROUPING SETS, and will be transformed to GroupingSets, + * and eventually will be transformed to Aggregate(.., Expand) in Analyzer + * + * @param groupByExprs The Group By expressions candidates. + * @param child Child operator + * @param aggregations The Aggregation expressions, those non selected group by expressions + * will be considered as constant null if it appears in the expressions + * @param gid The attribute represents the virtual column GROUPING__ID, and it's also + * the bitmask indicates the selected GroupBy Expressions for each + * aggregating output row. + */ +case class Cube( + groupByExprs: Seq[Expression], + child: LogicalPlan, + aggregations: Seq[NamedExpression], + gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics + +/** + * Rollup is a syntactic sugar for GROUPING SETS, and will be transformed to GroupingSets, + * and eventually will be transformed to Aggregate(.., Expand) in Analyzer + * + * @param groupByExprs The Group By expressions candidates, take effective only if the + * associated bit in the bitmask set to 1. + * @param child Child operator + * @param aggregations The Aggregation expressions, those non selected group by expressions + * will be considered as constant null if it appears in the expressions + * @param gid The attribute represents the virtual column GROUPING__ID, and it's also + * the bitmask indicates the selected GroupBy Expressions for each + * aggregating output row. + */ +case class Rollup( + groupByExprs: Seq[Expression], + child: LogicalPlan, + aggregations: Seq[NamedExpression], + gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics + case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { override def output = child.output |