aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorwangzhenhua <wangzhenhua@huawei.com>2017-03-08 16:01:28 +0100
committerHerman van Hovell <hvanhovell@databricks.com>2017-03-08 16:01:28 +0100
commite44274870dee308f4e3e8ce79457d8d19693b6e5 (patch)
tree99cde8d5b623e14e9b1a8fa86ab152a4cba0e640
parent9ea201cf6482c9c62c9428759d238063db62d66e (diff)
downloadspark-e44274870dee308f4e3e8ce79457d8d19693b6e5.tar.gz
spark-e44274870dee308f4e3e8ce79457d8d19693b6e5.tar.bz2
spark-e44274870dee308f4e3e8ce79457d8d19693b6e5.zip
[SPARK-17080][SQL] join reorder
## What changes were proposed in this pull request? Reorder the joins using a dynamic programming algorithm (Selinger paper): First we put all items (basic joined nodes) into level 1, then we build all two-way joins at level 2 from plans at level 1 (single items), then build all 3-way joins from plans at previous levels (two-way joins and single items), then 4-way joins ... etc, until we build all n-way joins and pick the best plan among them. When building m-way joins, we only keep the best plan (with the lowest cost) for the same set of m items. E.g., for 3-way joins, we keep only the best plan for items {A, B, C} among plans (A J B) J C, (A J C) J B and (B J C) J A. Thus, the plans maintained for each level when reordering four items A, B, C, D are as follows: ``` level 1: p({A}), p({B}), p({C}), p({D}) level 2: p({A, B}), p({A, C}), p({A, D}), p({B, C}), p({B, D}), p({C, D}) level 3: p({A, B, C}), p({A, B, D}), p({A, C, D}), p({B, C, D}) level 4: p({A, B, C, D}) ``` where p({A, B, C, D}) is the final output plan. For cost evaluation, since physical costs for operators are not available currently, we use cardinalities and sizes to compute costs. ## How was this patch tested? add test cases Author: wangzhenhua <wangzhenhua@huawei.com> Author: Zhenhua Wang <wzh_zju@163.com> Closes #17138 from wzhfy/joinReorder.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala297
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala194
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala16
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala2
8 files changed, 521 insertions, 4 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
index 5f50ce1ba6..fb99cb27b8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
@@ -60,6 +60,12 @@ trait CatalystConf {
* Enables CBO for estimation of plan statistics when set true.
*/
def cboEnabled: Boolean
+
+ /** Enables join reorder in CBO. */
+ def joinReorderEnabled: Boolean
+
+ /** The maximum number of joined nodes allowed in the dynamic programming algorithm. */
+ def joinReorderDPThreshold: Int
}
@@ -75,6 +81,8 @@ case class SimpleCatalystConf(
runSQLonFile: Boolean = true,
crossJoinEnabled: Boolean = false,
cboEnabled: Boolean = false,
+ joinReorderEnabled: Boolean = false,
+ joinReorderDPThreshold: Int = 12,
warehousePath: String = "/user/hive/warehouse",
sessionLocalTimeZone: String = TimeZone.getDefault().getID)
extends CatalystConf
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala
new file mode 100644
index 0000000000..b694561e53
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala
@@ -0,0 +1,297 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import scala.collection.mutable
+
+import org.apache.spark.sql.catalyst.CatalystConf
+import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeSet, Expression, PredicateHelper}
+import org.apache.spark.sql.catalyst.plans.{Inner, InnerLike}
+import org.apache.spark.sql.catalyst.plans.logical.{BinaryNode, Join, LogicalPlan, Project}
+import org.apache.spark.sql.catalyst.rules.Rule
+
+
+/**
+ * Cost-based join reorder.
+ * We may have several join reorder algorithms in the future. This class is the entry of these
+ * algorithms, and chooses which one to use.
+ */
+case class CostBasedJoinReorder(conf: CatalystConf) extends Rule[LogicalPlan] with PredicateHelper {
+ def apply(plan: LogicalPlan): LogicalPlan = {
+ if (!conf.cboEnabled || !conf.joinReorderEnabled) {
+ plan
+ } else {
+ val result = plan transform {
+ case p @ Project(projectList, j @ Join(_, _, _: InnerLike, _)) =>
+ reorder(p, p.outputSet)
+ case j @ Join(_, _, _: InnerLike, _) =>
+ reorder(j, j.outputSet)
+ }
+ // After reordering is finished, convert OrderedJoin back to Join
+ result transform {
+ case oj: OrderedJoin => oj.join
+ }
+ }
+ }
+
+ def reorder(plan: LogicalPlan, output: AttributeSet): LogicalPlan = {
+ val (items, conditions) = extractInnerJoins(plan)
+ val result =
+ // Do reordering if the number of items is appropriate and join conditions exist.
+ // We also need to check if costs of all items can be evaluated.
+ if (items.size > 2 && items.size <= conf.joinReorderDPThreshold && conditions.nonEmpty &&
+ items.forall(_.stats(conf).rowCount.isDefined)) {
+ JoinReorderDP.search(conf, items, conditions, output).getOrElse(plan)
+ } else {
+ plan
+ }
+ // Set consecutive join nodes ordered.
+ replaceWithOrderedJoin(result)
+ }
+
+ /**
+ * Extract consecutive inner joinable items and join conditions.
+ * This method works for bushy trees and left/right deep trees.
+ */
+ private def extractInnerJoins(plan: LogicalPlan): (Seq[LogicalPlan], Set[Expression]) = {
+ plan match {
+ case Join(left, right, _: InnerLike, cond) =>
+ val (leftPlans, leftConditions) = extractInnerJoins(left)
+ val (rightPlans, rightConditions) = extractInnerJoins(right)
+ (leftPlans ++ rightPlans, cond.toSet.flatMap(splitConjunctivePredicates) ++
+ leftConditions ++ rightConditions)
+ case Project(projectList, join) if projectList.forall(_.isInstanceOf[Attribute]) =>
+ extractInnerJoins(join)
+ case _ =>
+ (Seq(plan), Set())
+ }
+ }
+
+ private def replaceWithOrderedJoin(plan: LogicalPlan): LogicalPlan = plan match {
+ case j @ Join(left, right, _: InnerLike, cond) =>
+ val replacedLeft = replaceWithOrderedJoin(left)
+ val replacedRight = replaceWithOrderedJoin(right)
+ OrderedJoin(j.copy(left = replacedLeft, right = replacedRight))
+ case p @ Project(_, join) =>
+ p.copy(child = replaceWithOrderedJoin(join))
+ case _ =>
+ plan
+ }
+
+ /** This is a wrapper class for a join node that has been ordered. */
+ private case class OrderedJoin(join: Join) extends BinaryNode {
+ override def left: LogicalPlan = join.left
+ override def right: LogicalPlan = join.right
+ override def output: Seq[Attribute] = join.output
+ }
+}
+
+/**
+ * Reorder the joins using a dynamic programming algorithm. This implementation is based on the
+ * paper: Access Path Selection in a Relational Database Management System.
+ * http://www.inf.ed.ac.uk/teaching/courses/adbs/AccessPath.pdf
+ *
+ * First we put all items (basic joined nodes) into level 0, then we build all two-way joins
+ * at level 1 from plans at level 0 (single items), then build all 3-way joins from plans
+ * at previous levels (two-way joins and single items), then 4-way joins ... etc, until we
+ * build all n-way joins and pick the best plan among them.
+ *
+ * When building m-way joins, we only keep the best plan (with the lowest cost) for the same set
+ * of m items. E.g., for 3-way joins, we keep only the best plan for items {A, B, C} among
+ * plans (A J B) J C, (A J C) J B and (B J C) J A.
+ *
+ * Thus the plans maintained for each level when reordering four items A, B, C, D are as follows:
+ * level 0: p({A}), p({B}), p({C}), p({D})
+ * level 1: p({A, B}), p({A, C}), p({A, D}), p({B, C}), p({B, D}), p({C, D})
+ * level 2: p({A, B, C}), p({A, B, D}), p({A, C, D}), p({B, C, D})
+ * level 3: p({A, B, C, D})
+ * where p({A, B, C, D}) is the final output plan.
+ *
+ * For cost evaluation, since physical costs for operators are not available currently, we use
+ * cardinalities and sizes to compute costs.
+ */
+object JoinReorderDP extends PredicateHelper {
+
+ def search(
+ conf: CatalystConf,
+ items: Seq[LogicalPlan],
+ conditions: Set[Expression],
+ topOutput: AttributeSet): Option[LogicalPlan] = {
+
+ // Level i maintains all found plans for i + 1 items.
+ // Create the initial plans: each plan is a single item with zero cost.
+ val itemIndex = items.zipWithIndex
+ val foundPlans = mutable.Buffer[JoinPlanMap](itemIndex.map {
+ case (item, id) => Set(id) -> JoinPlan(Set(id), item, Set(), Cost(0, 0))
+ }.toMap)
+
+ for (lev <- 1 until items.length) {
+ // Build plans for the next level.
+ foundPlans += searchLevel(foundPlans, conf, conditions, topOutput)
+ }
+
+ val plansLastLevel = foundPlans(items.length - 1)
+ if (plansLastLevel.isEmpty) {
+ // Failed to find a plan, fall back to the original plan
+ None
+ } else {
+ // There must be only one plan at the last level, which contains all items.
+ assert(plansLastLevel.size == 1 && plansLastLevel.head._1.size == items.length)
+ Some(plansLastLevel.head._2.plan)
+ }
+ }
+
+ /** Find all possible plans at the next level, based on existing levels. */
+ private def searchLevel(
+ existingLevels: Seq[JoinPlanMap],
+ conf: CatalystConf,
+ conditions: Set[Expression],
+ topOutput: AttributeSet): JoinPlanMap = {
+
+ val nextLevel = mutable.Map.empty[Set[Int], JoinPlan]
+ var k = 0
+ val lev = existingLevels.length - 1
+ // Build plans for the next level from plans at level k (one side of the join) and level
+ // lev - k (the other side of the join).
+ // For the lower level k, we only need to search from 0 to lev - k, because when building
+ // a join from A and B, both A J B and B J A are handled.
+ while (k <= lev - k) {
+ val oneSideCandidates = existingLevels(k).values.toSeq
+ for (i <- oneSideCandidates.indices) {
+ val oneSidePlan = oneSideCandidates(i)
+ val otherSideCandidates = if (k == lev - k) {
+ // Both sides of a join are at the same level, no need to repeat for previous ones.
+ oneSideCandidates.drop(i)
+ } else {
+ existingLevels(lev - k).values.toSeq
+ }
+
+ otherSideCandidates.foreach { otherSidePlan =>
+ // Should not join two overlapping item sets.
+ if (oneSidePlan.itemIds.intersect(otherSidePlan.itemIds).isEmpty) {
+ val joinPlan = buildJoin(oneSidePlan, otherSidePlan, conf, conditions, topOutput)
+ // Check if it's the first plan for the item set, or it's a better plan than
+ // the existing one due to lower cost.
+ val existingPlan = nextLevel.get(joinPlan.itemIds)
+ if (existingPlan.isEmpty || joinPlan.cost.lessThan(existingPlan.get.cost)) {
+ nextLevel.update(joinPlan.itemIds, joinPlan)
+ }
+ }
+ }
+ }
+ k += 1
+ }
+ nextLevel.toMap
+ }
+
+ /** Build a new join node. */
+ private def buildJoin(
+ oneJoinPlan: JoinPlan,
+ otherJoinPlan: JoinPlan,
+ conf: CatalystConf,
+ conditions: Set[Expression],
+ topOutput: AttributeSet): JoinPlan = {
+
+ val onePlan = oneJoinPlan.plan
+ val otherPlan = otherJoinPlan.plan
+ // Now both onePlan and otherPlan become intermediate joins, so the cost of the
+ // new join should also include their own cardinalities and sizes.
+ val newCost = if (isCartesianProduct(onePlan) || isCartesianProduct(otherPlan)) {
+ // We consider cartesian product very expensive, thus set a very large cost for it.
+ // This enables to plan all the cartesian products at the end, because having a cartesian
+ // product as an intermediate join will significantly increase a plan's cost, making it
+ // impossible to be selected as the best plan for the items, unless there's no other choice.
+ Cost(
+ rows = BigInt(Long.MaxValue) * BigInt(Long.MaxValue),
+ size = BigInt(Long.MaxValue) * BigInt(Long.MaxValue))
+ } else {
+ val onePlanStats = onePlan.stats(conf)
+ val otherPlanStats = otherPlan.stats(conf)
+ Cost(
+ rows = oneJoinPlan.cost.rows + onePlanStats.rowCount.get +
+ otherJoinPlan.cost.rows + otherPlanStats.rowCount.get,
+ size = oneJoinPlan.cost.size + onePlanStats.sizeInBytes +
+ otherJoinPlan.cost.size + otherPlanStats.sizeInBytes)
+ }
+
+ // Put the deeper side on the left, tend to build a left-deep tree.
+ val (left, right) = if (oneJoinPlan.itemIds.size >= otherJoinPlan.itemIds.size) {
+ (onePlan, otherPlan)
+ } else {
+ (otherPlan, onePlan)
+ }
+ val joinConds = conditions
+ .filterNot(l => canEvaluate(l, onePlan))
+ .filterNot(r => canEvaluate(r, otherPlan))
+ .filter(e => e.references.subsetOf(onePlan.outputSet ++ otherPlan.outputSet))
+ // We use inner join whether join condition is empty or not. Since cross join is
+ // equivalent to inner join without condition.
+ val newJoin = Join(left, right, Inner, joinConds.reduceOption(And))
+ val collectedJoinConds = joinConds ++ oneJoinPlan.joinConds ++ otherJoinPlan.joinConds
+ val remainingConds = conditions -- collectedJoinConds
+ val neededAttr = AttributeSet(remainingConds.flatMap(_.references)) ++ topOutput
+ val neededFromNewJoin = newJoin.outputSet.filter(neededAttr.contains)
+ val newPlan =
+ if ((newJoin.outputSet -- neededFromNewJoin).nonEmpty) {
+ Project(neededFromNewJoin.toSeq, newJoin)
+ } else {
+ newJoin
+ }
+
+ val itemIds = oneJoinPlan.itemIds.union(otherJoinPlan.itemIds)
+ JoinPlan(itemIds, newPlan, collectedJoinConds, newCost)
+ }
+
+ private def isCartesianProduct(plan: LogicalPlan): Boolean = plan match {
+ case Join(_, _, _, None) => true
+ case Project(_, Join(_, _, _, None)) => true
+ case _ => false
+ }
+
+ /** Map[set of item ids, join plan for these items] */
+ type JoinPlanMap = Map[Set[Int], JoinPlan]
+
+ /**
+ * Partial join order in a specific level.
+ *
+ * @param itemIds Set of item ids participating in this partial plan.
+ * @param plan The plan tree with the lowest cost for these items found so far.
+ * @param joinConds Join conditions included in the plan.
+ * @param cost The cost of this plan is the sum of costs of all intermediate joins.
+ */
+ case class JoinPlan(itemIds: Set[Int], plan: LogicalPlan, joinConds: Set[Expression], cost: Cost)
+}
+
+/** This class defines the cost model. */
+case class Cost(rows: BigInt, size: BigInt) {
+ /**
+ * An empirical value for the weights of cardinality (number of rows) in the cost formula:
+ * cost = rows * weight + size * (1 - weight), usually cardinality is more important than size.
+ */
+ val weight = 0.7
+
+ def lessThan(other: Cost): Boolean = {
+ if (other.rows == 0 || other.size == 0) {
+ false
+ } else {
+ val relativeRows = BigDecimal(rows) / BigDecimal(other.rows)
+ val relativeSize = BigDecimal(size) / BigDecimal(other.size)
+ relativeRows * weight + relativeSize * (1 - weight) < 1
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 036da3ad20..d5bbc6e8ac 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -118,6 +118,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
SimplifyCreateMapOps) ::
Batch("Check Cartesian Products", Once,
CheckCartesianProducts(conf)) ::
+ Batch("Join Reorder", Once,
+ CostBasedJoinReorder(conf)) ::
Batch("Decimal Optimizations", fixedPoint,
DecimalAggregates(conf)) ::
Batch("Typed Filter Optimization", fixedPoint,
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala
new file mode 100644
index 0000000000..1b2f7a66b6
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala
@@ -0,0 +1,194 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.SimpleCatalystConf
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
+import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
+import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, StatsTestPlan}
+import org.apache.spark.sql.catalyst.util._
+
+
+class JoinReorderSuite extends PlanTest with StatsEstimationTestBase {
+
+ override val conf = SimpleCatalystConf(
+ caseSensitiveAnalysis = true, cboEnabled = true, joinReorderEnabled = true)
+
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches =
+ Batch("Operator Optimizations", FixedPoint(100),
+ CombineFilters,
+ PushDownPredicate,
+ PushPredicateThroughJoin,
+ ColumnPruning,
+ CollapseProject) ::
+ Batch("Join Reorder", Once,
+ CostBasedJoinReorder(conf)) :: Nil
+ }
+
+ /** Set up tables and columns for testing */
+ private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq(
+ attr("t1.k-1-2") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ attr("t1.v-1-10") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ attr("t2.k-1-5") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ attr("t3.v-1-100") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ attr("t4.k-1-2") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ attr("t4.v-1-10") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
+ nullCount = 0, avgLen = 4, maxLen = 4)
+ ))
+
+ private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1)
+ private val nameToColInfo: Map[String, (Attribute, ColumnStat)] =
+ columnInfo.map(kv => kv._1.name -> kv)
+
+ // Table t1/t4: big table with two columns
+ private val t1 = StatsTestPlan(
+ outputList = Seq("t1.k-1-2", "t1.v-1-10").map(nameToAttr),
+ rowCount = 1000,
+ // size = rows * (overhead + column length)
+ size = Some(1000 * (8 + 4 + 4)),
+ attributeStats = AttributeMap(Seq("t1.k-1-2", "t1.v-1-10").map(nameToColInfo)))
+
+ private val t4 = StatsTestPlan(
+ outputList = Seq("t4.k-1-2", "t4.v-1-10").map(nameToAttr),
+ rowCount = 2000,
+ size = Some(2000 * (8 + 4 + 4)),
+ attributeStats = AttributeMap(Seq("t4.k-1-2", "t4.v-1-10").map(nameToColInfo)))
+
+ // Table t2/t3: small table with only one column
+ private val t2 = StatsTestPlan(
+ outputList = Seq("t2.k-1-5").map(nameToAttr),
+ rowCount = 20,
+ size = Some(20 * (8 + 4)),
+ attributeStats = AttributeMap(Seq("t2.k-1-5").map(nameToColInfo)))
+
+ private val t3 = StatsTestPlan(
+ outputList = Seq("t3.v-1-100").map(nameToAttr),
+ rowCount = 100,
+ size = Some(100 * (8 + 4)),
+ attributeStats = AttributeMap(Seq("t3.v-1-100").map(nameToColInfo)))
+
+ test("reorder 3 tables") {
+ val originalPlan =
+ t1.join(t2).join(t3).where((nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) &&
+ (nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100")))
+
+ // The cost of original plan (use only cardinality to simplify explanation):
+ // cost = cost(t1 J t2) = 1000 * 20 / 5 = 4000
+ // In contrast, the cost of the best plan:
+ // cost = cost(t1 J t3) = 1000 * 100 / 100 = 1000 < 4000
+ // so (t1 J t3) J t2 is better (has lower cost, i.e. intermediate result size) than
+ // the original order (t1 J t2) J t3.
+ val bestPlan =
+ t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100")))
+ .join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
+
+ assertEqualPlans(originalPlan, bestPlan)
+ }
+
+ test("reorder 3 tables - put cross join at the end") {
+ val originalPlan =
+ t1.join(t2).join(t3).where(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))
+
+ val bestPlan =
+ t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100")))
+ .join(t2, Inner, None)
+
+ assertEqualPlans(originalPlan, bestPlan)
+ }
+
+ test("reorder 3 tables with pure-attribute project") {
+ val originalPlan =
+ t1.join(t2).join(t3).where((nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) &&
+ (nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100")))
+ .select(nameToAttr("t1.v-1-10"))
+
+ val bestPlan =
+ t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100")))
+ .select(nameToAttr("t1.k-1-2"), nameToAttr("t1.v-1-10"))
+ .join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
+ .select(nameToAttr("t1.v-1-10"))
+
+ assertEqualPlans(originalPlan, bestPlan)
+ }
+
+ test("don't reorder if project contains non-attribute") {
+ val originalPlan =
+ t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
+ .select((nameToAttr("t1.k-1-2") + nameToAttr("t2.k-1-5")) as "key", nameToAttr("t1.v-1-10"))
+ .join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100")))
+ .select("key".attr)
+
+ assertEqualPlans(originalPlan, originalPlan)
+ }
+
+ test("reorder 4 tables (bushy tree)") {
+ val originalPlan =
+ t1.join(t4).join(t2).join(t3).where((nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")) &&
+ (nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) &&
+ (nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100")))
+
+ // The cost of original plan (use only cardinality to simplify explanation):
+ // cost(t1 J t4) = 1000 * 2000 / 2 = 1000000, cost(t1t4 J t2) = 1000000 * 20 / 5 = 4000000,
+ // cost = cost(t1 J t4) + cost(t1t4 J t2) = 5000000
+ // In contrast, the cost of the best plan (a bushy tree):
+ // cost(t1 J t2) = 1000 * 20 / 5 = 4000, cost(t4 J t3) = 2000 * 100 / 100 = 2000,
+ // cost = cost(t1 J t2) + cost(t4 J t3) = 6000 << 5000000.
+ val bestPlan =
+ t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
+ .join(t4.join(t3, Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))),
+ Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")))
+
+ assertEqualPlans(originalPlan, bestPlan)
+ }
+
+ private def assertEqualPlans(
+ originalPlan: LogicalPlan,
+ groundTruthBestPlan: LogicalPlan): Unit = {
+ val optimized = Optimize.execute(originalPlan.analyze)
+ val normalized1 = normalizePlan(normalizeExprIds(optimized))
+ val normalized2 = normalizePlan(normalizeExprIds(groundTruthBestPlan.analyze))
+ if (!sameJoinPlan(normalized1, normalized2)) {
+ fail(
+ s"""
+ |== FAIL: Plans do not match ===
+ |${sideBySide(normalized1.treeString, normalized2.treeString).mkString("\n")}
+ """.stripMargin)
+ }
+ }
+
+ /** Consider symmetry for joins when comparing plans. */
+ private def sameJoinPlan(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = {
+ (plan1, plan2) match {
+ case (j1: Join, j2: Join) =>
+ (sameJoinPlan(j1.left, j2.left) && sameJoinPlan(j1.right, j2.right)) ||
+ (sameJoinPlan(j1.left, j2.right) && sameJoinPlan(j1.right, j2.left))
+ case _ =>
+ plan1 == plan2
+ }
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
index 3b7e5e938a..e9b7a0c6ad 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
@@ -62,7 +62,7 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper {
* - Sample the seed will replaced by 0L.
* - Join conditions will be resorted by hashCode.
*/
- private def normalizePlan(plan: LogicalPlan): LogicalPlan = {
+ protected def normalizePlan(plan: LogicalPlan): LogicalPlan = {
plan transform {
case filter @ Filter(condition: Expression, child: LogicalPlan) =>
Filter(splitConjunctivePredicates(condition).map(rewriteEqual(_)).sortBy(_.hashCode())
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala
index c56b41ce37..9b2b8dbe1b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, Logica
import org.apache.spark.sql.types.{IntegerType, StringType}
-class StatsEstimationTestBase extends SparkFunSuite {
+trait StatsEstimationTestBase extends SparkFunSuite {
/** Enable stats estimation based on CBO. */
protected val conf = SimpleCatalystConf(caseSensitiveAnalysis = true, cboEnabled = true)
@@ -48,7 +48,7 @@ class StatsEstimationTestBase extends SparkFunSuite {
/**
* This class is used for unit-testing. It's a logical plan whose output and stats are passed in.
*/
-protected case class StatsTestPlan(
+case class StatsTestPlan(
outputList: Seq[Attribute],
rowCount: BigInt,
attributeStats: AttributeMap[ColumnStat],
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index fd3acd42e8..94e3fa7dd1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -668,6 +668,18 @@ object SQLConf {
.booleanConf
.createWithDefault(false)
+ val JOIN_REORDER_ENABLED =
+ buildConf("spark.sql.cbo.joinReorder.enabled")
+ .doc("Enables join reorder in CBO.")
+ .booleanConf
+ .createWithDefault(false)
+
+ val JOIN_REORDER_DP_THRESHOLD =
+ buildConf("spark.sql.cbo.joinReorder.dp.threshold")
+ .doc("The maximum number of joined nodes allowed in the dynamic programming algorithm.")
+ .intConf
+ .createWithDefault(12)
+
val SESSION_LOCAL_TIMEZONE =
buildConf("spark.sql.session.timeZone")
.doc("""The ID of session local timezone, e.g. "GMT", "America/Los_Angeles", etc.""")
@@ -885,6 +897,10 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging {
override def cboEnabled: Boolean = getConf(SQLConf.CBO_ENABLED)
+ override def joinReorderEnabled: Boolean = getConf(SQLConf.JOIN_REORDER_ENABLED)
+
+ override def joinReorderDPThreshold: Int = getConf(SQLConf.JOIN_REORDER_DP_THRESHOLD)
+
/** ********************** SQLConf functionality methods ************ */
/** Set Spark SQL configuration properties. */
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala
index d44a6e41cb..a4d012cd76 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala
@@ -45,7 +45,7 @@ class SparkSqlParserSuite extends PlanTest {
* Normalizes plans:
* - CreateTable the createTime in tableDesc will replaced by -1L.
*/
- private def normalizePlan(plan: LogicalPlan): LogicalPlan = {
+ override def normalizePlan(plan: LogicalPlan): LogicalPlan = {
plan match {
case CreateTable(tableDesc, mode, query) =>
val newTableDesc = tableDesc.copy(createTime = -1L)