aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala23
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala6
-rw-r--r--sql/core/src/test/resources/sql-tests/inputs/limit.sql3
-rw-r--r--sql/core/src/test/resources/sql-tests/results/limit.sql.out10
4 files changed, 35 insertions, 7 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index f7aa6da0a5..ce57f05868 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -1208,17 +1208,28 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
filter
}
- // two filters should be combine together by other rules
- case filter @ Filter(_, _: Filter) => filter
- // should not push predicates through sample, or will generate different results.
- case filter @ Filter(_, _: Sample) => filter
-
- case filter @ Filter(condition, u: UnaryNode) if u.expressions.forall(_.deterministic) =>
+ case filter @ Filter(condition, u: UnaryNode)
+ if canPushThrough(u) && u.expressions.forall(_.deterministic) =>
pushDownPredicate(filter, u.child) { predicate =>
u.withNewChildren(Seq(Filter(predicate, u.child)))
}
}
+ private def canPushThrough(p: UnaryNode): Boolean = p match {
+ // Note that some operators (e.g. project, aggregate, union) are being handled separately
+ // (earlier in this rule).
+ case _: AppendColumns => true
+ case _: BroadcastHint => true
+ case _: Distinct => true
+ case _: Generate => true
+ case _: Pivot => true
+ case _: RedistributeData => true
+ case _: Repartition => true
+ case _: ScriptTransformation => true
+ case _: Sort => true
+ case _ => false
+ }
+
private def pushDownPredicate(
filter: Filter,
grandchild: LogicalPlan)(insertFilter: Expression => LogicalPlan): LogicalPlan = {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index 596b8fcea1..9f25e9d8e9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -111,6 +111,12 @@ class FilterPushdownSuite extends PlanTest {
assert(optimized == correctAnswer)
}
+ test("SPARK-16994: filter should not be pushed through limit") {
+ val originalQuery = testRelation.limit(10).where('a === 1).analyze
+ val optimized = Optimize.execute(originalQuery)
+ comparePlans(optimized, originalQuery)
+ }
+
test("can't push without rewrite") {
val originalQuery =
testRelation
diff --git a/sql/core/src/test/resources/sql-tests/inputs/limit.sql b/sql/core/src/test/resources/sql-tests/inputs/limit.sql
index 892a1bb4b5..2ea35f7f3a 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/limit.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/limit.sql
@@ -18,3 +18,6 @@ select * from testdata limit key > 3;
-- limit must be integer
select * from testdata limit true;
select * from testdata limit 'a';
+
+-- limit within a subquery
+select * from (select * from range(10) limit 5) where id > 3;
diff --git a/sql/core/src/test/resources/sql-tests/results/limit.sql.out b/sql/core/src/test/resources/sql-tests/results/limit.sql.out
index b71b058869..cb4e4d0481 100644
--- a/sql/core/src/test/resources/sql-tests/results/limit.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/limit.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 9
+-- Number of queries: 10
-- !query 0
@@ -81,3 +81,11 @@ struct<>
-- !query 8 output
org.apache.spark.sql.AnalysisException
The limit expression must be integer type, but got string;
+
+
+-- !query 9
+select * from (select * from range(10) limit 5) where id > 3
+-- !query 9 schema
+struct<id:bigint>
+-- !query 9 output
+4