diff options
author | Wenchen Fan <wenchen@databricks.com> | 2016-04-19 21:53:19 -0700 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2016-04-19 21:53:19 -0700 |
commit | 856bc465d53ccfdfda75c82c85d7f318a5158088 (patch) | |
tree | 4d168f7380313ceaf3bc56249113dfae004150ee /sql/catalyst | |
parent | 85d759ca3aebb7d60b963207dcada83c75502e52 (diff) | |
download | spark-856bc465d53ccfdfda75c82c85d7f318a5158088.tar.gz spark-856bc465d53ccfdfda75c82c85d7f318a5158088.tar.bz2 spark-856bc465d53ccfdfda75c82c85d7f318a5158088.zip |
[SPARK-14600] [SQL] Push predicates through Expand
## What changes were proposed in this pull request?
https://issues.apache.org/jira/browse/SPARK-14600
This PR makes `Expand.output` have different attributes from the grouping attributes produced by the underlying `Project`, as they have different meaning, so that we can safely push down filter through `Expand`
## How was this patch tested?
existing tests.
Author: Wenchen Fan <wenchen@databricks.com>
Closes #12496 from cloud-fan/expand.
Diffstat (limited to 'sql/catalyst')
4 files changed, 25 insertions, 9 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 236476900a..8595762988 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -296,10 +296,13 @@ class Analyzer( val nonNullBitmask = x.bitmasks.reduce(_ & _) - val groupByAttributes = groupByAliases.zipWithIndex.map { case (a, idx) => + val expandedAttributes = groupByAliases.zipWithIndex.map { case (a, idx) => a.toAttribute.withNullability((nonNullBitmask & 1 << idx) == 0) } + val expand = Expand(x.bitmasks, groupByAliases, expandedAttributes, gid, x.child) + val groupingAttrs = expand.output.drop(x.child.output.length) + val aggregations: Seq[NamedExpression] = x.aggregations.map { case expr => // collect all the found AggregateExpression, so we can check an expression is part of // any AggregateExpression or not. @@ -321,15 +324,12 @@ class Analyzer( if (index == -1) { e } else { - groupByAttributes(index) + groupingAttrs(index) } }.asInstanceOf[NamedExpression] } - Aggregate( - groupByAttributes :+ gid, - aggregations, - Expand(x.bitmasks, groupByAliases, groupByAttributes, gid, x.child)) + Aggregate(groupingAttrs, aggregations, expand) case f @ Filter(cond, child) if hasGroupingFunction(cond) => val groupingExprs = findGroupingExprs(child) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index ecc2d773e7..e6d554565d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1020,8 +1020,6 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { case filter @ Filter(_, f: Filter) => filter // should not push predicates through sample, or will generate different results. case filter @ Filter(_, s: Sample) => filter - // TODO: push predicates through expand - case filter @ Filter(_, e: Expand) => filter case filter @ Filter(condition, u: UnaryNode) if u.expressions.forall(_.deterministic) => pushDownPredicate(filter, u.child) { predicate => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index d4fc9e4da9..a445ce6947 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -516,7 +516,10 @@ private[sql] object Expand { // groupingId is the last output, here we use the bit mask as the concrete value for it. } :+ Literal.create(bitmask, IntegerType) } - val output = child.output ++ groupByAttrs :+ gid + + // the `groupByAttrs` has different meaning in `Expand.output`, it could be the original + // grouping expression or null, so here we create new instance of it. + val output = child.output ++ groupByAttrs.map(_.newInstance) :+ gid Expand(projections, output, Project(child.output ++ groupByAliases, child)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index df7529d83f..9174b4e649 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -743,4 +743,19 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("expand") { + val agg = testRelation + .groupBy(Cube(Seq('a, 'b)))('a, 'b, sum('c)) + .analyze + .asInstanceOf[Aggregate] + + val a = agg.output(0) + val b = agg.output(1) + + val query = agg.where(a > 1 && b > 2) + val optimized = Optimize.execute(query) + val correctedAnswer = agg.copy(child = agg.child.where(a > 1 && b > 2)).analyze + comparePlans(optimized, correctedAnswer) + } } |