aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-04-13 13:01:13 -0700
committerDavies Liu <davies.liu@gmail.com>2016-04-13 13:01:13 -0700
commitdbbe149070052af5cda04f7b110d65de73766ded (patch)
tree51ecbe9b35cddfd26abd6c61d67d80f5a8e957f3
parentf9d578eaa107d8e8503c1563a2b3990c85104298 (diff)
downloadspark-dbbe149070052af5cda04f7b110d65de73766ded.tar.gz
spark-dbbe149070052af5cda04f7b110d65de73766ded.tar.bz2
spark-dbbe149070052af5cda04f7b110d65de73766ded.zip
[SPARK-14581] [SQL] push predicatese through more logical plans
## What changes were proposed in this pull request? Right now, filter push down only works with Project, Aggregate, Generate and Join, they can't be pushed through many other plans. This PR added support for Union, Intersect, Except and all unary plans. ## How was this patch tested? Added tests. Author: Davies Liu <davies@databricks.com> Closes #12342 from davies/filter_hint.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala111
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala3
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala76
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala2
6 files changed, 146 insertions, 52 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index bad115d22f..438cbabdbb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -66,9 +66,7 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
ReorderJoin,
OuterJoinElimination,
PushPredicateThroughJoin,
- PushPredicateThroughProject,
- PushPredicateThroughGenerate,
- PushPredicateThroughAggregate,
+ PushDownPredicate,
LimitPushDown,
ColumnPruning,
InferFiltersFromConstraints,
@@ -917,12 +915,13 @@ object PruneFilters extends Rule[LogicalPlan] with PredicateHelper {
}
/**
- * Pushes [[Filter]] operators through [[Project]] operators, in-lining any [[Alias Aliases]]
- * that were defined in the projection.
+ * Pushes [[Filter]] operators through many operators iff:
+ * 1) the operator is deterministic
+ * 2) the predicate is deterministic and the operator will not change any of rows.
*
* This heuristic is valid assuming the expression evaluation cost is minimal.
*/
-object PushPredicateThroughProject extends Rule[LogicalPlan] with PredicateHelper {
+object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// SPARK-13473: We can't push the predicate down when the underlying projection output non-
// deterministic field(s). Non-deterministic expressions are essentially stateful. This
@@ -939,41 +938,7 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] with PredicateHelpe
})
project.copy(child = Filter(replaceAlias(condition, aliasMap), grandChild))
- }
-
-}
-
-/**
- * Push [[Filter]] operators through [[Generate]] operators. Parts of the predicate that reference
- * attributes generated in [[Generate]] will remain above, and the rest should be pushed beneath.
- */
-object PushPredicateThroughGenerate extends Rule[LogicalPlan] with PredicateHelper {
-
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case filter @ Filter(condition, g: Generate) =>
- // Predicates that reference attributes produced by the `Generate` operator cannot
- // be pushed below the operator.
- val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { cond =>
- cond.references.subsetOf(g.child.outputSet) && cond.deterministic
- }
- if (pushDown.nonEmpty) {
- val pushDownPredicate = pushDown.reduce(And)
- val newGenerate = Generate(g.generator, join = g.join, outer = g.outer,
- g.qualifier, g.generatorOutput, Filter(pushDownPredicate, g.child))
- if (stayUp.isEmpty) newGenerate else Filter(stayUp.reduce(And), newGenerate)
- } else {
- filter
- }
- }
-}
-/**
- * Push [[Filter]] operators through [[Aggregate]] operators, iff the filters reference only
- * non-aggregate attributes (typically literals or grouping expressions).
- */
-object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHelper {
-
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case filter @ Filter(condition, aggregate: Aggregate) =>
// Find all the aliased expressions in the aggregate list that don't include any actual
// AggregateExpression, and create a map from the alias to the expression
@@ -999,6 +964,72 @@ object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHel
} else {
filter
}
+
+ case filter @ Filter(condition, child)
+ if child.isInstanceOf[Union] || child.isInstanceOf[Intersect] =>
+ // Union/Intersect could change the rows, so non-deterministic predicate can't be pushed down
+ val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { cond =>
+ cond.deterministic
+ }
+ if (pushDown.nonEmpty) {
+ val pushDownCond = pushDown.reduceLeft(And)
+ val output = child.output
+ val newGrandChildren = child.children.map { grandchild =>
+ val newCond = pushDownCond transform {
+ case e if output.exists(_.semanticEquals(e)) =>
+ grandchild.output(output.indexWhere(_.semanticEquals(e)))
+ }
+ assert(newCond.references.subsetOf(grandchild.outputSet))
+ Filter(newCond, grandchild)
+ }
+ val newChild = child.withNewChildren(newGrandChildren)
+ if (stayUp.nonEmpty) {
+ Filter(stayUp.reduceLeft(And), newChild)
+ } else {
+ newChild
+ }
+ } else {
+ filter
+ }
+
+ case filter @ Filter(condition, e @ Except(left, _)) =>
+ pushDownPredicate(filter, e.left) { predicate =>
+ e.copy(left = Filter(predicate, left))
+ }
+
+ // two filters should be combine together by other rules
+ case filter @ Filter(_, f: Filter) => filter
+ // should not push predicates through sample, or will generate different results.
+ case filter @ Filter(_, s: Sample) => filter
+ // TODO: push predicates through expand
+ case filter @ Filter(_, e: Expand) => filter
+
+ case filter @ Filter(condition, u: UnaryNode) if u.expressions.forall(_.deterministic) =>
+ pushDownPredicate(filter, u.child) { predicate =>
+ u.withNewChildren(Seq(Filter(predicate, u.child)))
+ }
+ }
+
+ private def pushDownPredicate(
+ filter: Filter,
+ grandchild: LogicalPlan)(insertFilter: Expression => LogicalPlan): LogicalPlan = {
+ // Only push down the predicates that is deterministic and all the referenced attributes
+ // come from grandchild.
+ // TODO: non-deterministic predicates could be pushed through some operators that do not change
+ // the rows.
+ val (pushDown, stayUp) = splitConjunctivePredicates(filter.condition).partition { cond =>
+ cond.deterministic && cond.references.subsetOf(grandchild.outputSet)
+ }
+ if (pushDown.nonEmpty) {
+ val newChild = insertFilter(pushDown.reduceLeft(And))
+ if (stayUp.nonEmpty) {
+ Filter(stayUp.reduceLeft(And), newChild)
+ } else {
+ newChild
+ }
+ } else {
+ filter
+ }
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 6f35d87ebb..0065619135 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -69,6 +69,9 @@ object PhysicalOperation extends PredicateHelper {
val substitutedCondition = substitute(aliases)(condition)
(fields, filters ++ splitConjunctivePredicates(substitutedCondition), other, aliases)
+ case BroadcastHint(child) =>
+ collectProjectsAndFilters(child)
+
case other =>
(None, Nil, other, Map.empty)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
index 2248e03b2f..52b574c0e6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
@@ -34,7 +34,7 @@ class ColumnPruningSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("Column pruning", FixedPoint(100),
- PushPredicateThroughProject,
+ PushDownPredicate,
ColumnPruning,
CollapseProject) :: Nil
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index b84ae7c5bb..df7529d83f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -33,14 +33,12 @@ class FilterPushdownSuite extends PlanTest {
val batches =
Batch("Subqueries", Once,
EliminateSubqueryAliases) ::
- Batch("Filter Pushdown", Once,
+ Batch("Filter Pushdown", FixedPoint(10),
SamplePushDown,
CombineFilters,
- PushPredicateThroughProject,
+ PushDownPredicate,
BooleanSimplification,
PushPredicateThroughJoin,
- PushPredicateThroughGenerate,
- PushPredicateThroughAggregate,
CollapseProject) :: Nil
}
@@ -620,8 +618,8 @@ class FilterPushdownSuite extends PlanTest {
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = testRelation
- .select('a, 'b)
.where('a === 3)
+ .select('a, 'b)
.groupBy('a)('a, count('b) as 'c)
.where('c === 2L)
.analyze
@@ -638,8 +636,8 @@ class FilterPushdownSuite extends PlanTest {
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = testRelation
- .select('a, 'b)
.where('a + 1 < 3)
+ .select('a, 'b)
.groupBy('a)(('a + 1) as 'aa, count('b) as 'c)
.where('c === 2L || 'aa > 4)
.analyze
@@ -656,8 +654,8 @@ class FilterPushdownSuite extends PlanTest {
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = testRelation
- .select('a, 'b)
.where("s" === "s")
+ .select('a, 'b)
.groupBy('a)('a, count('b) as 'c, "s" as 'd)
.where('c === 2L)
.analyze
@@ -681,4 +679,68 @@ class FilterPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
+
+ test("broadcast hint") {
+ val originalQuery = BroadcastHint(testRelation)
+ .where('a === 2L && 'b + Rand(10).as("rnd") === 3)
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+
+ val correctAnswer = BroadcastHint(testRelation.where('a === 2L))
+ .where('b + Rand(10).as("rnd") === 3)
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("union") {
+ val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int)
+
+ val originalQuery = Union(Seq(testRelation, testRelation2))
+ .where('a === 2L && 'b + Rand(10).as("rnd") === 3)
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+
+ val correctAnswer = Union(Seq(
+ testRelation.where('a === 2L),
+ testRelation2.where('d === 2L)))
+ .where('b + Rand(10).as("rnd") === 3)
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("intersect") {
+ val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int)
+
+ val originalQuery = Intersect(testRelation, testRelation2)
+ .where('a === 2L && 'b + Rand(10).as("rnd") === 3)
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+
+ val correctAnswer = Intersect(
+ testRelation.where('a === 2L),
+ testRelation2.where('d === 2L))
+ .where('b + Rand(10).as("rnd") === 3)
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("except") {
+ val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int)
+
+ val originalQuery = Except(testRelation, testRelation2)
+ .where('a === 2L && 'b + Rand(10).as("rnd") === 3)
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+
+ val correctAnswer = Except(
+ testRelation.where('a === 2L),
+ testRelation2)
+ .where('b + Rand(10).as("rnd") === 3)
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
index e2f8146bee..c1ebf8b09e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
@@ -36,12 +36,10 @@ class JoinOptimizationSuite extends PlanTest {
EliminateSubqueryAliases) ::
Batch("Filter Pushdown", FixedPoint(100),
CombineFilters,
- PushPredicateThroughProject,
+ PushDownPredicate,
BooleanSimplification,
ReorderJoin,
PushPredicateThroughJoin,
- PushPredicateThroughGenerate,
- PushPredicateThroughAggregate,
ColumnPruning,
CollapseProject) :: Nil
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala
index 14fb72a8a3..d8cfec5391 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala
@@ -34,7 +34,7 @@ class PruneFiltersSuite extends PlanTest {
Batch("Filter Pushdown and Pruning", Once,
CombineFilters,
PruneFilters,
- PushPredicateThroughProject,
+ PushDownPredicate,
PushPredicateThroughJoin) :: Nil
}