aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala8
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala21
2 files changed, 24 insertions, 5 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 63669e970b..709f7d672d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -571,7 +571,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
split(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil), left, right)
joinType match {
- case Inner =>
+ case _ @ (Inner | LeftSemi) =>
// push down the single side only join filter for both sides sub queries
val newLeft = leftJoinConditions.
reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
@@ -579,7 +579,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
val newJoinCond = commonJoinCondition.reduceLeftOption(And)
- Join(newLeft, newRight, Inner, newJoinCond)
+ Join(newLeft, newRight, joinType, newJoinCond)
case RightOuter =>
// push down the left side only join filter for left side sub query
val newLeft = leftJoinConditions.
@@ -588,14 +588,14 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
val newJoinCond = (rightJoinConditions ++ commonJoinCondition).reduceLeftOption(And)
Join(newLeft, newRight, RightOuter, newJoinCond)
- case _ @ (LeftOuter | LeftSemi) =>
+ case LeftOuter =>
// push down the right side only join filter for right sub query
val newLeft = left
val newRight = rightJoinConditions.
reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
val newJoinCond = (leftJoinConditions ++ commonJoinCondition).reduceLeftOption(And)
- Join(newLeft, newRight, joinType, newJoinCond)
+ Join(newLeft, newRight, LeftOuter, newJoinCond)
case FullOuter => f
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index 2ad73941ab..58d415d901 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
import org.apache.spark.sql.catalyst.expressions.{Count, Explode}
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.plans.{PlanTest, LeftOuter, RightOuter}
+import org.apache.spark.sql.catalyst.plans.{LeftSemi, PlanTest, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._
@@ -43,6 +43,8 @@ class FilterPushdownSuite extends PlanTest {
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
+ val testRelation1 = LocalRelation('d.int)
+
// This test already passes.
test("eliminate subqueries") {
val originalQuery =
@@ -213,6 +215,23 @@ class FilterPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
+ test("joins: push down left semi join") {
+ val x = testRelation.subquery('x)
+ val y = testRelation1.subquery('y)
+
+ val originalQuery = {
+ x.join(y, LeftSemi, Option("x.a".attr === "y.d".attr && "x.b".attr >= 1 && "y.d".attr >= 2))
+ }
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+ val left = testRelation.where('b >= 1)
+ val right = testRelation1.where('d >= 2)
+ val correctAnswer =
+ left.join(right, LeftSemi, Option("a".attr === "d".attr)).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
test("joins: push down left outer join #1") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)