aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala18
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala16
2 files changed, 34 insertions, 0 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index bfd24287c9..7d41ef9aaf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -41,6 +41,7 @@ object DefaultOptimizer extends Optimizer {
Batch("Operator Optimizations", FixedPoint(100),
// Operator push down
UnionPushDown,
+ SamplePushDown,
PushPredicateThroughJoin,
PushPredicateThroughProject,
PushPredicateThroughGenerate,
@@ -66,6 +67,23 @@ object DefaultOptimizer extends Optimizer {
}
/**
+ * Pushes operations down into a Sample.
+ */
+object SamplePushDown extends Rule[LogicalPlan] {
+
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ // Push down filter into sample
+ case Filter(condition, s @ Sample(lb, up, replace, seed, child)) =>
+ Sample(lb, up, replace, seed,
+ Filter(condition, child))
+ // Push down projection into sample
+ case Project(projectList, s @ Sample(lb, up, replace, seed, child)) =>
+ Sample(lb, up, replace, seed,
+ Project(projectList, child))
+ }
+}
+
+/**
* Pushes operations to either side of a Union.
*/
object UnionPushDown extends Rule[LogicalPlan] {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index ffdc673cdc..dc28b3ffb5 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -34,6 +34,7 @@ class FilterPushdownSuite extends PlanTest {
Batch("Subqueries", Once,
EliminateSubQueries) ::
Batch("Filter Pushdown", Once,
+ SamplePushDown,
CombineFilters,
PushPredicateThroughProject,
BooleanSimplification,
@@ -593,4 +594,19 @@ class FilterPushdownSuite extends PlanTest {
comparePlans(optimized1, analysis.EliminateSubQueries(correctAnswer1))
}
+
+ test("push project and filter down into sample") {
+ val x = testRelation.subquery('x)
+ val originalQuery =
+ Sample(0.0, 0.6, false, 11L, x).select('a)
+
+ val originalQueryAnalyzed = EliminateSubQueries(analysis.SimpleAnalyzer.execute(originalQuery))
+
+ val optimized = Optimize.execute(originalQueryAnalyzed)
+
+ val correctAnswer =
+ Sample(0.0, 0.6, false, 11L, x.select('a))
+
+ comparePlans(optimized, correctAnswer.analyze)
+ }
}