aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-04-19 21:53:19 -0700
committerDavies Liu <davies.liu@gmail.com>2016-04-19 21:53:19 -0700
commit856bc465d53ccfdfda75c82c85d7f318a5158088 (patch)
tree4d168f7380313ceaf3bc56249113dfae004150ee /sql/catalyst
parent85d759ca3aebb7d60b963207dcada83c75502e52 (diff)
downloadspark-856bc465d53ccfdfda75c82c85d7f318a5158088.tar.gz
spark-856bc465d53ccfdfda75c82c85d7f318a5158088.tar.bz2
spark-856bc465d53ccfdfda75c82c85d7f318a5158088.zip
[SPARK-14600] [SQL] Push predicates through Expand
## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-14600 This PR makes `Expand.output` have different attributes from the grouping attributes produced by the underlying `Project`, as they have different meaning, so that we can safely push down filter through `Expand` ## How was this patch tested? existing tests. Author: Wenchen Fan <wenchen@databricks.com> Closes #12496 from cloud-fan/expand.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala5
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala15
4 files changed, 25 insertions, 9 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 236476900a..8595762988 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -296,10 +296,13 @@ class Analyzer(
val nonNullBitmask = x.bitmasks.reduce(_ & _)
- val groupByAttributes = groupByAliases.zipWithIndex.map { case (a, idx) =>
+ val expandedAttributes = groupByAliases.zipWithIndex.map { case (a, idx) =>
a.toAttribute.withNullability((nonNullBitmask & 1 << idx) == 0)
}
+ val expand = Expand(x.bitmasks, groupByAliases, expandedAttributes, gid, x.child)
+ val groupingAttrs = expand.output.drop(x.child.output.length)
+
val aggregations: Seq[NamedExpression] = x.aggregations.map { case expr =>
// collect all the found AggregateExpression, so we can check an expression is part of
// any AggregateExpression or not.
@@ -321,15 +324,12 @@ class Analyzer(
if (index == -1) {
e
} else {
- groupByAttributes(index)
+ groupingAttrs(index)
}
}.asInstanceOf[NamedExpression]
}
- Aggregate(
- groupByAttributes :+ gid,
- aggregations,
- Expand(x.bitmasks, groupByAliases, groupByAttributes, gid, x.child))
+ Aggregate(groupingAttrs, aggregations, expand)
case f @ Filter(cond, child) if hasGroupingFunction(cond) =>
val groupingExprs = findGroupingExprs(child)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index ecc2d773e7..e6d554565d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -1020,8 +1020,6 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
case filter @ Filter(_, f: Filter) => filter
// should not push predicates through sample, or will generate different results.
case filter @ Filter(_, s: Sample) => filter
- // TODO: push predicates through expand
- case filter @ Filter(_, e: Expand) => filter
case filter @ Filter(condition, u: UnaryNode) if u.expressions.forall(_.deterministic) =>
pushDownPredicate(filter, u.child) { predicate =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index d4fc9e4da9..a445ce6947 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -516,7 +516,10 @@ private[sql] object Expand {
// groupingId is the last output, here we use the bit mask as the concrete value for it.
} :+ Literal.create(bitmask, IntegerType)
}
- val output = child.output ++ groupByAttrs :+ gid
+
+ // the `groupByAttrs` has different meaning in `Expand.output`, it could be the original
+ // grouping expression or null, so here we create new instance of it.
+ val output = child.output ++ groupByAttrs.map(_.newInstance) :+ gid
Expand(projections, output, Project(child.output ++ groupByAliases, child))
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index df7529d83f..9174b4e649 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -743,4 +743,19 @@ class FilterPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
+
+ test("expand") {
+ val agg = testRelation
+ .groupBy(Cube(Seq('a, 'b)))('a, 'b, sum('c))
+ .analyze
+ .asInstanceOf[Aggregate]
+
+ val a = agg.output(0)
+ val b = agg.output(1)
+
+ val query = agg.where(a > 1 && b > 2)
+ val optimized = Optimize.execute(query)
+ val correctedAnswer = agg.copy(child = agg.child.where(a > 1 && b > 2)).analyze
+ comparePlans(optimized, correctedAnswer)
+ }
}