aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorCheng Hao <hao.cheng@intel.com>2014-12-18 18:58:29 -0800
committerMichael Armbrust <michael@databricks.com>2014-12-18 18:58:29 -0800
commitf728e0fe7e860fe6dd3437e248472a67a2d435f8 (patch)
tree28a78bc4c5a9820d9558471882760c0134997c12 /sql/catalyst
parent9804a759b68f56eceb8a2f4ea90f76a92b5f9f67 (diff)
downloadspark-f728e0fe7e860fe6dd3437e248472a67a2d435f8.tar.gz
spark-f728e0fe7e860fe6dd3437e248472a67a2d435f8.tar.bz2
spark-f728e0fe7e860fe6dd3437e248472a67a2d435f8.zip
[SPARK-2663] [SQL] Support the Grouping Set
Add support for `GROUPING SETS`, `ROLLUP`, `CUBE` and the the virtual column `GROUPING__ID`. More details on how to use the `GROUPING SETS" can be found at: https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation,+Cube,+Grouping+and+Rollup https://issues.apache.org/jira/secure/attachment/12676811/grouping_set.pdf The generic idea of the implementations are : 1 Replace the `ROLLUP`, `CUBE` with `GROUPING SETS` 2 Explode each of the input row, and then feed them to `Aggregate` * Each grouping set are represented as the bit mask for the `GroupBy Expression List`, for each bit, `1` means the expression is selected, otherwise `0` (left is the lower bit, and right is the higher bit in the `GroupBy Expression List`) * Several of projections are constructed according to the grouping sets, and within each projection(Seq[Expression), we replace those expressions with `Literal(null)` if it's not selected in the grouping set (based on the bit mask) * Output Schema of `Explode` is `child.output :+ grouping__id` * GroupBy Expressions of `Aggregate` is `GroupBy Expression List :+ grouping__id` * Keep the `Aggregation expressions` the same for the `Aggregate` The expressions substitutions happen in Logic Plan analyzing, so we will benefit from the Logical Plan optimization (e.g. expression constant folding, and map side aggregation etc.), Only an `Explosive` operator added for Physical Plan, which will explode the rows according the pre-set projections. A known issue will be done in the follow up PR: * Optimization `ColumnPruning` is not supported yet for `Explosive` node. Author: Cheng Hao <hao.cheng@intel.com> Closes #1567 from chenghao-intel/grouping_sets and squashes the following commits: fe65fcc [Cheng Hao] Remove the extra space 3547056 [Cheng Hao] Add more doc and Simplify the Expand a7c869d [Cheng Hao] update code as feedbacks d23c672 [Cheng Hao] Add GroupingExpression to replace the Seq[Expression] 414b165 [Cheng Hao] revert the unnecessary changes ec276c6 [Cheng Hao] Support Rollup/Cube/GroupingSets
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala95
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala13
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala83
4 files changed, 195 insertions, 1 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 3705fcc1f1..1c4088b843 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -17,11 +17,13 @@
package org.apache.spark.sql.catalyst.analysis
+import org.apache.spark.util.collection.OpenHashSet
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.types.StructType
+import org.apache.spark.sql.catalyst.types.IntegerType
/**
* A trivial [[Analyzer]] with an [[EmptyCatalog]] and [[EmptyFunctionRegistry]]. Used for testing
@@ -56,6 +58,7 @@ class Analyzer(catalog: Catalog,
Batch("Resolution", fixedPoint,
ResolveReferences ::
ResolveRelations ::
+ ResolveGroupingAnalytics ::
ResolveSortReferences ::
NewRelationInstances ::
ImplicitGenerate ::
@@ -102,6 +105,93 @@ class Analyzer(catalog: Catalog,
}
}
+ object ResolveGroupingAnalytics extends Rule[LogicalPlan] {
+ /**
+ * Extract attribute set according to the grouping id
+ * @param bitmask bitmask to represent the selected of the attribute sequence
+ * @param exprs the attributes in sequence
+ * @return the attributes of non selected specified via bitmask (with the bit set to 1)
+ */
+ private def buildNonSelectExprSet(bitmask: Int, exprs: Seq[Expression])
+ : OpenHashSet[Expression] = {
+ val set = new OpenHashSet[Expression](2)
+
+ var bit = exprs.length - 1
+ while (bit >= 0) {
+ if (((bitmask >> bit) & 1) == 0) set.add(exprs(bit))
+ bit -= 1
+ }
+
+ set
+ }
+
+ /*
+ * GROUP BY a, b, c, WITH ROLLUP
+ * is equivalent to
+ * GROUP BY a, b, c GROUPING SETS ( (a, b, c), (a, b), (a), ( )).
+ * Group Count: N + 1 (N is the number of group expression)
+ *
+ * We need to get all of its subsets for the rule described above, the subset is
+ * represented as the bit masks.
+ */
+ def bitmasks(r: Rollup): Seq[Int] = {
+ Seq.tabulate(r.groupByExprs.length + 1)(idx => {(1 << idx) - 1})
+ }
+
+ /*
+ * GROUP BY a, b, c, WITH CUBE
+ * is equivalent to
+ * GROUP BY a, b, c GROUPING SETS ( (a, b, c), (a, b), (b, c), (a, c), (a), (b), (c), ( ) ).
+ * Group Count: 2^N (N is the number of group expression)
+ *
+ * We need to get all of its sub sets for a given GROUPBY expressions, the subset is
+ * represented as the bit masks.
+ */
+ def bitmasks(c: Cube): Seq[Int] = {
+ Seq.tabulate(1 << c.groupByExprs.length)(i => i)
+ }
+
+ /**
+ * Create an array of Projections for the child projection, and replace the projections'
+ * expressions which equal GroupBy expressions with Literal(null), if those expressions
+ * are not set for this grouping set (according to the bit mask).
+ */
+ private[this] def expand(g: GroupingSets): Seq[GroupExpression] = {
+ val result = new scala.collection.mutable.ArrayBuffer[GroupExpression]
+
+ g.bitmasks.foreach { bitmask =>
+ // get the non selected grouping attributes according to the bit mask
+ val nonSelectedGroupExprSet = buildNonSelectExprSet(bitmask, g.groupByExprs)
+
+ val substitution = (g.child.output :+ g.gid).map(expr => expr transformDown {
+ case x: Expression if nonSelectedGroupExprSet.contains(x) =>
+ // if the input attribute in the Invalid Grouping Expression set of for this group
+ // replace it with constant null
+ Literal(null, expr.dataType)
+ case x if x == g.gid =>
+ // replace the groupingId with concrete value (the bit mask)
+ Literal(bitmask, IntegerType)
+ })
+
+ result += GroupExpression(substitution)
+ }
+
+ result.toSeq
+ }
+
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case a: Cube if a.resolved =>
+ GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations, a.gid)
+ case a: Rollup if a.resolved =>
+ GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations, a.gid)
+ case x: GroupingSets if x.resolved =>
+ Aggregate(
+ x.groupByExprs :+ x.gid,
+ x.aggregations,
+ Expand(expand(x), x.child.output :+ x.gid, x.child))
+ }
+ }
+
/**
* Checks for non-aggregated attributes with aggregation
*/
@@ -183,6 +273,11 @@ class Analyzer(catalog: Catalog,
case q: LogicalPlan =>
logTrace(s"Attempting to resolve ${q.simpleString}")
q transformExpressions {
+ case u @ UnresolvedAttribute(name)
+ if resolver(name, VirtualColumn.groupingIdName) &&
+ q.isInstanceOf[GroupingAnalytics] =>
+ // Resolve the virtual column GROUPING__ID for the operator GroupingAnalytics
+ q.asInstanceOf[GroupingAnalytics].gid
case u @ UnresolvedAttribute(name) =>
// Leave unchanged if resolution fails. Hopefully will be resolved next round.
val result = q.resolveChildren(name, resolver).getOrElse(u)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index bc45881e42..ac5b02c2e6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -284,6 +284,17 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression]
abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] {
self: Product =>
+}
-
+// TODO Semantically we probably not need GroupExpression
+// All we need is holding the Seq[Expression], and ONLY used in doing the
+// expressions transformation correctly. Probably will be removed since it's
+// not like a real expressions.
+case class GroupExpression(children: Seq[Expression]) extends Expression {
+ self: Product =>
+ type EvaluatedType = Seq[Any]
+ override def eval(input: Row): EvaluatedType = ???
+ override def nullable = false
+ override def foldable = false
+ override def dataType = ???
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 7634d392d4..a3c300b5d9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -187,3 +187,8 @@ case class AttributeReference(
override def toString: String = s"$name#${exprId.id}$typeSuffix"
}
+
+object VirtualColumn {
+ val groupingIdName = "grouping__id"
+ def newGroupingId = AttributeReference(groupingIdName, IntegerType, false)()
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 64b8d45ebb..a9282b98ad 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -143,6 +143,89 @@ case class Aggregate(
override def output = aggregateExpressions.map(_.toAttribute)
}
+/**
+ * Apply the all of the GroupExpressions to every input row, hence we will get
+ * multiple output rows for a input row.
+ * @param projections The group of expressions, all of the group expressions should
+ * output the same schema specified by the parameter `output`
+ * @param output The output Schema
+ * @param child Child operator
+ */
+case class Expand(
+ projections: Seq[GroupExpression],
+ output: Seq[Attribute],
+ child: LogicalPlan) extends UnaryNode
+
+trait GroupingAnalytics extends UnaryNode {
+ self: Product =>
+ def gid: AttributeReference
+ def groupByExprs: Seq[Expression]
+ def aggregations: Seq[NamedExpression]
+
+ override def output = aggregations.map(_.toAttribute)
+}
+
+/**
+ * A GROUP BY clause with GROUPING SETS can generate a result set equivalent
+ * to generated by a UNION ALL of multiple simple GROUP BY clauses.
+ *
+ * We will transform GROUPING SETS into logical plan Aggregate(.., Expand) in Analyzer
+ * @param bitmasks A list of bitmasks, each of the bitmask indicates the selected
+ * GroupBy expressions
+ * @param groupByExprs The Group By expressions candidates, take effective only if the
+ * associated bit in the bitmask set to 1.
+ * @param child Child operator
+ * @param aggregations The Aggregation expressions, those non selected group by expressions
+ * will be considered as constant null if it appears in the expressions
+ * @param gid The attribute represents the virtual column GROUPING__ID, and it's also
+ * the bitmask indicates the selected GroupBy Expressions for each
+ * aggregating output row.
+ * The associated output will be one of the value in `bitmasks`
+ */
+case class GroupingSets(
+ bitmasks: Seq[Int],
+ groupByExprs: Seq[Expression],
+ child: LogicalPlan,
+ aggregations: Seq[NamedExpression],
+ gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics
+
+/**
+ * Cube is a syntactic sugar for GROUPING SETS, and will be transformed to GroupingSets,
+ * and eventually will be transformed to Aggregate(.., Expand) in Analyzer
+ *
+ * @param groupByExprs The Group By expressions candidates.
+ * @param child Child operator
+ * @param aggregations The Aggregation expressions, those non selected group by expressions
+ * will be considered as constant null if it appears in the expressions
+ * @param gid The attribute represents the virtual column GROUPING__ID, and it's also
+ * the bitmask indicates the selected GroupBy Expressions for each
+ * aggregating output row.
+ */
+case class Cube(
+ groupByExprs: Seq[Expression],
+ child: LogicalPlan,
+ aggregations: Seq[NamedExpression],
+ gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics
+
+/**
+ * Rollup is a syntactic sugar for GROUPING SETS, and will be transformed to GroupingSets,
+ * and eventually will be transformed to Aggregate(.., Expand) in Analyzer
+ *
+ * @param groupByExprs The Group By expressions candidates, take effective only if the
+ * associated bit in the bitmask set to 1.
+ * @param child Child operator
+ * @param aggregations The Aggregation expressions, those non selected group by expressions
+ * will be considered as constant null if it appears in the expressions
+ * @param gid The attribute represents the virtual column GROUPING__ID, and it's also
+ * the bitmask indicates the selected GroupBy Expressions for each
+ * aggregating output row.
+ */
+case class Rollup(
+ groupByExprs: Seq[Expression],
+ child: LogicalPlan,
+ aggregations: Seq[NamedExpression],
+ gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics
+
case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
override def output = child.output