aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCheng Hao <hao.cheng@intel.com>2015-06-23 10:52:17 -0700
committerMichael Armbrust <michael@databricks.com>2015-06-23 10:52:17 -0700
commit7b1450b666f88452e7fe969a6d59e8b24842ea39 (patch)
tree8ef3d644708111352cec4d936fd8a17bb74fe688
parent4f7fbefb8db56ecaab66bb0ac2ab124416fefe58 (diff)
downloadspark-7b1450b666f88452e7fe969a6d59e8b24842ea39.tar.gz
spark-7b1450b666f88452e7fe969a6d59e8b24842ea39.tar.bz2
spark-7b1450b666f88452e7fe969a6d59e8b24842ea39.zip
[SPARK-7235] [SQL] Refactor the grouping sets
The logical plan `Expand` takes the `output` as constructor argument, which break the references chain. We need to refactor the code, as well as the column pruning. Author: Cheng Hao <hao.cheng@intel.com> Closes #5780 from chenghao-intel/expand and squashes the following commits: 76e4aa4 [Cheng Hao] revert the change for case insenstive 7c10a83 [Cheng Hao] refactor the grouping sets
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala55
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala84
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala4
5 files changed, 78 insertions, 71 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 6311784422..0a3f5a7b5c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -192,49 +192,17 @@ class Analyzer(
Seq.tabulate(1 << c.groupByExprs.length)(i => i)
}
- /**
- * Create an array of Projections for the child projection, and replace the projections'
- * expressions which equal GroupBy expressions with Literal(null), if those expressions
- * are not set for this grouping set (according to the bit mask).
- */
- private[this] def expand(g: GroupingSets): Seq[Seq[Expression]] = {
- val result = new scala.collection.mutable.ArrayBuffer[Seq[Expression]]
-
- g.bitmasks.foreach { bitmask =>
- // get the non selected grouping attributes according to the bit mask
- val nonSelectedGroupExprs = ArrayBuffer.empty[Expression]
- var bit = g.groupByExprs.length - 1
- while (bit >= 0) {
- if (((bitmask >> bit) & 1) == 0) nonSelectedGroupExprs += g.groupByExprs(bit)
- bit -= 1
- }
-
- val substitution = (g.child.output :+ g.gid).map(expr => expr transformDown {
- case x: Expression if nonSelectedGroupExprs.find(_ semanticEquals x).isDefined =>
- // if the input attribute in the Invalid Grouping Expression set of for this group
- // replace it with constant null
- Literal.create(null, expr.dataType)
- case x if x == g.gid =>
- // replace the groupingId with concrete value (the bit mask)
- Literal.create(bitmask, IntegerType)
- })
-
- result += substitution
- }
-
- result.toSeq
- }
-
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case a: Cube if a.resolved =>
- GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations, a.gid)
- case a: Rollup if a.resolved =>
- GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations, a.gid)
- case x: GroupingSets if x.resolved =>
+ case a: Cube =>
+ GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations)
+ case a: Rollup =>
+ GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations)
+ case x: GroupingSets =>
+ val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)()
Aggregate(
- x.groupByExprs :+ x.gid,
+ x.groupByExprs :+ VirtualColumn.groupingIdAttribute,
x.aggregations,
- Expand(expand(x), x.child.output :+ x.gid, x.child))
+ Expand(x.bitmasks, x.groupByExprs, gid, x.child))
}
}
@@ -368,12 +336,7 @@ class Analyzer(
case q: LogicalPlan =>
logTrace(s"Attempting to resolve ${q.simpleString}")
- q transformExpressionsUp {
- case u @ UnresolvedAttribute(nameParts) if nameParts.length == 1 &&
- resolver(nameParts(0), VirtualColumn.groupingIdName) &&
- q.isInstanceOf[GroupingAnalytics] =>
- // Resolve the virtual column GROUPING__ID for the operator GroupingAnalytics
- q.asInstanceOf[GroupingAnalytics].gid
+ q transformExpressionsUp {
case u @ UnresolvedAttribute(nameParts) =>
// Leave unchanged if resolution fails. Hopefully will be resolved next round.
val result =
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 58dbeaf89c..9cacdceb13 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -262,5 +262,5 @@ case class PrettyAttribute(name: String) extends Attribute with trees.LeafNode[E
object VirtualColumn {
val groupingIdName: String = "grouping__id"
- def newGroupingId: AttributeReference = AttributeReference(groupingIdName, IntegerType, false)()
+ val groupingIdAttribute: UnresolvedAttribute = UnresolvedAttribute(groupingIdName)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 9132a786f7..98b4476076 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -121,6 +121,10 @@ object UnionPushdown extends Rule[LogicalPlan] {
*/
object ColumnPruning extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case a @ Aggregate(_, _, e @ Expand(_, groupByExprs, _, child))
+ if (child.outputSet -- AttributeSet(groupByExprs) -- a.references).nonEmpty =>
+ a.copy(child = e.copy(child = prunedChild(child, AttributeSet(groupByExprs) ++ a.references)))
+
// Eliminate attributes that are not needed to calculate the specified aggregates.
case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty =>
a.copy(child = Project(a.references.toSeq, child))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 7814e51628..fae339808c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.types._
+import org.apache.spark.util.collection.OpenHashSet
case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = projectList.map(_.toAttribute)
@@ -228,24 +229,76 @@ case class Window(
/**
* Apply the all of the GroupExpressions to every input row, hence we will get
* multiple output rows for a input row.
- * @param projections The group of expressions, all of the group expressions should
- * output the same schema specified by the parameter `output`
- * @param output The output Schema
+ * @param bitmasks The bitmask set represents the grouping sets
+ * @param groupByExprs The grouping by expressions
* @param child Child operator
*/
case class Expand(
- projections: Seq[Seq[Expression]],
- output: Seq[Attribute],
+ bitmasks: Seq[Int],
+ groupByExprs: Seq[Expression],
+ gid: Attribute,
child: LogicalPlan) extends UnaryNode {
override def statistics: Statistics = {
val sizeInBytes = child.statistics.sizeInBytes * projections.length
Statistics(sizeInBytes = sizeInBytes)
}
+
+ val projections: Seq[Seq[Expression]] = expand()
+
+ /**
+ * Extract attribute set according to the grouping id
+ * @param bitmask bitmask to represent the selected of the attribute sequence
+ * @param exprs the attributes in sequence
+ * @return the attributes of non selected specified via bitmask (with the bit set to 1)
+ */
+ private def buildNonSelectExprSet(bitmask: Int, exprs: Seq[Expression])
+ : OpenHashSet[Expression] = {
+ val set = new OpenHashSet[Expression](2)
+
+ var bit = exprs.length - 1
+ while (bit >= 0) {
+ if (((bitmask >> bit) & 1) == 0) set.add(exprs(bit))
+ bit -= 1
+ }
+
+ set
+ }
+
+ /**
+ * Create an array of Projections for the child projection, and replace the projections'
+ * expressions which equal GroupBy expressions with Literal(null), if those expressions
+ * are not set for this grouping set (according to the bit mask).
+ */
+ private[this] def expand(): Seq[Seq[Expression]] = {
+ val result = new scala.collection.mutable.ArrayBuffer[Seq[Expression]]
+
+ bitmasks.foreach { bitmask =>
+ // get the non selected grouping attributes according to the bit mask
+ val nonSelectedGroupExprSet = buildNonSelectExprSet(bitmask, groupByExprs)
+
+ val substitution = (child.output :+ gid).map(expr => expr transformDown {
+ case x: Expression if nonSelectedGroupExprSet.contains(x) =>
+ // if the input attribute in the Invalid Grouping Expression set of for this group
+ // replace it with constant null
+ Literal.create(null, expr.dataType)
+ case x if x == gid =>
+ // replace the groupingId with concrete value (the bit mask)
+ Literal.create(bitmask, IntegerType)
+ })
+
+ result += substitution
+ }
+
+ result.toSeq
+ }
+
+ override def output: Seq[Attribute] = {
+ child.output :+ gid
+ }
}
trait GroupingAnalytics extends UnaryNode {
self: Product =>
- def gid: AttributeReference
def groupByExprs: Seq[Expression]
def aggregations: Seq[NamedExpression]
@@ -266,17 +319,12 @@ trait GroupingAnalytics extends UnaryNode {
* @param child Child operator
* @param aggregations The Aggregation expressions, those non selected group by expressions
* will be considered as constant null if it appears in the expressions
- * @param gid The attribute represents the virtual column GROUPING__ID, and it's also
- * the bitmask indicates the selected GroupBy Expressions for each
- * aggregating output row.
- * The associated output will be one of the value in `bitmasks`
*/
case class GroupingSets(
bitmasks: Seq[Int],
groupByExprs: Seq[Expression],
child: LogicalPlan,
- aggregations: Seq[NamedExpression],
- gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics {
+ aggregations: Seq[NamedExpression]) extends GroupingAnalytics {
def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics =
this.copy(aggregations = aggs)
@@ -290,15 +338,11 @@ case class GroupingSets(
* @param child Child operator
* @param aggregations The Aggregation expressions, those non selected group by expressions
* will be considered as constant null if it appears in the expressions
- * @param gid The attribute represents the virtual column GROUPING__ID, and it's also
- * the bitmask indicates the selected GroupBy Expressions for each
- * aggregating output row.
*/
case class Cube(
groupByExprs: Seq[Expression],
child: LogicalPlan,
- aggregations: Seq[NamedExpression],
- gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics {
+ aggregations: Seq[NamedExpression]) extends GroupingAnalytics {
def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics =
this.copy(aggregations = aggs)
@@ -313,15 +357,11 @@ case class Cube(
* @param child Child operator
* @param aggregations The Aggregation expressions, those non selected group by expressions
* will be considered as constant null if it appears in the expressions
- * @param gid The attribute represents the virtual column GROUPING__ID, and it's also
- * the bitmask indicates the selected GroupBy Expressions for each
- * aggregating output row.
*/
case class Rollup(
groupByExprs: Seq[Expression],
child: LogicalPlan,
- aggregations: Seq[NamedExpression],
- gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics {
+ aggregations: Seq[NamedExpression]) extends GroupingAnalytics {
def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics =
this.copy(aggregations = aggs)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 5c420eb9d7..1ff1cc224d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -308,8 +308,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.Project(projectList, planLater(child)) :: Nil
case logical.Filter(condition, child) =>
execution.Filter(condition, planLater(child)) :: Nil
- case logical.Expand(projections, output, child) =>
- execution.Expand(projections, output, planLater(child)) :: Nil
+ case e @ logical.Expand(_, _, _, child) =>
+ execution.Expand(e.projections, e.output, planLater(child)) :: Nil
case logical.Aggregate(group, agg, child) =>
execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil
case logical.Window(projectList, windowExpressions, spec, child) =>