aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala56
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala40
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala95
3 files changed, 185 insertions, 6 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 06d14fcf8b..f6088695a9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -18,14 +18,12 @@
package org.apache.spark.sql.catalyst.optimizer
import scala.collection.immutable.HashSet
+
import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, EliminateSubQueries}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.plans.Inner
-import org.apache.spark.sql.catalyst.plans.FullOuter
-import org.apache.spark.sql.catalyst.plans.LeftOuter
-import org.apache.spark.sql.catalyst.plans.RightOuter
-import org.apache.spark.sql.catalyst.plans.LeftSemi
+import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins
+import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, LeftSemi, RightOuter}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types._
@@ -44,6 +42,7 @@ object DefaultOptimizer extends Optimizer {
// Operator push down
SetOperationPushDown,
SamplePushDown,
+ ReorderJoin,
PushPredicateThroughJoin,
PushPredicateThroughProject,
PushPredicateThroughGenerate,
@@ -712,6 +711,53 @@ object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHel
}
/**
+ * Reorder the joins and push all the conditions into join, so that the bottom ones have at least
+ * one condition.
+ *
+ * The order of joins will not be changed if all of them already have at least one condition.
+ */
+object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
+
+ /**
+ * Join a list of plans together and push down the conditions into them.
+ *
+ * The joined plan are picked from left to right, prefer those has at least one join condition.
+ *
+ * @param input a list of LogicalPlans to join.
+ * @param conditions a list of condition for join.
+ */
+ def createOrderedJoin(input: Seq[LogicalPlan], conditions: Seq[Expression]): LogicalPlan = {
+ assert(input.size >= 2)
+ if (input.size == 2) {
+ Join(input(0), input(1), Inner, conditions.reduceLeftOption(And))
+ } else {
+ val left :: rest = input.toList
+ // find out the first join that have at least one join condition
+ val conditionalJoin = rest.find { plan =>
+ val refs = left.outputSet ++ plan.outputSet
+ conditions.filterNot(canEvaluate(_, left)).filterNot(canEvaluate(_, plan))
+ .exists(_.references.subsetOf(refs))
+ }
+ // pick the next one if no condition left
+ val right = conditionalJoin.getOrElse(rest.head)
+
+ val joinedRefs = left.outputSet ++ right.outputSet
+ val (joinConditions, others) = conditions.partition(_.references.subsetOf(joinedRefs))
+ val joined = Join(left, right, Inner, joinConditions.reduceLeftOption(And))
+
+ // should not have reference to same logical plan
+ createOrderedJoin(Seq(joined) ++ rest.filterNot(_ eq right), others)
+ }
+ }
+
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case j @ ExtractFiltersAndInnerJoins(input, conditions)
+ if input.size > 2 && conditions.nonEmpty =>
+ createOrderedJoin(input, conditions)
+ }
+}
+
+/**
* Pushes down [[Filter]] operators where the `condition` can be
* evaluated using only the attributes of the left or right side of a join. Other
* [[Filter]] conditions are moved into the `condition` of the [[Join]].
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 6f4f11406d..cd3f15cbe1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -21,7 +21,6 @@ import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.trees.TreeNodeRef
/**
* A pattern that matches any number of project or filter operations on top of another relational
@@ -133,6 +132,45 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper {
}
/**
+ * A pattern that collects the filter and inner joins.
+ *
+ * Filter
+ * |
+ * inner Join
+ * / \ ----> (Seq(plan0, plan1, plan2), conditions)
+ * Filter plan2
+ * |
+ * inner join
+ * / \
+ * plan0 plan1
+ *
+ * Note: This pattern currently only works for left-deep trees.
+ */
+object ExtractFiltersAndInnerJoins extends PredicateHelper {
+
+ // flatten all inner joins, which are next to each other
+ def flattenJoin(plan: LogicalPlan): (Seq[LogicalPlan], Seq[Expression]) = plan match {
+ case Join(left, right, Inner, cond) =>
+ val (plans, conditions) = flattenJoin(left)
+ (plans ++ Seq(right), conditions ++ cond.toSeq)
+
+ case Filter(filterCondition, j @ Join(left, right, Inner, joinCondition)) =>
+ val (plans, conditions) = flattenJoin(j)
+ (plans, conditions ++ splitConjunctivePredicates(filterCondition))
+
+ case _ => (Seq(plan), Seq())
+ }
+
+ def unapply(plan: LogicalPlan): Option[(Seq[LogicalPlan], Seq[Expression])] = plan match {
+ case f @ Filter(filterCondition, j @ Join(_, _, Inner, _)) =>
+ Some(flattenJoin(f))
+ case j @ Join(_, _, Inner, _) =>
+ Some(flattenJoin(j))
+ case _ => None
+ }
+}
+
+/**
* A pattern that collects all adjacent unions and returns their children as a Seq.
*/
object Unions {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala
new file mode 100644
index 0000000000..9b1e16c727
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.analysis
+import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+
+
+class JoinOrderSuite extends PlanTest {
+
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches =
+ Batch("Subqueries", Once,
+ EliminateSubQueries) ::
+ Batch("Filter Pushdown", Once,
+ CombineFilters,
+ PushPredicateThroughProject,
+ BooleanSimplification,
+ ReorderJoin,
+ PushPredicateThroughJoin,
+ PushPredicateThroughGenerate,
+ PushPredicateThroughAggregate,
+ ColumnPruning,
+ ProjectCollapsing) :: Nil
+
+ }
+
+ val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
+ val testRelation1 = LocalRelation('d.int)
+
+ test("extract filters and joins") {
+ val x = testRelation.subquery('x)
+ val y = testRelation1.subquery('y)
+ val z = testRelation.subquery('z)
+
+ def testExtract(plan: LogicalPlan, expected: Option[(Seq[LogicalPlan], Seq[Expression])]) {
+ assert(ExtractFiltersAndInnerJoins.unapply(plan) === expected)
+ }
+
+ testExtract(x, None)
+ testExtract(x.where("x.b".attr === 1), None)
+ testExtract(x.join(y), Some(Seq(x, y), Seq()))
+ testExtract(x.join(y, condition = Some("x.b".attr === "y.d".attr)),
+ Some(Seq(x, y), Seq("x.b".attr === "y.d".attr)))
+ testExtract(x.join(y).where("x.b".attr === "y.d".attr),
+ Some(Seq(x, y), Seq("x.b".attr === "y.d".attr)))
+ testExtract(x.join(y).join(z), Some(Seq(x, y, z), Seq()))
+ testExtract(x.join(y).where("x.b".attr === "y.d".attr).join(z),
+ Some(Seq(x, y, z), Seq("x.b".attr === "y.d".attr)))
+ testExtract(x.join(y).join(x.join(z)), Some(Seq(x, y, x.join(z)), Seq()))
+ testExtract(x.join(y).join(x.join(z)).where("x.b".attr === "y.d".attr),
+ Some(Seq(x, y, x.join(z)), Seq("x.b".attr === "y.d".attr)))
+ }
+
+ test("reorder inner joins") {
+ val x = testRelation.subquery('x)
+ val y = testRelation1.subquery('y)
+ val z = testRelation.subquery('z)
+
+ val originalQuery = {
+ x.join(y).join(z)
+ .where(("x.b".attr === "z.b".attr) && ("y.d".attr === "z.a".attr))
+ }
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+ val correctAnswer =
+ x.join(z, condition = Some("x.b".attr === "z.b".attr))
+ .join(y, condition = Some("y.d".attr === "z.a".attr))
+ .analyze
+
+ comparePlans(optimized, analysis.EliminateSubQueries(correctAnswer))
+ }
+}