diff options
2 files changed, 34 insertions, 0 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index bfd24287c9..7d41ef9aaf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -41,6 +41,7 @@ object DefaultOptimizer extends Optimizer { Batch("Operator Optimizations", FixedPoint(100), // Operator push down UnionPushDown, + SamplePushDown, PushPredicateThroughJoin, PushPredicateThroughProject, PushPredicateThroughGenerate, @@ -66,6 +67,23 @@ object DefaultOptimizer extends Optimizer { } /** + * Pushes operations down into a Sample. + */ +object SamplePushDown extends Rule[LogicalPlan] { + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // Push down filter into sample + case Filter(condition, s @ Sample(lb, up, replace, seed, child)) => + Sample(lb, up, replace, seed, + Filter(condition, child)) + // Push down projection into sample + case Project(projectList, s @ Sample(lb, up, replace, seed, child)) => + Sample(lb, up, replace, seed, + Project(projectList, child)) + } +} + +/** * Pushes operations to either side of a Union. */ object UnionPushDown extends Rule[LogicalPlan] { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index ffdc673cdc..dc28b3ffb5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -34,6 +34,7 @@ class FilterPushdownSuite extends PlanTest { Batch("Subqueries", Once, EliminateSubQueries) :: Batch("Filter Pushdown", Once, + SamplePushDown, CombineFilters, PushPredicateThroughProject, BooleanSimplification, @@ -593,4 +594,19 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized1, analysis.EliminateSubQueries(correctAnswer1)) } + + test("push project and filter down into sample") { + val x = testRelation.subquery('x) + val originalQuery = + Sample(0.0, 0.6, false, 11L, x).select('a) + + val originalQueryAnalyzed = EliminateSubQueries(analysis.SimpleAnalyzer.execute(originalQuery)) + + val optimized = Optimize.execute(originalQueryAnalyzed) + + val correctAnswer = + Sample(0.0, 0.6, false, 11L, x.select('a)) + + comparePlans(optimized, correctAnswer.analyze) + } } |