diff options
Diffstat (limited to 'sql/catalyst')
4 files changed, 195 insertions, 1 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 3705fcc1f1..1c4088b843 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.catalyst.analysis +import org.apache.spark.util.collection.OpenHashSet import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.types.StructType +import org.apache.spark.sql.catalyst.types.IntegerType /** * A trivial [[Analyzer]] with an [[EmptyCatalog]] and [[EmptyFunctionRegistry]]. Used for testing @@ -56,6 +58,7 @@ class Analyzer(catalog: Catalog, Batch("Resolution", fixedPoint, ResolveReferences :: ResolveRelations :: + ResolveGroupingAnalytics :: ResolveSortReferences :: NewRelationInstances :: ImplicitGenerate :: @@ -102,6 +105,93 @@ class Analyzer(catalog: Catalog, } } + object ResolveGroupingAnalytics extends Rule[LogicalPlan] { + /** + * Extract attribute set according to the grouping id + * @param bitmask bitmask to represent the selected of the attribute sequence + * @param exprs the attributes in sequence + * @return the attributes of non selected specified via bitmask (with the bit set to 1) + */ + private def buildNonSelectExprSet(bitmask: Int, exprs: Seq[Expression]) + : OpenHashSet[Expression] = { + val set = new OpenHashSet[Expression](2) + + var bit = exprs.length - 1 + while (bit >= 0) { + if (((bitmask >> bit) & 1) == 0) set.add(exprs(bit)) + bit -= 1 + } + + set + } + + /* + * GROUP BY a, b, c, WITH ROLLUP + * is equivalent to + * GROUP BY a, b, c GROUPING SETS ( (a, b, c), (a, b), (a), ( )). + * Group Count: N + 1 (N is the number of group expression) + * + * We need to get all of its subsets for the rule described above, the subset is + * represented as the bit masks. + */ + def bitmasks(r: Rollup): Seq[Int] = { + Seq.tabulate(r.groupByExprs.length + 1)(idx => {(1 << idx) - 1}) + } + + /* + * GROUP BY a, b, c, WITH CUBE + * is equivalent to + * GROUP BY a, b, c GROUPING SETS ( (a, b, c), (a, b), (b, c), (a, c), (a), (b), (c), ( ) ). + * Group Count: 2^N (N is the number of group expression) + * + * We need to get all of its sub sets for a given GROUPBY expressions, the subset is + * represented as the bit masks. + */ + def bitmasks(c: Cube): Seq[Int] = { + Seq.tabulate(1 << c.groupByExprs.length)(i => i) + } + + /** + * Create an array of Projections for the child projection, and replace the projections' + * expressions which equal GroupBy expressions with Literal(null), if those expressions + * are not set for this grouping set (according to the bit mask). + */ + private[this] def expand(g: GroupingSets): Seq[GroupExpression] = { + val result = new scala.collection.mutable.ArrayBuffer[GroupExpression] + + g.bitmasks.foreach { bitmask => + // get the non selected grouping attributes according to the bit mask + val nonSelectedGroupExprSet = buildNonSelectExprSet(bitmask, g.groupByExprs) + + val substitution = (g.child.output :+ g.gid).map(expr => expr transformDown { + case x: Expression if nonSelectedGroupExprSet.contains(x) => + // if the input attribute in the Invalid Grouping Expression set of for this group + // replace it with constant null + Literal(null, expr.dataType) + case x if x == g.gid => + // replace the groupingId with concrete value (the bit mask) + Literal(bitmask, IntegerType) + }) + + result += GroupExpression(substitution) + } + + result.toSeq + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case a: Cube if a.resolved => + GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations, a.gid) + case a: Rollup if a.resolved => + GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations, a.gid) + case x: GroupingSets if x.resolved => + Aggregate( + x.groupByExprs :+ x.gid, + x.aggregations, + Expand(expand(x), x.child.output :+ x.gid, x.child)) + } + } + /** * Checks for non-aggregated attributes with aggregation */ @@ -183,6 +273,11 @@ class Analyzer(catalog: Catalog, case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") q transformExpressions { + case u @ UnresolvedAttribute(name) + if resolver(name, VirtualColumn.groupingIdName) && + q.isInstanceOf[GroupingAnalytics] => + // Resolve the virtual column GROUPING__ID for the operator GroupingAnalytics + q.asInstanceOf[GroupingAnalytics].gid case u @ UnresolvedAttribute(name) => // Leave unchanged if resolution fails. Hopefully will be resolved next round. val result = q.resolveChildren(name, resolver).getOrElse(u) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index bc45881e42..ac5b02c2e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -284,6 +284,17 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression] abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] { self: Product => +} - +// TODO Semantically we probably not need GroupExpression +// All we need is holding the Seq[Expression], and ONLY used in doing the +// expressions transformation correctly. Probably will be removed since it's +// not like a real expressions. +case class GroupExpression(children: Seq[Expression]) extends Expression { + self: Product => + type EvaluatedType = Seq[Any] + override def eval(input: Row): EvaluatedType = ??? + override def nullable = false + override def foldable = false + override def dataType = ??? } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 7634d392d4..a3c300b5d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -187,3 +187,8 @@ case class AttributeReference( override def toString: String = s"$name#${exprId.id}$typeSuffix" } + +object VirtualColumn { + val groupingIdName = "grouping__id" + def newGroupingId = AttributeReference(groupingIdName, IntegerType, false)() +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 64b8d45ebb..a9282b98ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -143,6 +143,89 @@ case class Aggregate( override def output = aggregateExpressions.map(_.toAttribute) } +/** + * Apply the all of the GroupExpressions to every input row, hence we will get + * multiple output rows for a input row. + * @param projections The group of expressions, all of the group expressions should + * output the same schema specified by the parameter `output` + * @param output The output Schema + * @param child Child operator + */ +case class Expand( + projections: Seq[GroupExpression], + output: Seq[Attribute], + child: LogicalPlan) extends UnaryNode + +trait GroupingAnalytics extends UnaryNode { + self: Product => + def gid: AttributeReference + def groupByExprs: Seq[Expression] + def aggregations: Seq[NamedExpression] + + override def output = aggregations.map(_.toAttribute) +} + +/** + * A GROUP BY clause with GROUPING SETS can generate a result set equivalent + * to generated by a UNION ALL of multiple simple GROUP BY clauses. + * + * We will transform GROUPING SETS into logical plan Aggregate(.., Expand) in Analyzer + * @param bitmasks A list of bitmasks, each of the bitmask indicates the selected + * GroupBy expressions + * @param groupByExprs The Group By expressions candidates, take effective only if the + * associated bit in the bitmask set to 1. + * @param child Child operator + * @param aggregations The Aggregation expressions, those non selected group by expressions + * will be considered as constant null if it appears in the expressions + * @param gid The attribute represents the virtual column GROUPING__ID, and it's also + * the bitmask indicates the selected GroupBy Expressions for each + * aggregating output row. + * The associated output will be one of the value in `bitmasks` + */ +case class GroupingSets( + bitmasks: Seq[Int], + groupByExprs: Seq[Expression], + child: LogicalPlan, + aggregations: Seq[NamedExpression], + gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics + +/** + * Cube is a syntactic sugar for GROUPING SETS, and will be transformed to GroupingSets, + * and eventually will be transformed to Aggregate(.., Expand) in Analyzer + * + * @param groupByExprs The Group By expressions candidates. + * @param child Child operator + * @param aggregations The Aggregation expressions, those non selected group by expressions + * will be considered as constant null if it appears in the expressions + * @param gid The attribute represents the virtual column GROUPING__ID, and it's also + * the bitmask indicates the selected GroupBy Expressions for each + * aggregating output row. + */ +case class Cube( + groupByExprs: Seq[Expression], + child: LogicalPlan, + aggregations: Seq[NamedExpression], + gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics + +/** + * Rollup is a syntactic sugar for GROUPING SETS, and will be transformed to GroupingSets, + * and eventually will be transformed to Aggregate(.., Expand) in Analyzer + * + * @param groupByExprs The Group By expressions candidates, take effective only if the + * associated bit in the bitmask set to 1. + * @param child Child operator + * @param aggregations The Aggregation expressions, those non selected group by expressions + * will be considered as constant null if it appears in the expressions + * @param gid The attribute represents the virtual column GROUPING__ID, and it's also + * the bitmask indicates the selected GroupBy Expressions for each + * aggregating output row. + */ +case class Rollup( + groupByExprs: Seq[Expression], + child: LogicalPlan, + aggregations: Seq[NamedExpression], + gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics + case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { override def output = child.output |