From e44274870dee308f4e3e8ce79457d8d19693b6e5 Mon Sep 17 00:00:00 2001 From: wangzhenhua Date: Wed, 8 Mar 2017 16:01:28 +0100 Subject: [SPARK-17080][SQL] join reorder ## What changes were proposed in this pull request? Reorder the joins using a dynamic programming algorithm (Selinger paper): First we put all items (basic joined nodes) into level 1, then we build all two-way joins at level 2 from plans at level 1 (single items), then build all 3-way joins from plans at previous levels (two-way joins and single items), then 4-way joins ... etc, until we build all n-way joins and pick the best plan among them. When building m-way joins, we only keep the best plan (with the lowest cost) for the same set of m items. E.g., for 3-way joins, we keep only the best plan for items {A, B, C} among plans (A J B) J C, (A J C) J B and (B J C) J A. Thus, the plans maintained for each level when reordering four items A, B, C, D are as follows: ``` level 1: p({A}), p({B}), p({C}), p({D}) level 2: p({A, B}), p({A, C}), p({A, D}), p({B, C}), p({B, D}), p({C, D}) level 3: p({A, B, C}), p({A, B, D}), p({A, C, D}), p({B, C, D}) level 4: p({A, B, C, D}) ``` where p({A, B, C, D}) is the final output plan. For cost evaluation, since physical costs for operators are not available currently, we use cardinalities and sizes to compute costs. ## How was this patch tested? add test cases Author: wangzhenhua Author: Zhenhua Wang Closes #17138 from wzhfy/joinReorder. --- .../apache/spark/sql/catalyst/CatalystConf.scala | 8 + .../catalyst/optimizer/CostBasedJoinReorder.scala | 297 +++++++++++++++++++++ .../spark/sql/catalyst/optimizer/Optimizer.scala | 2 + .../sql/catalyst/optimizer/JoinReorderSuite.scala | 194 ++++++++++++++ .../apache/spark/sql/catalyst/plans/PlanTest.scala | 2 +- .../statsEstimation/StatsEstimationTestBase.scala | 4 +- 6 files changed, 504 insertions(+), 3 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala index 5f50ce1ba6..fb99cb27b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala @@ -60,6 +60,12 @@ trait CatalystConf { * Enables CBO for estimation of plan statistics when set true. */ def cboEnabled: Boolean + + /** Enables join reorder in CBO. */ + def joinReorderEnabled: Boolean + + /** The maximum number of joined nodes allowed in the dynamic programming algorithm. */ + def joinReorderDPThreshold: Int } @@ -75,6 +81,8 @@ case class SimpleCatalystConf( runSQLonFile: Boolean = true, crossJoinEnabled: Boolean = false, cboEnabled: Boolean = false, + joinReorderEnabled: Boolean = false, + joinReorderDPThreshold: Int = 12, warehousePath: String = "/user/hive/warehouse", sessionLocalTimeZone: String = TimeZone.getDefault().getID) extends CatalystConf diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala new file mode 100644 index 0000000000..b694561e53 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala @@ -0,0 +1,297 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.CatalystConf +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeSet, Expression, PredicateHelper} +import org.apache.spark.sql.catalyst.plans.{Inner, InnerLike} +import org.apache.spark.sql.catalyst.plans.logical.{BinaryNode, Join, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule + + +/** + * Cost-based join reorder. + * We may have several join reorder algorithms in the future. This class is the entry of these + * algorithms, and chooses which one to use. + */ +case class CostBasedJoinReorder(conf: CatalystConf) extends Rule[LogicalPlan] with PredicateHelper { + def apply(plan: LogicalPlan): LogicalPlan = { + if (!conf.cboEnabled || !conf.joinReorderEnabled) { + plan + } else { + val result = plan transform { + case p @ Project(projectList, j @ Join(_, _, _: InnerLike, _)) => + reorder(p, p.outputSet) + case j @ Join(_, _, _: InnerLike, _) => + reorder(j, j.outputSet) + } + // After reordering is finished, convert OrderedJoin back to Join + result transform { + case oj: OrderedJoin => oj.join + } + } + } + + def reorder(plan: LogicalPlan, output: AttributeSet): LogicalPlan = { + val (items, conditions) = extractInnerJoins(plan) + val result = + // Do reordering if the number of items is appropriate and join conditions exist. + // We also need to check if costs of all items can be evaluated. + if (items.size > 2 && items.size <= conf.joinReorderDPThreshold && conditions.nonEmpty && + items.forall(_.stats(conf).rowCount.isDefined)) { + JoinReorderDP.search(conf, items, conditions, output).getOrElse(plan) + } else { + plan + } + // Set consecutive join nodes ordered. + replaceWithOrderedJoin(result) + } + + /** + * Extract consecutive inner joinable items and join conditions. + * This method works for bushy trees and left/right deep trees. + */ + private def extractInnerJoins(plan: LogicalPlan): (Seq[LogicalPlan], Set[Expression]) = { + plan match { + case Join(left, right, _: InnerLike, cond) => + val (leftPlans, leftConditions) = extractInnerJoins(left) + val (rightPlans, rightConditions) = extractInnerJoins(right) + (leftPlans ++ rightPlans, cond.toSet.flatMap(splitConjunctivePredicates) ++ + leftConditions ++ rightConditions) + case Project(projectList, join) if projectList.forall(_.isInstanceOf[Attribute]) => + extractInnerJoins(join) + case _ => + (Seq(plan), Set()) + } + } + + private def replaceWithOrderedJoin(plan: LogicalPlan): LogicalPlan = plan match { + case j @ Join(left, right, _: InnerLike, cond) => + val replacedLeft = replaceWithOrderedJoin(left) + val replacedRight = replaceWithOrderedJoin(right) + OrderedJoin(j.copy(left = replacedLeft, right = replacedRight)) + case p @ Project(_, join) => + p.copy(child = replaceWithOrderedJoin(join)) + case _ => + plan + } + + /** This is a wrapper class for a join node that has been ordered. */ + private case class OrderedJoin(join: Join) extends BinaryNode { + override def left: LogicalPlan = join.left + override def right: LogicalPlan = join.right + override def output: Seq[Attribute] = join.output + } +} + +/** + * Reorder the joins using a dynamic programming algorithm. This implementation is based on the + * paper: Access Path Selection in a Relational Database Management System. + * http://www.inf.ed.ac.uk/teaching/courses/adbs/AccessPath.pdf + * + * First we put all items (basic joined nodes) into level 0, then we build all two-way joins + * at level 1 from plans at level 0 (single items), then build all 3-way joins from plans + * at previous levels (two-way joins and single items), then 4-way joins ... etc, until we + * build all n-way joins and pick the best plan among them. + * + * When building m-way joins, we only keep the best plan (with the lowest cost) for the same set + * of m items. E.g., for 3-way joins, we keep only the best plan for items {A, B, C} among + * plans (A J B) J C, (A J C) J B and (B J C) J A. + * + * Thus the plans maintained for each level when reordering four items A, B, C, D are as follows: + * level 0: p({A}), p({B}), p({C}), p({D}) + * level 1: p({A, B}), p({A, C}), p({A, D}), p({B, C}), p({B, D}), p({C, D}) + * level 2: p({A, B, C}), p({A, B, D}), p({A, C, D}), p({B, C, D}) + * level 3: p({A, B, C, D}) + * where p({A, B, C, D}) is the final output plan. + * + * For cost evaluation, since physical costs for operators are not available currently, we use + * cardinalities and sizes to compute costs. + */ +object JoinReorderDP extends PredicateHelper { + + def search( + conf: CatalystConf, + items: Seq[LogicalPlan], + conditions: Set[Expression], + topOutput: AttributeSet): Option[LogicalPlan] = { + + // Level i maintains all found plans for i + 1 items. + // Create the initial plans: each plan is a single item with zero cost. + val itemIndex = items.zipWithIndex + val foundPlans = mutable.Buffer[JoinPlanMap](itemIndex.map { + case (item, id) => Set(id) -> JoinPlan(Set(id), item, Set(), Cost(0, 0)) + }.toMap) + + for (lev <- 1 until items.length) { + // Build plans for the next level. + foundPlans += searchLevel(foundPlans, conf, conditions, topOutput) + } + + val plansLastLevel = foundPlans(items.length - 1) + if (plansLastLevel.isEmpty) { + // Failed to find a plan, fall back to the original plan + None + } else { + // There must be only one plan at the last level, which contains all items. + assert(plansLastLevel.size == 1 && plansLastLevel.head._1.size == items.length) + Some(plansLastLevel.head._2.plan) + } + } + + /** Find all possible plans at the next level, based on existing levels. */ + private def searchLevel( + existingLevels: Seq[JoinPlanMap], + conf: CatalystConf, + conditions: Set[Expression], + topOutput: AttributeSet): JoinPlanMap = { + + val nextLevel = mutable.Map.empty[Set[Int], JoinPlan] + var k = 0 + val lev = existingLevels.length - 1 + // Build plans for the next level from plans at level k (one side of the join) and level + // lev - k (the other side of the join). + // For the lower level k, we only need to search from 0 to lev - k, because when building + // a join from A and B, both A J B and B J A are handled. + while (k <= lev - k) { + val oneSideCandidates = existingLevels(k).values.toSeq + for (i <- oneSideCandidates.indices) { + val oneSidePlan = oneSideCandidates(i) + val otherSideCandidates = if (k == lev - k) { + // Both sides of a join are at the same level, no need to repeat for previous ones. + oneSideCandidates.drop(i) + } else { + existingLevels(lev - k).values.toSeq + } + + otherSideCandidates.foreach { otherSidePlan => + // Should not join two overlapping item sets. + if (oneSidePlan.itemIds.intersect(otherSidePlan.itemIds).isEmpty) { + val joinPlan = buildJoin(oneSidePlan, otherSidePlan, conf, conditions, topOutput) + // Check if it's the first plan for the item set, or it's a better plan than + // the existing one due to lower cost. + val existingPlan = nextLevel.get(joinPlan.itemIds) + if (existingPlan.isEmpty || joinPlan.cost.lessThan(existingPlan.get.cost)) { + nextLevel.update(joinPlan.itemIds, joinPlan) + } + } + } + } + k += 1 + } + nextLevel.toMap + } + + /** Build a new join node. */ + private def buildJoin( + oneJoinPlan: JoinPlan, + otherJoinPlan: JoinPlan, + conf: CatalystConf, + conditions: Set[Expression], + topOutput: AttributeSet): JoinPlan = { + + val onePlan = oneJoinPlan.plan + val otherPlan = otherJoinPlan.plan + // Now both onePlan and otherPlan become intermediate joins, so the cost of the + // new join should also include their own cardinalities and sizes. + val newCost = if (isCartesianProduct(onePlan) || isCartesianProduct(otherPlan)) { + // We consider cartesian product very expensive, thus set a very large cost for it. + // This enables to plan all the cartesian products at the end, because having a cartesian + // product as an intermediate join will significantly increase a plan's cost, making it + // impossible to be selected as the best plan for the items, unless there's no other choice. + Cost( + rows = BigInt(Long.MaxValue) * BigInt(Long.MaxValue), + size = BigInt(Long.MaxValue) * BigInt(Long.MaxValue)) + } else { + val onePlanStats = onePlan.stats(conf) + val otherPlanStats = otherPlan.stats(conf) + Cost( + rows = oneJoinPlan.cost.rows + onePlanStats.rowCount.get + + otherJoinPlan.cost.rows + otherPlanStats.rowCount.get, + size = oneJoinPlan.cost.size + onePlanStats.sizeInBytes + + otherJoinPlan.cost.size + otherPlanStats.sizeInBytes) + } + + // Put the deeper side on the left, tend to build a left-deep tree. + val (left, right) = if (oneJoinPlan.itemIds.size >= otherJoinPlan.itemIds.size) { + (onePlan, otherPlan) + } else { + (otherPlan, onePlan) + } + val joinConds = conditions + .filterNot(l => canEvaluate(l, onePlan)) + .filterNot(r => canEvaluate(r, otherPlan)) + .filter(e => e.references.subsetOf(onePlan.outputSet ++ otherPlan.outputSet)) + // We use inner join whether join condition is empty or not. Since cross join is + // equivalent to inner join without condition. + val newJoin = Join(left, right, Inner, joinConds.reduceOption(And)) + val collectedJoinConds = joinConds ++ oneJoinPlan.joinConds ++ otherJoinPlan.joinConds + val remainingConds = conditions -- collectedJoinConds + val neededAttr = AttributeSet(remainingConds.flatMap(_.references)) ++ topOutput + val neededFromNewJoin = newJoin.outputSet.filter(neededAttr.contains) + val newPlan = + if ((newJoin.outputSet -- neededFromNewJoin).nonEmpty) { + Project(neededFromNewJoin.toSeq, newJoin) + } else { + newJoin + } + + val itemIds = oneJoinPlan.itemIds.union(otherJoinPlan.itemIds) + JoinPlan(itemIds, newPlan, collectedJoinConds, newCost) + } + + private def isCartesianProduct(plan: LogicalPlan): Boolean = plan match { + case Join(_, _, _, None) => true + case Project(_, Join(_, _, _, None)) => true + case _ => false + } + + /** Map[set of item ids, join plan for these items] */ + type JoinPlanMap = Map[Set[Int], JoinPlan] + + /** + * Partial join order in a specific level. + * + * @param itemIds Set of item ids participating in this partial plan. + * @param plan The plan tree with the lowest cost for these items found so far. + * @param joinConds Join conditions included in the plan. + * @param cost The cost of this plan is the sum of costs of all intermediate joins. + */ + case class JoinPlan(itemIds: Set[Int], plan: LogicalPlan, joinConds: Set[Expression], cost: Cost) +} + +/** This class defines the cost model. */ +case class Cost(rows: BigInt, size: BigInt) { + /** + * An empirical value for the weights of cardinality (number of rows) in the cost formula: + * cost = rows * weight + size * (1 - weight), usually cardinality is more important than size. + */ + val weight = 0.7 + + def lessThan(other: Cost): Boolean = { + if (other.rows == 0 || other.size == 0) { + false + } else { + val relativeRows = BigDecimal(rows) / BigDecimal(other.rows) + val relativeSize = BigDecimal(size) / BigDecimal(other.size) + relativeRows * weight + relativeSize * (1 - weight) < 1 + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 036da3ad20..d5bbc6e8ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -118,6 +118,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) SimplifyCreateMapOps) :: Batch("Check Cartesian Products", Once, CheckCartesianProducts(conf)) :: + Batch("Join Reorder", Once, + CostBasedJoinReorder(conf)) :: Batch("Decimal Optimizations", fixedPoint, DecimalAggregates(conf)) :: Batch("Typed Filter Optimization", fixedPoint, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala new file mode 100644 index 0000000000..1b2f7a66b6 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.SimpleCatalystConf +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} +import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, StatsTestPlan} +import org.apache.spark.sql.catalyst.util._ + + +class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { + + override val conf = SimpleCatalystConf( + caseSensitiveAnalysis = true, cboEnabled = true, joinReorderEnabled = true) + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Operator Optimizations", FixedPoint(100), + CombineFilters, + PushDownPredicate, + PushPredicateThroughJoin, + ColumnPruning, + CollapseProject) :: + Batch("Join Reorder", Once, + CostBasedJoinReorder(conf)) :: Nil + } + + /** Set up tables and columns for testing */ + private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( + attr("t1.k-1-2") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("t1.v-1-10") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("t2.k-1-5") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("t3.v-1-100") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("t4.k-1-2") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("t4.v-1-10") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4) + )) + + private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1) + private val nameToColInfo: Map[String, (Attribute, ColumnStat)] = + columnInfo.map(kv => kv._1.name -> kv) + + // Table t1/t4: big table with two columns + private val t1 = StatsTestPlan( + outputList = Seq("t1.k-1-2", "t1.v-1-10").map(nameToAttr), + rowCount = 1000, + // size = rows * (overhead + column length) + size = Some(1000 * (8 + 4 + 4)), + attributeStats = AttributeMap(Seq("t1.k-1-2", "t1.v-1-10").map(nameToColInfo))) + + private val t4 = StatsTestPlan( + outputList = Seq("t4.k-1-2", "t4.v-1-10").map(nameToAttr), + rowCount = 2000, + size = Some(2000 * (8 + 4 + 4)), + attributeStats = AttributeMap(Seq("t4.k-1-2", "t4.v-1-10").map(nameToColInfo))) + + // Table t2/t3: small table with only one column + private val t2 = StatsTestPlan( + outputList = Seq("t2.k-1-5").map(nameToAttr), + rowCount = 20, + size = Some(20 * (8 + 4)), + attributeStats = AttributeMap(Seq("t2.k-1-5").map(nameToColInfo))) + + private val t3 = StatsTestPlan( + outputList = Seq("t3.v-1-100").map(nameToAttr), + rowCount = 100, + size = Some(100 * (8 + 4)), + attributeStats = AttributeMap(Seq("t3.v-1-100").map(nameToColInfo))) + + test("reorder 3 tables") { + val originalPlan = + t1.join(t2).join(t3).where((nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) && + (nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + + // The cost of original plan (use only cardinality to simplify explanation): + // cost = cost(t1 J t2) = 1000 * 20 / 5 = 4000 + // In contrast, the cost of the best plan: + // cost = cost(t1 J t3) = 1000 * 100 / 100 = 1000 < 4000 + // so (t1 J t3) J t2 is better (has lower cost, i.e. intermediate result size) than + // the original order (t1 J t2) J t3. + val bestPlan = + t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + + assertEqualPlans(originalPlan, bestPlan) + } + + test("reorder 3 tables - put cross join at the end") { + val originalPlan = + t1.join(t2).join(t3).where(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100")) + + val bestPlan = + t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .join(t2, Inner, None) + + assertEqualPlans(originalPlan, bestPlan) + } + + test("reorder 3 tables with pure-attribute project") { + val originalPlan = + t1.join(t2).join(t3).where((nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) && + (nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .select(nameToAttr("t1.v-1-10")) + + val bestPlan = + t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .select(nameToAttr("t1.k-1-2"), nameToAttr("t1.v-1-10")) + .join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + .select(nameToAttr("t1.v-1-10")) + + assertEqualPlans(originalPlan, bestPlan) + } + + test("don't reorder if project contains non-attribute") { + val originalPlan = + t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + .select((nameToAttr("t1.k-1-2") + nameToAttr("t2.k-1-5")) as "key", nameToAttr("t1.v-1-10")) + .join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .select("key".attr) + + assertEqualPlans(originalPlan, originalPlan) + } + + test("reorder 4 tables (bushy tree)") { + val originalPlan = + t1.join(t4).join(t2).join(t3).where((nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")) && + (nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) && + (nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))) + + // The cost of original plan (use only cardinality to simplify explanation): + // cost(t1 J t4) = 1000 * 2000 / 2 = 1000000, cost(t1t4 J t2) = 1000000 * 20 / 5 = 4000000, + // cost = cost(t1 J t4) + cost(t1t4 J t2) = 5000000 + // In contrast, the cost of the best plan (a bushy tree): + // cost(t1 J t2) = 1000 * 20 / 5 = 4000, cost(t4 J t3) = 2000 * 100 / 100 = 2000, + // cost = cost(t1 J t2) + cost(t4 J t3) = 6000 << 5000000. + val bestPlan = + t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + .join(t4.join(t3, Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))), + Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2"))) + + assertEqualPlans(originalPlan, bestPlan) + } + + private def assertEqualPlans( + originalPlan: LogicalPlan, + groundTruthBestPlan: LogicalPlan): Unit = { + val optimized = Optimize.execute(originalPlan.analyze) + val normalized1 = normalizePlan(normalizeExprIds(optimized)) + val normalized2 = normalizePlan(normalizeExprIds(groundTruthBestPlan.analyze)) + if (!sameJoinPlan(normalized1, normalized2)) { + fail( + s""" + |== FAIL: Plans do not match === + |${sideBySide(normalized1.treeString, normalized2.treeString).mkString("\n")} + """.stripMargin) + } + } + + /** Consider symmetry for joins when comparing plans. */ + private def sameJoinPlan(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = { + (plan1, plan2) match { + case (j1: Join, j2: Join) => + (sameJoinPlan(j1.left, j2.left) && sameJoinPlan(j1.right, j2.right)) || + (sameJoinPlan(j1.left, j2.right) && sameJoinPlan(j1.right, j2.left)) + case _ => + plan1 == plan2 + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 3b7e5e938a..e9b7a0c6ad 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -62,7 +62,7 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { * - Sample the seed will replaced by 0L. * - Join conditions will be resorted by hashCode. */ - private def normalizePlan(plan: LogicalPlan): LogicalPlan = { + protected def normalizePlan(plan: LogicalPlan): LogicalPlan = { plan transform { case filter @ Filter(condition: Expression, child: LogicalPlan) => Filter(splitConjunctivePredicates(condition).map(rewriteEqual(_)).sortBy(_.hashCode()) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala index c56b41ce37..9b2b8dbe1b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, Logica import org.apache.spark.sql.types.{IntegerType, StringType} -class StatsEstimationTestBase extends SparkFunSuite { +trait StatsEstimationTestBase extends SparkFunSuite { /** Enable stats estimation based on CBO. */ protected val conf = SimpleCatalystConf(caseSensitiveAnalysis = true, cboEnabled = true) @@ -48,7 +48,7 @@ class StatsEstimationTestBase extends SparkFunSuite { /** * This class is used for unit-testing. It's a logical plan whose output and stats are passed in. */ -protected case class StatsTestPlan( +case class StatsTestPlan( outputList: Seq[Attribute], rowCount: BigInt, attributeStats: AttributeMap[ColumnStat], -- cgit v1.2.3