aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-11-26 00:19:42 -0800
committerDavies Liu <davies.liu@gmail.com>2015-11-26 00:19:42 -0800
commit27d69a0573ed55e916a464e268dcfd5ecc6ed849 (patch)
tree764ada6aeed457028d35c8f591b9d7b6e61c5969 /sql/catalyst
parentd3ef693325f91a1ed340c9756c81244a80398eb2 (diff)
downloadspark-27d69a0573ed55e916a464e268dcfd5ecc6ed849.tar.gz
spark-27d69a0573ed55e916a464e268dcfd5ecc6ed849.tar.bz2
spark-27d69a0573ed55e916a464e268dcfd5ecc6ed849.zip
[SPARK-11973] [SQL] push filter through aggregation with alias and literals
Currently, filter can't be pushed through aggregation with alias or literals, this patch fix that. After this patch, the time of TPC-DS query 4 go down to 13 seconds from 141 seconds (10x improvements). cc nongli yhuai Author: Davies Liu <davies@databricks.com> Closes #9959 from davies/push_filter2.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala9
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala28
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala53
3 files changed, 79 insertions, 11 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 68557479a9..304b438c84 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -65,6 +65,15 @@ trait PredicateHelper {
}
}
+ // Substitute any known alias from a map.
+ protected def replaceAlias(
+ condition: Expression,
+ aliases: AttributeMap[Expression]): Expression = {
+ condition.transform {
+ case a: Attribute => aliases.getOrElse(a, a)
+ }
+ }
+
/**
* Returns true if `expr` can be evaluated using only the output of `plan`. This method
* can be used to determine when it is acceptable to move expression evaluation within a query
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index f4dba67f13..52f609bc15 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -640,20 +640,14 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] with PredicateHelpe
filter
} else {
// Push down the small conditions without nondeterministic expressions.
- val pushedCondition = deterministic.map(replaceAlias(_, aliasMap)).reduce(And)
+ val pushedCondition =
+ deterministic.map(replaceAlias(_, aliasMap)).reduce(And)
Filter(nondeterministic.reduce(And),
project.copy(child = Filter(pushedCondition, grandChild)))
}
}
}
- // Substitute any attributes that are produced by the child projection, so that we safely
- // eliminate it.
- private def replaceAlias(condition: Expression, sourceAliases: AttributeMap[Expression]) = {
- condition.transform {
- case a: Attribute => sourceAliases.getOrElse(a, a)
- }
- }
}
/**
@@ -690,12 +684,24 @@ object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHel
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case filter @ Filter(condition,
aggregate @ Aggregate(groupingExpressions, aggregateExpressions, grandChild)) =>
- val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition {
- conjunct => conjunct.references subsetOf AttributeSet(groupingExpressions)
+
+ def hasAggregate(expression: Expression): Boolean = expression match {
+ case agg: AggregateExpression => true
+ case other => expression.children.exists(hasAggregate)
+ }
+ // Create a map of Alias for expressions that does not have AggregateExpression
+ val aliasMap = AttributeMap(aggregateExpressions.collect {
+ case a: Alias if !hasAggregate(a.child) => (a.toAttribute, a.child)
+ })
+
+ val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { conjunct =>
+ val replaced = replaceAlias(conjunct, aliasMap)
+ replaced.references.subsetOf(grandChild.outputSet) && replaced.deterministic
}
if (pushDown.nonEmpty) {
val pushDownPredicate = pushDown.reduce(And)
- val withPushdown = aggregate.copy(child = Filter(pushDownPredicate, grandChild))
+ val replaced = replaceAlias(pushDownPredicate, aliasMap)
+ val withPushdown = aggregate.copy(child = Filter(replaced, grandChild))
stayUp.reduceOption(And).map(Filter(_, withPushdown)).getOrElse(withPushdown)
} else {
filter
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index 0290fafe87..0128c220ba 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -697,4 +697,57 @@ class FilterPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
+
+ test("aggregate: push down filters with alias") {
+ val originalQuery = testRelation
+ .select('a, 'b)
+ .groupBy('a)(('a + 1) as 'aa, count('b) as 'c)
+ .where(('c === 2L || 'aa > 4) && 'aa < 3)
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+
+ val correctAnswer = testRelation
+ .select('a, 'b)
+ .where('a + 1 < 3)
+ .groupBy('a)(('a + 1) as 'aa, count('b) as 'c)
+ .where('c === 2L || 'aa > 4)
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("aggregate: push down filters with literal") {
+ val originalQuery = testRelation
+ .select('a, 'b)
+ .groupBy('a)('a, count('b) as 'c, "s" as 'd)
+ .where('c === 2L && 'd === "s")
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+
+ val correctAnswer = testRelation
+ .select('a, 'b)
+ .where("s" === "s")
+ .groupBy('a)('a, count('b) as 'c, "s" as 'd)
+ .where('c === 2L)
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("aggregate: don't push down filters which is nondeterministic") {
+ val originalQuery = testRelation
+ .select('a, 'b)
+ .groupBy('a)('a + Rand(10) as 'aa, count('b) as 'c, Rand(11).as("rnd"))
+ .where('c === 2L && 'aa + Rand(10).as("rnd") === 3 && 'rnd === 5)
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+
+ val correctAnswer = testRelation
+ .select('a, 'b)
+ .groupBy('a)('a + Rand(10) as 'aa, count('b) as 'c, Rand(11).as("rnd"))
+ .where('c === 2L && 'aa + Rand(10).as("rnd") === 3 && 'rnd === 5)
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
}