aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala111
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala3
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala76
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala2
6 files changed, 146 insertions, 52 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index bad115d22f..438cbabdbb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -66,9 +66,7 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
ReorderJoin,
OuterJoinElimination,
PushPredicateThroughJoin,
- PushPredicateThroughProject,
- PushPredicateThroughGenerate,
- PushPredicateThroughAggregate,
+ PushDownPredicate,
LimitPushDown,
ColumnPruning,
InferFiltersFromConstraints,
@@ -917,12 +915,13 @@ object PruneFilters extends Rule[LogicalPlan] with PredicateHelper {
}
/**
- * Pushes [[Filter]] operators through [[Project]] operators, in-lining any [[Alias Aliases]]
- * that were defined in the projection.
+ * Pushes [[Filter]] operators through many operators iff:
+ * 1) the operator is deterministic
+ * 2) the predicate is deterministic and the operator will not change any of rows.
*
* This heuristic is valid assuming the expression evaluation cost is minimal.
*/
-object PushPredicateThroughProject extends Rule[LogicalPlan] with PredicateHelper {
+object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// SPARK-13473: We can't push the predicate down when the underlying projection output non-
// deterministic field(s). Non-deterministic expressions are essentially stateful. This
@@ -939,41 +938,7 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] with PredicateHelpe
})
project.copy(child = Filter(replaceAlias(condition, aliasMap), grandChild))
- }
-
-}
-
-/**
- * Push [[Filter]] operators through [[Generate]] operators. Parts of the predicate that reference
- * attributes generated in [[Generate]] will remain above, and the rest should be pushed beneath.
- */
-object PushPredicateThroughGenerate extends Rule[LogicalPlan] with PredicateHelper {
-
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case filter @ Filter(condition, g: Generate) =>
- // Predicates that reference attributes produced by the `Generate` operator cannot
- // be pushed below the operator.
- val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { cond =>
- cond.references.subsetOf(g.child.outputSet) && cond.deterministic
- }
- if (pushDown.nonEmpty) {
- val pushDownPredicate = pushDown.reduce(And)
- val newGenerate = Generate(g.generator, join = g.join, outer = g.outer,
- g.qualifier, g.generatorOutput, Filter(pushDownPredicate, g.child))
- if (stayUp.isEmpty) newGenerate else Filter(stayUp.reduce(And), newGenerate)
- } else {
- filter
- }
- }
-}
-/**
- * Push [[Filter]] operators through [[Aggregate]] operators, iff the filters reference only
- * non-aggregate attributes (typically literals or grouping expressions).
- */
-object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHelper {
-
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case filter @ Filter(condition, aggregate: Aggregate) =>
// Find all the aliased expressions in the aggregate list that don't include any actual
// AggregateExpression, and create a map from the alias to the expression
@@ -999,6 +964,72 @@ object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHel
} else {
filter
}
+
+ case filter @ Filter(condition, child)
+ if child.isInstanceOf[Union] || child.isInstanceOf[Intersect] =>
+ // Union/Intersect could change the rows, so non-deterministic predicate can't be pushed down
+ val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { cond =>
+ cond.deterministic
+ }
+ if (pushDown.nonEmpty) {
+ val pushDownCond = pushDown.reduceLeft(And)
+ val output = child.output
+ val newGrandChildren = child.children.map { grandchild =>
+ val newCond = pushDownCond transform {
+ case e if output.exists(_.semanticEquals(e)) =>
+ grandchild.output(output.indexWhere(_.semanticEquals(e)))
+ }
+ assert(newCond.references.subsetOf(grandchild.outputSet))
+ Filter(newCond, grandchild)
+ }
+ val newChild = child.withNewChildren(newGrandChildren)
+ if (stayUp.nonEmpty) {
+ Filter(stayUp.reduceLeft(And), newChild)
+ } else {
+ newChild
+ }
+ } else {
+ filter
+ }
+
+ case filter @ Filter(condition, e @ Except(left, _)) =>
+ pushDownPredicate(filter, e.left) { predicate =>
+ e.copy(left = Filter(predicate, left))
+ }
+
+ // two filters should be combine together by other rules
+ case filter @ Filter(_, f: Filter) => filter
+ // should not push predicates through sample, or will generate different results.
+ case filter @ Filter(_, s: Sample) => filter
+ // TODO: push predicates through expand
+ case filter @ Filter(_, e: Expand) => filter
+
+ case filter @ Filter(condition, u: UnaryNode) if u.expressions.forall(_.deterministic) =>
+ pushDownPredicate(filter, u.child) { predicate =>
+ u.withNewChildren(Seq(Filter(predicate, u.child)))
+ }
+ }
+
+ private def pushDownPredicate(
+ filter: Filter,
+ grandchild: LogicalPlan)(insertFilter: Expression => LogicalPlan): LogicalPlan = {
+ // Only push down the predicates that is deterministic and all the referenced attributes
+ // come from grandchild.
+ // TODO: non-deterministic predicates could be pushed through some operators that do not change
+ // the rows.
+ val (pushDown, stayUp) = splitConjunctivePredicates(filter.condition).partition { cond =>
+ cond.deterministic && cond.references.subsetOf(grandchild.outputSet)
+ }
+ if (pushDown.nonEmpty) {
+ val newChild = insertFilter(pushDown.reduceLeft(And))
+ if (stayUp.nonEmpty) {
+ Filter(stayUp.reduceLeft(And), newChild)
+ } else {
+ newChild
+ }
+ } else {
+ filter
+ }
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 6f35d87ebb..0065619135 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -69,6 +69,9 @@ object PhysicalOperation extends PredicateHelper {
val substitutedCondition = substitute(aliases)(condition)
(fields, filters ++ splitConjunctivePredicates(substitutedCondition), other, aliases)
+ case BroadcastHint(child) =>
+ collectProjectsAndFilters(child)
+
case other =>
(None, Nil, other, Map.empty)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
index 2248e03b2f..52b574c0e6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
@@ -34,7 +34,7 @@ class ColumnPruningSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("Column pruning", FixedPoint(100),
- PushPredicateThroughProject,
+ PushDownPredicate,
ColumnPruning,
CollapseProject) :: Nil
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index b84ae7c5bb..df7529d83f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -33,14 +33,12 @@ class FilterPushdownSuite extends PlanTest {
val batches =
Batch("Subqueries", Once,
EliminateSubqueryAliases) ::
- Batch("Filter Pushdown", Once,
+ Batch("Filter Pushdown", FixedPoint(10),
SamplePushDown,
CombineFilters,
- PushPredicateThroughProject,
+ PushDownPredicate,
BooleanSimplification,
PushPredicateThroughJoin,
- PushPredicateThroughGenerate,
- PushPredicateThroughAggregate,
CollapseProject) :: Nil
}
@@ -620,8 +618,8 @@ class FilterPushdownSuite extends PlanTest {
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = testRelation
- .select('a, 'b)
.where('a === 3)
+ .select('a, 'b)
.groupBy('a)('a, count('b) as 'c)
.where('c === 2L)
.analyze
@@ -638,8 +636,8 @@ class FilterPushdownSuite extends PlanTest {
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = testRelation
- .select('a, 'b)
.where('a + 1 < 3)
+ .select('a, 'b)
.groupBy('a)(('a + 1) as 'aa, count('b) as 'c)
.where('c === 2L || 'aa > 4)
.analyze
@@ -656,8 +654,8 @@ class FilterPushdownSuite extends PlanTest {
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = testRelation
- .select('a, 'b)
.where("s" === "s")
+ .select('a, 'b)
.groupBy('a)('a, count('b) as 'c, "s" as 'd)
.where('c === 2L)
.analyze
@@ -681,4 +679,68 @@ class FilterPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
+
+ test("broadcast hint") {
+ val originalQuery = BroadcastHint(testRelation)
+ .where('a === 2L && 'b + Rand(10).as("rnd") === 3)
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+
+ val correctAnswer = BroadcastHint(testRelation.where('a === 2L))
+ .where('b + Rand(10).as("rnd") === 3)
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("union") {
+ val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int)
+
+ val originalQuery = Union(Seq(testRelation, testRelation2))
+ .where('a === 2L && 'b + Rand(10).as("rnd") === 3)
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+
+ val correctAnswer = Union(Seq(
+ testRelation.where('a === 2L),
+ testRelation2.where('d === 2L)))
+ .where('b + Rand(10).as("rnd") === 3)
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("intersect") {
+ val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int)
+
+ val originalQuery = Intersect(testRelation, testRelation2)
+ .where('a === 2L && 'b + Rand(10).as("rnd") === 3)
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+
+ val correctAnswer = Intersect(
+ testRelation.where('a === 2L),
+ testRelation2.where('d === 2L))
+ .where('b + Rand(10).as("rnd") === 3)
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("except") {
+ val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int)
+
+ val originalQuery = Except(testRelation, testRelation2)
+ .where('a === 2L && 'b + Rand(10).as("rnd") === 3)
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+
+ val correctAnswer = Except(
+ testRelation.where('a === 2L),
+ testRelation2)
+ .where('b + Rand(10).as("rnd") === 3)
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
index e2f8146bee..c1ebf8b09e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
@@ -36,12 +36,10 @@ class JoinOptimizationSuite extends PlanTest {
EliminateSubqueryAliases) ::
Batch("Filter Pushdown", FixedPoint(100),
CombineFilters,
- PushPredicateThroughProject,
+ PushDownPredicate,
BooleanSimplification,
ReorderJoin,
PushPredicateThroughJoin,
- PushPredicateThroughGenerate,
- PushPredicateThroughAggregate,
ColumnPruning,
CollapseProject) :: Nil
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala
index 14fb72a8a3..d8cfec5391 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala
@@ -34,7 +34,7 @@ class PruneFiltersSuite extends PlanTest {
Batch("Filter Pushdown and Pruning", Once,
CombineFilters,
PruneFilters,
- PushPredicateThroughProject,
+ PushDownPredicate,
PushPredicateThroughJoin) :: Nil
}