diff options
author | Wenchen Fan <wenchen@databricks.com> | 2016-04-19 21:53:19 -0700 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2016-04-19 21:53:19 -0700 |
commit | 856bc465d53ccfdfda75c82c85d7f318a5158088 (patch) | |
tree | 4d168f7380313ceaf3bc56249113dfae004150ee | |
parent | 85d759ca3aebb7d60b963207dcada83c75502e52 (diff) | |
download | spark-856bc465d53ccfdfda75c82c85d7f318a5158088.tar.gz spark-856bc465d53ccfdfda75c82c85d7f318a5158088.tar.bz2 spark-856bc465d53ccfdfda75c82c85d7f318a5158088.zip |
[SPARK-14600] [SQL] Push predicates through Expand
## What changes were proposed in this pull request?
https://issues.apache.org/jira/browse/SPARK-14600
This PR makes `Expand.output` have different attributes from the grouping attributes produced by the underlying `Project`, as they have different meaning, so that we can safely push down filter through `Expand`
## How was this patch tested?
existing tests.
Author: Wenchen Fan <wenchen@databricks.com>
Closes #12496 from cloud-fan/expand.
5 files changed, 34 insertions, 14 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 236476900a..8595762988 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -296,10 +296,13 @@ class Analyzer( val nonNullBitmask = x.bitmasks.reduce(_ & _) - val groupByAttributes = groupByAliases.zipWithIndex.map { case (a, idx) => + val expandedAttributes = groupByAliases.zipWithIndex.map { case (a, idx) => a.toAttribute.withNullability((nonNullBitmask & 1 << idx) == 0) } + val expand = Expand(x.bitmasks, groupByAliases, expandedAttributes, gid, x.child) + val groupingAttrs = expand.output.drop(x.child.output.length) + val aggregations: Seq[NamedExpression] = x.aggregations.map { case expr => // collect all the found AggregateExpression, so we can check an expression is part of // any AggregateExpression or not. @@ -321,15 +324,12 @@ class Analyzer( if (index == -1) { e } else { - groupByAttributes(index) + groupingAttrs(index) } }.asInstanceOf[NamedExpression] } - Aggregate( - groupByAttributes :+ gid, - aggregations, - Expand(x.bitmasks, groupByAliases, groupByAttributes, gid, x.child)) + Aggregate(groupingAttrs, aggregations, expand) case f @ Filter(cond, child) if hasGroupingFunction(cond) => val groupingExprs = findGroupingExprs(child) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index ecc2d773e7..e6d554565d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1020,8 +1020,6 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { case filter @ Filter(_, f: Filter) => filter // should not push predicates through sample, or will generate different results. case filter @ Filter(_, s: Sample) => filter - // TODO: push predicates through expand - case filter @ Filter(_, e: Expand) => filter case filter @ Filter(condition, u: UnaryNode) if u.expressions.forall(_.deterministic) => pushDownPredicate(filter, u.child) { predicate => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index d4fc9e4da9..a445ce6947 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -516,7 +516,10 @@ private[sql] object Expand { // groupingId is the last output, here we use the bit mask as the concrete value for it. } :+ Literal.create(bitmask, IntegerType) } - val output = child.output ++ groupByAttrs :+ gid + + // the `groupByAttrs` has different meaning in `Expand.output`, it could be the original + // grouping expression or null, so here we create new instance of it. + val output = child.output ++ groupByAttrs.map(_.newInstance) :+ gid Expand(projections, output, Project(child.output ++ groupByAliases, child)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index df7529d83f..9174b4e649 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -743,4 +743,19 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("expand") { + val agg = testRelation + .groupBy(Cube(Seq('a, 'b)))('a, 'b, sum('c)) + .analyze + .asInstanceOf[Aggregate] + + val a = agg.output(0) + val b = agg.output(1) + + val query = agg.where(a > 1 && b > 2) + val optimized = Optimize.execute(query) + val correctedAnswer = agg.copy(child = agg.child.where(a > 1 && b > 2)).analyze + comparePlans(optimized, correctedAnswer) + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index e54358e657..2d44813f0e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -288,8 +288,9 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi private def isGroupingSet(a: Aggregate, e: Expand, p: Project): Boolean = { assert(a.child == e && e.child == p) - a.groupingExpressions.forall(_.isInstanceOf[Attribute]) && - sameOutput(e.output, p.child.output ++ a.groupingExpressions.map(_.asInstanceOf[Attribute])) + a.groupingExpressions.forall(_.isInstanceOf[Attribute]) && sameOutput( + e.output.drop(p.child.output.length), + a.groupingExpressions.map(_.asInstanceOf[Attribute])) } private def groupingSetToSQL( @@ -303,25 +304,28 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi val numOriginalOutput = project.child.output.length // Assumption: Aggregate's groupingExpressions is composed of - // 1) the attributes of aliased group by expressions + // 1) the grouping attributes // 2) gid, which is always the last one val groupByAttributes = agg.groupingExpressions.dropRight(1).map(_.asInstanceOf[Attribute]) // Assumption: Project's projectList is composed of // 1) the original output (Project's child.output), // 2) the aliased group by expressions. + val expandedAttributes = project.output.drop(numOriginalOutput) val groupByExprs = project.projectList.drop(numOriginalOutput).map(_.asInstanceOf[Alias].child) val groupingSQL = groupByExprs.map(_.sql).mkString(", ") // a map from group by attributes to the original group by expressions. val groupByAttrMap = AttributeMap(groupByAttributes.zip(groupByExprs)) + // a map from expanded attributes to the original group by expressions. + val expandedAttrMap = AttributeMap(expandedAttributes.zip(groupByExprs)) val groupingSet: Seq[Seq[Expression]] = expand.projections.map { project => // Assumption: expand.projections is composed of // 1) the original output (Project's child.output), - // 2) group by attributes(or null literal) + // 2) expanded attributes(or null literal) // 3) gid, which is always the last one in each project in Expand project.drop(numOriginalOutput).dropRight(1).collect { - case attr: Attribute if groupByAttrMap.contains(attr) => groupByAttrMap(attr) + case attr: Attribute if expandedAttrMap.contains(attr) => expandedAttrMap(attr) } } val groupingSetSQL = "GROUPING SETS(" + |