From 81639115947a13017d1637549a8f66ba599b27b8 Mon Sep 17 00:00:00 2001 From: Ioana Delaney Date: Mon, 20 Mar 2017 16:04:58 +0800 Subject: [SPARK-17791][SQL] Join reordering using star schema detection ## What changes were proposed in this pull request? Star schema consists of one or more fact tables referencing a number of dimension tables. In general, queries against star schema are expected to run fast because of the established RI constraints among the tables. This design proposes a join reordering based on natural, generally accepted heuristics for star schema queries: - Finds the star join with the largest fact table and places it on the driving arm of the left-deep join. This plan avoids large tables on the inner, and thus favors hash joins. - Applies the most selective dimensions early in the plan to reduce the amount of data flow. The design document was included in SPARK-17791. Link to the google doc: [StarSchemaDetection](https://docs.google.com/document/d/1UAfwbm_A6wo7goHlVZfYK99pqDMEZUumi7pubJXETEA/edit?usp=sharing) ## How was this patch tested? A new test suite StarJoinSuite.scala was implemented. Author: Ioana Delaney Closes #15363 from ioana-delaney/starJoinReord2. --- .../spark/sql/catalyst/SimpleCatalystConf.scala | 1 + .../catalyst/optimizer/CostBasedJoinReorder.scala | 2 + .../spark/sql/catalyst/optimizer/Optimizer.scala | 2 +- .../spark/sql/catalyst/optimizer/joins.scala | 350 ++++++++++++- .../spark/sql/catalyst/planning/patterns.scala | 4 +- .../org/apache/spark/sql/internal/SQLConf.scala | 16 + .../catalyst/optimizer/JoinOptimizationSuite.scala | 4 +- .../sql/catalyst/optimizer/JoinReorderSuite.scala | 29 +- .../catalyst/optimizer/StarJoinReorderSuite.scala | 580 +++++++++++++++++++++ .../apache/spark/sql/catalyst/plans/PlanTest.scala | 26 + 10 files changed, 978 insertions(+), 36 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala index 0d4903e03b..ac97987c55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala @@ -40,6 +40,7 @@ case class SimpleCatalystConf( override val cboEnabled: Boolean = false, override val joinReorderEnabled: Boolean = false, override val joinReorderDPThreshold: Int = 12, + override val starSchemaDetection: Boolean = false, override val warehousePath: String = "/user/hive/warehouse", override val sessionLocalTimeZone: String = TimeZone.getDefault().getID, override val maxNestedViewDepth: Int = 100) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala index 1b32bda72b..521c468fe1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala @@ -53,6 +53,8 @@ case class CostBasedJoinReorder(conf: SQLConf) extends Rule[LogicalPlan] with Pr def reorder(plan: LogicalPlan, output: AttributeSet): LogicalPlan = { val (items, conditions) = extractInnerJoins(plan) + // TODO: Compute the set of star-joins and use them in the join enumeration + // algorithm to prune un-optimal plan choices. val result = // Do reordering if the number of items is appropriate and join conditions exist. // We also need to check if costs of all items can be evaluated. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index c8ed4190a1..d7524a57ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -82,7 +82,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) Batch("Operator Optimizations", fixedPoint, // Operator push down PushProjectionThroughUnion, - ReorderJoin, + ReorderJoin(conf), EliminateOuterJoin, PushPredicateThroughJoin, PushDownPredicate, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index bfe529e21e..58e4a230f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -20,19 +20,347 @@ package org.apache.spark.sql.catalyst.optimizer import scala.annotation.tailrec import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins +import org.apache.spark.sql.catalyst.planning.{ExtractFiltersAndInnerJoins, PhysicalOperation} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf + +/** + * Encapsulates star-schema join detection. + */ +case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper { + + /** + * Star schema consists of one or more fact tables referencing a number of dimension + * tables. In general, star-schema joins are detected using the following conditions: + * 1. Informational RI constraints (reliable detection) + * + Dimension contains a primary key that is being joined to the fact table. + * + Fact table contains foreign keys referencing multiple dimension tables. + * 2. Cardinality based heuristics + * + Usually, the table with the highest cardinality is the fact table. + * + Table being joined with the most number of tables is the fact table. + * + * To detect star joins, the algorithm uses a combination of the above two conditions. + * The fact table is chosen based on the cardinality heuristics, and the dimension + * tables are chosen based on the RI constraints. A star join will consist of the largest + * fact table joined with the dimension tables on their primary keys. To detect that a + * column is a primary key, the algorithm uses table and column statistics. + * + * Since Catalyst only supports left-deep tree plans, the algorithm currently returns only + * the star join with the largest fact table. Choosing the largest fact table on the + * driving arm to avoid large inners is in general a good heuristic. This restriction can + * be lifted with support for bushy tree plans. + * + * The highlights of the algorithm are the following: + * + * Given a set of joined tables/plans, the algorithm first verifies if they are eligible + * for star join detection. An eligible plan is a base table access with valid statistics. + * A base table access represents Project or Filter operators above a LeafNode. Conservatively, + * the algorithm only considers base table access as part of a star join since they provide + * reliable statistics. + * + * If some of the plans are not base table access, or statistics are not available, the algorithm + * returns an empty star join plan since, in the absence of statistics, it cannot make + * good planning decisions. Otherwise, the algorithm finds the table with the largest cardinality + * (number of rows), which is assumed to be a fact table. + * + * Next, it computes the set of dimension tables for the current fact table. A dimension table + * is assumed to be in a RI relationship with a fact table. To infer column uniqueness, + * the algorithm compares the number of distinct values with the total number of rows in the + * table. If their relative difference is within certain limits (i.e. ndvMaxError * 2, adjusted + * based on 1TB TPC-DS data), the column is assumed to be unique. + */ + def findStarJoins( + input: Seq[LogicalPlan], + conditions: Seq[Expression]): Seq[Seq[LogicalPlan]] = { + + val emptyStarJoinPlan = Seq.empty[Seq[LogicalPlan]] + + if (!conf.starSchemaDetection || input.size < 2) { + emptyStarJoinPlan + } else { + // Find if the input plans are eligible for star join detection. + // An eligible plan is a base table access with valid statistics. + val foundEligibleJoin = input.forall { + case PhysicalOperation(_, _, t: LeafNode) if t.stats(conf).rowCount.isDefined => true + case _ => false + } + + if (!foundEligibleJoin) { + // Some plans don't have stats or are complex plans. Conservatively, + // return an empty star join. This restriction can be lifted + // once statistics are propagated in the plan. + emptyStarJoinPlan + } else { + // Find the fact table using cardinality based heuristics i.e. + // the table with the largest number of rows. + val sortedFactTables = input.map { plan => + TableAccessCardinality(plan, getTableAccessCardinality(plan)) + }.collect { case t @ TableAccessCardinality(_, Some(_)) => + t + }.sortBy(_.size)(implicitly[Ordering[Option[BigInt]]].reverse) + + sortedFactTables match { + case Nil => + emptyStarJoinPlan + case table1 :: table2 :: _ + if table2.size.get.toDouble > conf.starSchemaFTRatio * table1.size.get.toDouble => + // If the top largest tables have comparable number of rows, return an empty star plan. + // This restriction will be lifted when the algorithm is generalized + // to return multiple star plans. + emptyStarJoinPlan + case TableAccessCardinality(factTable, _) :: rest => + // Find the fact table joins. + val allFactJoins = rest.collect { case TableAccessCardinality(plan, _) + if findJoinConditions(factTable, plan, conditions).nonEmpty => + plan + } + + // Find the corresponding join conditions. + val allFactJoinCond = allFactJoins.flatMap { plan => + val joinCond = findJoinConditions(factTable, plan, conditions) + joinCond + } + + // Verify if the join columns have valid statistics. + // Allow any relational comparison between the tables. Later + // we will heuristically choose a subset of equi-join + // tables. + val areStatsAvailable = allFactJoins.forall { dimTable => + allFactJoinCond.exists { + case BinaryComparison(lhs: AttributeReference, rhs: AttributeReference) => + val dimCol = if (dimTable.outputSet.contains(lhs)) lhs else rhs + val factCol = if (factTable.outputSet.contains(lhs)) lhs else rhs + hasStatistics(dimCol, dimTable) && hasStatistics(factCol, factTable) + case _ => false + } + } + + if (!areStatsAvailable) { + emptyStarJoinPlan + } else { + // Find the subset of dimension tables. A dimension table is assumed to be in a + // RI relationship with the fact table. Only consider equi-joins + // between a fact and a dimension table to avoid expanding joins. + val eligibleDimPlans = allFactJoins.filter { dimTable => + allFactJoinCond.exists { + case cond @ Equality(lhs: AttributeReference, rhs: AttributeReference) => + val dimCol = if (dimTable.outputSet.contains(lhs)) lhs else rhs + isUnique(dimCol, dimTable) + case _ => false + } + } + + if (eligibleDimPlans.isEmpty) { + // An eligible star join was not found because the join is not + // an RI join, or the star join is an expanding join. + emptyStarJoinPlan + } else { + Seq(factTable +: eligibleDimPlans) + } + } + } + } + } + } + + /** + * Reorders a star join based on heuristics: + * 1) Finds the star join with the largest fact table and places it on the driving + * arm of the left-deep tree. This plan avoids large table access on the inner, and + * thus favor hash joins. + * 2) Applies the most selective dimensions early in the plan to reduce the amount of + * data flow. + */ + def reorderStarJoins( + input: Seq[(LogicalPlan, InnerLike)], + conditions: Seq[Expression]): Seq[(LogicalPlan, InnerLike)] = { + assert(input.size >= 2) + + val emptyStarJoinPlan = Seq.empty[(LogicalPlan, InnerLike)] + + // Find the eligible star plans. Currently, it only returns + // the star join with the largest fact table. + val eligibleJoins = input.collect{ case (plan, Inner) => plan } + val starPlans = findStarJoins(eligibleJoins, conditions) + + if (starPlans.isEmpty) { + emptyStarJoinPlan + } else { + val starPlan = starPlans.head + val (factTable, dimTables) = (starPlan.head, starPlan.tail) + + // Only consider selective joins. This case is detected by observing local predicates + // on the dimension tables. In a star schema relationship, the join between the fact and the + // dimension table is a FK-PK join. Heuristically, a selective dimension may reduce + // the result of a join. + // Also, conservatively assume that a fact table is joined with more than one dimension. + if (dimTables.size >= 2 && isSelectiveStarJoin(dimTables, conditions)) { + val reorderDimTables = dimTables.map { plan => + TableAccessCardinality(plan, getTableAccessCardinality(plan)) + }.sortBy(_.size).map { + case TableAccessCardinality(p1, _) => p1 + } + + val reorderStarPlan = factTable +: reorderDimTables + reorderStarPlan.map(plan => (plan, Inner)) + } else { + emptyStarJoinPlan + } + } + } + + /** + * Determines if a column referenced by a base table access is a primary key. + * A column is a PK if it is not nullable and has unique values. + * To determine if a column has unique values in the absence of informational + * RI constraints, the number of distinct values is compared to the total + * number of rows in the table. If their relative difference + * is within the expected limits (i.e. 2 * spark.sql.statistics.ndv.maxError based + * on TPCDS data results), the column is assumed to have unique values. + */ + private def isUnique( + column: Attribute, + plan: LogicalPlan): Boolean = plan match { + case PhysicalOperation(_, _, t: LeafNode) => + val leafCol = findLeafNodeCol(column, plan) + leafCol match { + case Some(col) if t.outputSet.contains(col) => + val stats = t.stats(conf) + stats.rowCount match { + case Some(rowCount) if rowCount >= 0 => + if (stats.attributeStats.nonEmpty && stats.attributeStats.contains(col)) { + val colStats = stats.attributeStats.get(col) + if (colStats.get.nullCount > 0) { + false + } else { + val distinctCount = colStats.get.distinctCount + val relDiff = math.abs((distinctCount.toDouble / rowCount.toDouble) - 1.0d) + // ndvMaxErr adjusted based on TPCDS 1TB data results + relDiff <= conf.ndvMaxError * 2 + } + } else { + false + } + case None => false + } + case None => false + } + case _ => false + } + + /** + * Given a column over a base table access, it returns + * the leaf node column from which the input column is derived. + */ + @tailrec + private def findLeafNodeCol( + column: Attribute, + plan: LogicalPlan): Option[Attribute] = plan match { + case pl @ PhysicalOperation(_, _, _: LeafNode) => + pl match { + case t: LeafNode if t.outputSet.contains(column) => + Option(column) + case p: Project if p.outputSet.exists(_.semanticEquals(column)) => + val col = p.outputSet.find(_.semanticEquals(column)).get + findLeafNodeCol(col, p.child) + case f: Filter => + findLeafNodeCol(column, f.child) + case _ => None + } + case _ => None + } + + /** + * Checks if a column has statistics. + * The column is assumed to be over a base table access. + */ + private def hasStatistics( + column: Attribute, + plan: LogicalPlan): Boolean = plan match { + case PhysicalOperation(_, _, t: LeafNode) => + val leafCol = findLeafNodeCol(column, plan) + leafCol match { + case Some(col) if t.outputSet.contains(col) => + val stats = t.stats(conf) + stats.attributeStats.nonEmpty && stats.attributeStats.contains(col) + case None => false + } + case _ => false + } + + /** + * Returns the join predicates between two input plans. It only + * considers basic comparison operators. + */ + @inline + private def findJoinConditions( + plan1: LogicalPlan, + plan2: LogicalPlan, + conditions: Seq[Expression]): Seq[Expression] = { + val refs = plan1.outputSet ++ plan2.outputSet + conditions.filter { + case BinaryComparison(_, _) => true + case _ => false + }.filterNot(canEvaluate(_, plan1)) + .filterNot(canEvaluate(_, plan2)) + .filter(_.references.subsetOf(refs)) + } + + /** + * Checks if a star join is a selective join. A star join is assumed + * to be selective if there are local predicates on the dimension + * tables. + */ + private def isSelectiveStarJoin( + dimTables: Seq[LogicalPlan], + conditions: Seq[Expression]): Boolean = dimTables.exists { + case plan @ PhysicalOperation(_, p, _: LeafNode) => + // Checks if any condition applies to the dimension tables. + // Exclude the IsNotNull predicates until predicate selectivity is available. + // In most cases, this predicate is artificially introduced by the Optimizer + // to enforce nullability constraints. + val localPredicates = conditions.filterNot(_.isInstanceOf[IsNotNull]) + .exists(canEvaluate(_, plan)) + + // Checks if there are any predicates pushed down to the base table access. + val pushedDownPredicates = p.nonEmpty && !p.forall(_.isInstanceOf[IsNotNull]) + + localPredicates || pushedDownPredicates + case _ => false + } + + /** + * Helper case class to hold (plan, rowCount) pairs. + */ + private case class TableAccessCardinality(plan: LogicalPlan, size: Option[BigInt]) + + /** + * Returns the cardinality of a base table access. A base table access represents + * a LeafNode, or Project or Filter operators above a LeafNode. + */ + private def getTableAccessCardinality( + input: LogicalPlan): Option[BigInt] = input match { + case PhysicalOperation(_, cond, t: LeafNode) if t.stats(conf).rowCount.isDefined => + if (conf.cboEnabled && input.stats(conf).rowCount.isDefined) { + Option(input.stats(conf).rowCount.get) + } else { + Option(t.stats(conf).rowCount.get) + } + case _ => None + } +} /** * Reorder the joins and push all the conditions into join, so that the bottom ones have at least * one condition. * * The order of joins will not be changed if all of them already have at least one condition. + * + * If star schema detection is enabled, reorder the star join plans based on heuristics. */ -object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { - +case class ReorderJoin(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHelper { /** * Join a list of plans together and push down the conditions into them. * @@ -42,7 +370,7 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { * @param conditions a list of condition for join. */ @tailrec - def createOrderedJoin(input: Seq[(LogicalPlan, InnerLike)], conditions: Seq[Expression]) + final def createOrderedJoin(input: Seq[(LogicalPlan, InnerLike)], conditions: Seq[Expression]) : LogicalPlan = { assert(input.size >= 2) if (input.size == 2) { @@ -83,9 +411,19 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { } def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case j @ ExtractFiltersAndInnerJoins(input, conditions) + case ExtractFiltersAndInnerJoins(input, conditions) if input.size > 2 && conditions.nonEmpty => - createOrderedJoin(input, conditions) + if (conf.starSchemaDetection && !conf.cboEnabled) { + val starJoinPlan = StarSchemaDetection(conf).reorderStarJoins(input, conditions) + if (starJoinPlan.nonEmpty) { + val rest = input.filterNot(starJoinPlan.contains(_)) + createOrderedJoin(starJoinPlan ++ rest, conditions) + } else { + createOrderedJoin(input, conditions) + } + } else { + createOrderedJoin(input, conditions) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 0893af2673..d39b0ef7e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -167,8 +167,8 @@ object ExtractFiltersAndInnerJoins extends PredicateHelper { : (Seq[(LogicalPlan, InnerLike)], Seq[Expression]) = plan match { case Join(left, right, joinType: InnerLike, cond) => val (plans, conditions) = flattenJoin(left, joinType) - (plans ++ Seq((right, joinType)), conditions ++ cond.toSeq) - + (plans ++ Seq((right, joinType)), conditions ++ + cond.toSeq.flatMap(splitConjunctivePredicates)) case Filter(filterCondition, j @ Join(left, right, _: InnerLike, joinCondition)) => val (plans, conditions) = flattenJoin(j) (plans, conditions ++ splitConjunctivePredicates(filterCondition)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index d2ac4b88ee..b6e0b8ccbe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -719,6 +719,18 @@ object SQLConf { .checkValue(weight => weight >= 0 && weight <= 1, "The weight value must be in [0, 1].") .createWithDefault(0.7) + val STARSCHEMA_DETECTION = buildConf("spark.sql.cbo.starSchemaDetection") + .doc("When true, it enables join reordering based on star schema detection. ") + .booleanConf + .createWithDefault(false) + + val STARSCHEMA_FACT_TABLE_RATIO = buildConf("spark.sql.cbo.starJoinFTRatio") + .internal() + .doc("Specifies the upper limit of the ratio between the largest fact tables" + + " for a star join to be considered. ") + .doubleConf + .createWithDefault(0.9) + val SESSION_LOCAL_TIMEZONE = buildConf("spark.sql.session.timeZone") .doc("""The ID of session local timezone, e.g. "GMT", "America/Los_Angeles", etc.""") @@ -988,6 +1000,10 @@ class SQLConf extends Serializable with Logging { def maxNestedViewDepth: Int = getConf(SQLConf.MAX_NESTED_VIEW_DEPTH) + def starSchemaDetection: Boolean = getConf(STARSCHEMA_DETECTION) + + def starSchemaFTRatio: Double = getConf(STARSCHEMA_FACT_TABLE_RATIO) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala index 985e49069d..61e8180814 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins import org.apache.spark.sql.catalyst.plans.{Cross, Inner, InnerLike, PlanTest} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor - +import org.apache.spark.sql.catalyst.SimpleCatalystConf class JoinOptimizationSuite extends PlanTest { @@ -38,7 +38,7 @@ class JoinOptimizationSuite extends PlanTest { CombineFilters, PushDownPredicate, BooleanSimplification, - ReorderJoin, + ReorderJoin(SimpleCatalystConf(true)), PushPredicateThroughJoin, ColumnPruning, CollapseProject) :: Nil diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala index 5607bcd16f..05b839b011 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala @@ -22,10 +22,9 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, StatsTestPlan} -import org.apache.spark.sql.catalyst.util._ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { @@ -38,7 +37,7 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { Batch("Operator Optimizations", FixedPoint(100), CombineFilters, PushDownPredicate, - ReorderJoin, + ReorderJoin(conf), PushPredicateThroughJoin, ColumnPruning, CollapseProject) :: @@ -203,27 +202,7 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { originalPlan: LogicalPlan, groundTruthBestPlan: LogicalPlan): Unit = { val optimized = Optimize.execute(originalPlan.analyze) - val normalized1 = normalizePlan(normalizeExprIds(optimized)) - val normalized2 = normalizePlan(normalizeExprIds(groundTruthBestPlan.analyze)) - if (!sameJoinPlan(normalized1, normalized2)) { - fail( - s""" - |== FAIL: Plans do not match === - |${sideBySide(normalized1.treeString, normalized2.treeString).mkString("\n")} - """.stripMargin) - } - } - - /** Consider symmetry for joins when comparing plans. */ - private def sameJoinPlan(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = { - (plan1, plan2) match { - case (j1: Join, j2: Join) => - (sameJoinPlan(j1.left, j2.left) && sameJoinPlan(j1.right, j2.right)) || - (sameJoinPlan(j1.left, j2.right) && sameJoinPlan(j1.right, j2.left)) - case _ if plan1.children.nonEmpty && plan2.children.nonEmpty => - (plan1.children, plan2.children).zipped.forall { case (c1, c2) => sameJoinPlan(c1, c2) } - case _ => - plan1 == plan2 - } + val expected = groundTruthBestPlan.analyze + compareJoinOrder(optimized, expected) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala new file mode 100644 index 0000000000..93fdd98d1a --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala @@ -0,0 +1,580 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.SimpleCatalystConf +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} +import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, StatsTestPlan} + + +class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase { + + override val conf = SimpleCatalystConf( + caseSensitiveAnalysis = true, starSchemaDetection = true) + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Operator Optimizations", FixedPoint(100), + CombineFilters, + PushDownPredicate, + ReorderJoin(conf), + PushPredicateThroughJoin, + ColumnPruning, + CollapseProject) :: Nil + } + + // Table setup using star schema relationships: + // + // d1 - f1 - d2 + // | + // d3 - s3 + // + // Table f1 is the fact table. Tables d1, d2, and d3 are the dimension tables. + // Dimension d3 is further joined/normalized into table s3. + // Tables' cardinality: f1 > d3 > d1 > d2 > s3 + private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( + // F1 + attr("f1_fk1") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("f1_fk2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("f1_fk3") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("f1_c4") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), + nullCount = 0, avgLen = 4, maxLen = 4), + // D1 + attr("d1_pk1") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d1_c2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d1_c3") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d1_c4") -> ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + // D2 + attr("d2_c2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("d2_pk1") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d2_c3") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d2_c4") -> ColumnStat(distinctCount = 2, min = Some(3), max = Some(4), + nullCount = 0, avgLen = 4, maxLen = 4), + // D3 + attr("d3_fk1") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d3_c2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d3_pk1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d3_c4") -> ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + // S3 + attr("s3_pk1") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("s3_c2") -> ColumnStat(distinctCount = 1, min = Some(3), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("s3_c3") -> ColumnStat(distinctCount = 1, min = Some(3), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("s3_c4") -> ColumnStat(distinctCount = 2, min = Some(3), max = Some(4), + nullCount = 0, avgLen = 4, maxLen = 4), + // F11 + attr("f11_fk1") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("f11_fk2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("f11_fk3") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("f11_c4") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), + nullCount = 0, avgLen = 4, maxLen = 4) + )) + + private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1) + private val nameToColInfo: Map[String, (Attribute, ColumnStat)] = + columnInfo.map(kv => kv._1.name -> kv) + + private val f1 = StatsTestPlan( + outputList = Seq("f1_fk1", "f1_fk2", "f1_fk3", "f1_c4").map(nameToAttr), + rowCount = 6, + size = Some(48), + attributeStats = AttributeMap(Seq("f1_fk1", "f1_fk2", "f1_fk3", "f1_c4").map(nameToColInfo))) + + private val d1 = StatsTestPlan( + outputList = Seq("d1_pk1", "d1_c2", "d1_c3", "d1_c4").map(nameToAttr), + rowCount = 4, + size = Some(32), + attributeStats = AttributeMap(Seq("d1_pk1", "d1_c2", "d1_c3", "d1_c4").map(nameToColInfo))) + + private val d2 = StatsTestPlan( + outputList = Seq("d2_c2", "d2_pk1", "d2_c3", "d2_c4").map(nameToAttr), + rowCount = 3, + size = Some(24), + attributeStats = AttributeMap(Seq("d2_c2", "d2_pk1", "d2_c3", "d2_c4").map(nameToColInfo))) + + private val d3 = StatsTestPlan( + outputList = Seq("d3_fk1", "d3_c2", "d3_pk1", "d3_c4").map(nameToAttr), + rowCount = 5, + size = Some(40), + attributeStats = AttributeMap(Seq("d3_fk1", "d3_c2", "d3_pk1", "d3_c4").map(nameToColInfo))) + + private val s3 = StatsTestPlan( + outputList = Seq("s3_pk1", "s3_c2", "s3_c3", "s3_c4").map(nameToAttr), + rowCount = 2, + size = Some(17), + attributeStats = AttributeMap(Seq("s3_pk1", "s3_c2", "s3_c3", "s3_c4").map(nameToColInfo))) + + private val d3_ns = LocalRelation('d3_fk1.int, 'd3_c2.int, 'd3_pk1.int, 'd3_c4.int) + + private val f11 = StatsTestPlan( + outputList = Seq("f11_fk1", "f11_fk2", "f11_fk3", "f11_c4").map(nameToAttr), + rowCount = 6, + size = Some(48), + attributeStats = AttributeMap(Seq("f11_fk1", "f11_fk2", "f11_fk3", "f11_c4") + .map(nameToColInfo))) + + private val subq = d3.select(sum('d3_fk1).as('col)) + + test("Test 1: Selective star-join on all dimensions") { + // Star join: + // (=) (=) + // d1 - f1 - d2 + // | (=) + // s3 - d3 + // + // Query: + // select f1_fk1, f1_fk3 + // from d1, d2, f1, d3, s3 + // where f1_fk2 = d2_pk1 and d2_c2 < 2 + // and f1_fk1 = d1_pk1 + // and f1_fk3 = d3_pk1 + // and d3_fk1 = s3_pk1 + // + // Positional join reordering: d1, f1, d2, d3, s3 + // Star join reordering: f1, d2, d1, d3, s3 + val query = + d1.join(d2).join(f1).join(d3).join(s3) + .where((nameToAttr("f1_fk2") === nameToAttr("d2_pk1")) && + (nameToAttr("d2_c2") === 2) && + (nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) && + (nameToAttr("f1_fk3") === nameToAttr("d3_pk1")) && + (nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + val expected = + f1.join(d2.where(nameToAttr("d2_c2") === 2), Inner, + Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1"))) + .join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1"))) + .join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1"))) + .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + assertEqualPlans(query, expected) + } + + test("Test 2: Star join on a subset of dimensions due to inequality joins") { + // Star join: + // (=) (<) + // d1 - f1 - d2 + // | + // | (=) + // d3 - s3 + // (=) + // + // Query: + // select f1_fk1, f1_fk3 + // from d1, f1, d2, s3, d3 + // where f1_fk2 < d2_pk1 + // and f1_fk1 = d1_pk1 and d1_c2 = 2 + // and f1_fk3 = d3_pk1 + // and d3_fk1 = s3_pk1 + // + // Default join reordering: d1, f1, d2, d3, s3 + // Star join reordering: f1, d1, d3, d2,, d3 + + val query = + d1.join(f1).join(d2).join(s3).join(d3) + .where((nameToAttr("f1_fk2") < nameToAttr("d2_pk1")) && + (nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) && + (nameToAttr("d1_c2") === 2) && + (nameToAttr("f1_fk3") === nameToAttr("d3_pk1")) && + (nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + val expected = + f1.join(d1.where(nameToAttr("d1_c2") === 2), Inner, + Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1"))) + .join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1"))) + .join(d2, Inner, Some(nameToAttr("f1_fk2") < nameToAttr("d2_pk1"))) + .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + assertEqualPlans(query, expected) + } + + test("Test 3: Star join on a subset of dimensions since join column is not unique") { + // Star join: + // (=) (=) + // d1 - f1 - d2 + // | (=) + // d3 - s3 + // + // Query: + // select f1_fk1, f1_fk3 + // from d1, f1, d2, s3, d3 + // where f1_fk2 = d2_c4 + // and f1_fk1 = d1_pk1 and d1_c2 = 2 + // and f1_fk3 = d3_pk1 + // and d3_fk1 = s3_pk1 + // + // Default join reordering: d1, f1, d2, d3, s3 + // Star join reordering: f1, d1, d3, d2, d3 + val query = + d1.join(f1).join(d2).join(s3).join(d3) + .where((nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) && + (nameToAttr("d1_c2") === 2) && + (nameToAttr("f1_fk2") === nameToAttr("d2_c4")) && + (nameToAttr("f1_fk3") === nameToAttr("d3_pk1")) && + (nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + val expected = + f1.join(d1.where(nameToAttr("d1_c2") === 2), Inner, + Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1"))) + .join(d3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + .join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1"))) + .join(s3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("s3_c2"))) + + + assertEqualPlans(query, expected) + } + + test("Test 4: Star join on a subset of dimensions since join column is nullable") { + // Star join: + // (=) (=) + // d1 - f1 - d2 + // | (=) + // s3 - d3 + // + // Query: + // select f1_fk1, f1_fk3 + // from d1, f1, d2, s3, d3 + // where f1_fk2 = d2_c2 + // and f1_fk1 = d1_pk1 and d1_c2 = 2 + // and f1_fk3 = d3_pk1 + // and d3_fk1 = s3_pk1 + // + // Default join reordering: d1, f1, d2, d3, s3 + // Star join reordering: f1, d1, d3, d2, s3 + + val query = + d1.join(f1).join(d2).join(s3).join(d3) + .where((nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) && + (nameToAttr("d1_c2") === 2) && + (nameToAttr("f1_fk2") === nameToAttr("d2_c2")) && + (nameToAttr("f1_fk3") === nameToAttr("d3_pk1")) && + (nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + val expected = + f1.join(d1.where(nameToAttr("d1_c2") === 2), Inner, + Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1"))) + .join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1"))) + .join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_c2"))) + .join(s3, Inner, Some(nameToAttr("d3_fk1") < nameToAttr("s3_pk1"))) + + assertEqualPlans(query, expected) + } + + test("Test 5: Table stats not available for some of the joined tables") { + // Star join: + // (=) (=) + // d1 - f1 - d2 + // | (=) + // d3_ns - s3 + // + // select f1_fk1, f1_fk3 + // from d3_ns, f1, d1, d2, s3 + // where f1_fk2 = d2_pk1 and d2_c2 = 2 + // and f1_fk1 = d1_pk1 + // and f1_fk3 = d3_pk1 + // and d3_fk1 = s3_pk1 + // + // Positional join reordering: d3_ns, f1, d1, d2, s3 + // Star join reordering: empty + + val query = + d3_ns.join(f1).join(d1).join(d2).join(s3) + .where((nameToAttr("f1_fk2") === nameToAttr("d2_pk1")) && + (nameToAttr("d2_c2") === 2) && + (nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) && + (nameToAttr("f1_fk3") === nameToAttr("d3_pk1")) && + (nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + val equivQuery = + d3_ns.join(f1, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1"))) + .join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1"))) + .join(d2.where(nameToAttr("d2_c2") === 2), Inner, + Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1"))) + .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + assertEqualPlans(query, equivQuery) + } + + test("Test 6: Join with complex plans") { + // Star join: + // (=) (=) + // d1 - f1 - d2 + // | (=) + // (sub-query) + // + // select f1_fk1, f1_fk3 + // from (select sum(d3_fk1) as col from d3) subq, f1, d1, d2 + // where f1_fk2 = d2_pk1 and d2_c2 < 2 + // and f1_fk1 = d1_pk1 + // and f1_fk3 = sq.col + // + // Positional join reordering: d3, f1, d1, d2 + // Star join reordering: empty + + val query = + subq.join(f1).join(d1).join(d2) + .where((nameToAttr("f1_fk2") === nameToAttr("d2_pk1")) && + (nameToAttr("d2_c2") === 2) && + (nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) && + (nameToAttr("f1_fk3") === "col".attr)) + + val expected = + d3.select('d3_fk1).select(sum('d3_fk1).as('col)) + .join(f1, Inner, Some(nameToAttr("f1_fk3") === "col".attr)) + .join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1"))) + .join(d2.where(nameToAttr("d2_c2") === 2), Inner, + Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1"))) + + assertEqualPlans(query, expected) + } + + test("Test 7: Comparable fact table sizes") { + // Star join: + // (=) (=) + // d1 - f1 - d2 + // | (=) + // f11 - s3 + // + // select f1.f1_fk1, f1.f1_fk3 + // from d1, f11, f1, d2, s3 + // where f1.f1_fk2 = d2_pk1 and d2_c2 = 2 + // and f1.f1_fk1 = d1_pk1 + // and f1.f1_fk3 = f11.f1_fk3 + // and f11.f1_fk1 = s3_pk1 + // + // Positional join reordering: d1, f1, f11, d2, s3 + // Star join reordering: empty + + val query = + d1.join(f11).join(f1).join(d2).join(s3) + .where((nameToAttr("f1_fk2") === nameToAttr("d2_pk1")) && + (nameToAttr("d2_c2") === 2) && + (nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) && + (nameToAttr("f1_fk3") === nameToAttr("f11_fk3")) && + (nameToAttr("f11_fk1") === nameToAttr("s3_pk1"))) + + val equivQuery = + d1.join(f1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1"))) + .join(f11, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("f11_fk3"))) + .join(d2.where(nameToAttr("d2_c2") === 2), Inner, + Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1"))) + .join(s3, Inner, Some(nameToAttr("f11_fk1") === nameToAttr("s3_pk1"))) + + assertEqualPlans(query, equivQuery) + } + + test("Test 8: No RI joins") { + // Star join: + // (=) (=) + // d1 - f1 - d2 + // | (=) + // d3 - s3 + // + // select f1_fk1, f1_fk3 + // from d1, d3, f1, d2, s3 + // where f1_fk2 = d2_c4 and d2_c2 = 2 + // and f1_fk1 = d1_c4 + // and f1_fk3 = d3_c4 + // and d3_fk1 = s3_pk1 + // + // Positional/default join reordering: d1, f1, d3, d2, s3 + // Star join reordering: empty + + val query = + d1.join(d3).join(f1).join(d2).join(s3) + .where((nameToAttr("f1_fk2") === nameToAttr("d2_c4")) && + (nameToAttr("d2_c2") === 2) && + (nameToAttr("f1_fk1") === nameToAttr("d1_c4")) && + (nameToAttr("f1_fk3") === nameToAttr("d3_c4")) && + (nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + val expected = + d1.join(f1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_c4"))) + .join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_c4"))) + .join(d2.where(nameToAttr("d2_c2") === 2), Inner, + Some(nameToAttr("f1_fk2") === nameToAttr("d2_c4"))) + .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + assertEqualPlans(query, expected) + } + + test("Test 9: Complex join predicates") { + // Star join: + // (=) (=) + // d1 - f1 - d2 + // | (=) + // d3 - s3 + // + // select f1_fk1, f1_fk3 + // from d1, d3, f1, d2, s3 + // where f1_fk2 = d2_pk1 and d2_c2 = 2 + // and abs(f1_fk1) = d1_pk1 + // and f1_fk3 = d3_pk1 + // and d3_fk1 = s3_pk1 + // + // Positional/default join reordering: d1, f1, d3, d2, s3 + // Star join reordering: empty + + val query = + d1.join(d3).join(f1).join(d2).join(s3) + .where((nameToAttr("f1_fk2") === nameToAttr("d2_pk1")) && + (nameToAttr("d2_c2") === 2) && + (abs(nameToAttr("f1_fk1")) === nameToAttr("d1_pk1")) && + (nameToAttr("f1_fk3") === nameToAttr("d3_pk1")) && + (nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + val expected = + d1.join(f1, Inner, Some(abs(nameToAttr("f1_fk1")) === nameToAttr("d1_pk1"))) + .join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1"))) + .join(d2.where(nameToAttr("d2_c2") === 2), Inner, + Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1"))) + .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + assertEqualPlans(query, expected) + } + + test("Test 10: Less than two dimensions") { + // Star join: + // (<) (=) + // d1 - f1 - d2 + // |(<) + // d3 - s3 + // + // select f1_fk1, f1_fk3 + // from d1, d3, f1, d2, s3 + // where f1_fk2 = d2_pk1 and d2_c2 = 2 + // and f1_fk1 < d1_pk1 + // and f1_fk3 < d3_pk1 + // + // Positional join reordering: d1, f1, d3, d2, s3 + // Star join reordering: empty + + val query = + d1.join(d3).join(f1).join(d2).join(s3) + .where((nameToAttr("f1_fk2") === nameToAttr("d2_pk1")) && + (nameToAttr("d2_c2") === 2) && + (nameToAttr("f1_fk1") < nameToAttr("d1_pk1")) && + (nameToAttr("f1_fk3") < nameToAttr("d3_pk1")) && + (nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + val expected = + d1.join(f1, Inner, Some(nameToAttr("f1_fk1") < nameToAttr("d1_pk1"))) + .join(d3, Inner, Some(nameToAttr("f1_fk3") < nameToAttr("d3_pk1"))) + .join(d2.where(nameToAttr("d2_c2") === 2), + Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1"))) + .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + assertEqualPlans(query, expected) + } + + test("Test 11: Expanding star join") { + // Star join: + // (<) (<) + // d1 - f1 - d2 + // | (<) + // d3 - s3 + // + // select f1_fk1, f1_fk3 + // from d1, d3, f1, d2, s3 + // where f1_fk2 < d2_pk1 + // and f1_fk1 < d1_pk1 + // and f1_fk3 < d3_pk1 + // and d3_fk1 < s3_pk1 + // + // Positional join reordering: d1, f1, d3, d2, s3 + // Star join reordering: empty + + val query = + d1.join(d3).join(f1).join(d2).join(s3) + .where((nameToAttr("f1_fk2") < nameToAttr("d2_pk1")) && + (nameToAttr("f1_fk1") < nameToAttr("d1_pk1")) && + (nameToAttr("f1_fk3") < nameToAttr("d3_pk1")) && + (nameToAttr("d3_fk1") < nameToAttr("s3_pk1"))) + + val expected = + d1.join(f1, Inner, Some(nameToAttr("f1_fk1") < nameToAttr("d1_pk1"))) + .join(d3, Inner, Some(nameToAttr("f1_fk3") < nameToAttr("d3_pk1"))) + .join(d2, Inner, Some(nameToAttr("f1_fk2") < nameToAttr("d2_pk1"))) + .join(s3, Inner, Some(nameToAttr("d3_fk1") < nameToAttr("s3_pk1"))) + + assertEqualPlans(query, expected) + } + + test("Test 12: Non selective star join") { + // Star join: + // (=) (=) + // d1 - f1 - d2 + // | (=) + // d3 - s3 + // + // select f1_fk1, f1_fk3 + // from d1, d3, f1, d2, s3 + // where f1_fk2 = d2_pk1 + // and f1_fk1 = d1_pk1 + // and f1_fk3 = d3_pk1 + // and d3_fk1 = s3_pk1 + // + // Positional join reordering: d1, f1, d3, d2, s3 + // Star join reordering: empty + + val query = + d1.join(d3).join(f1).join(d2).join(s3) + .where((nameToAttr("f1_fk2") === nameToAttr("d2_pk1")) && + (nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) && + (nameToAttr("f1_fk3") === nameToAttr("d3_pk1")) && + (nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + val expected = + d1.join(f1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1"))) + .join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1"))) + .join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1"))) + .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + assertEqualPlans(query, expected) + } + + private def assertEqualPlans( plan1: LogicalPlan, plan2: LogicalPlan): Unit = { + val optimized = Optimize.execute(plan1.analyze) + val expected = plan2.analyze + compareJoinOrder(optimized, expected) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 5eb31413ad..2a9d057014 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -106,4 +106,30 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { protected def compareExpressions(e1: Expression, e2: Expression): Unit = { comparePlans(Filter(e1, OneRowRelation), Filter(e2, OneRowRelation)) } + + /** Fails the test if the join order in the two plans do not match */ + protected def compareJoinOrder(plan1: LogicalPlan, plan2: LogicalPlan) { + val normalized1 = normalizePlan(normalizeExprIds(plan1)) + val normalized2 = normalizePlan(normalizeExprIds(plan2)) + if (!sameJoinPlan(normalized1, normalized2)) { + fail( + s""" + |== FAIL: Plans do not match === + |${sideBySide(normalized1.treeString, normalized2.treeString).mkString("\n")} + """.stripMargin) + } + } + + /** Consider symmetry for joins when comparing plans. */ + private def sameJoinPlan(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = { + (plan1, plan2) match { + case (j1: Join, j2: Join) => + (sameJoinPlan(j1.left, j2.left) && sameJoinPlan(j1.right, j2.right)) || + (sameJoinPlan(j1.left, j2.right) && sameJoinPlan(j1.right, j2.left)) + case _ if plan1.children.nonEmpty && plan2.children.nonEmpty => + (plan1.children, plan2.children).zipped.forall { case (c1, c2) => sameJoinPlan(c1, c2) } + case _ => + plan1 == plan2 + } + } } -- cgit v1.2.3