aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorIoana Delaney <ioanamdelaney@gmail.com>2017-03-20 16:04:58 +0800
committerWenchen Fan <wenchen@databricks.com>2017-03-20 16:04:58 +0800
commit81639115947a13017d1637549a8f66ba599b27b8 (patch)
tree03261c3f1a576c1f17f594e1f69cd6fcb6f5ef7c /sql/catalyst
parentf14f81e900e2e6c216055799584148a2c944268d (diff)
downloadspark-81639115947a13017d1637549a8f66ba599b27b8.tar.gz
spark-81639115947a13017d1637549a8f66ba599b27b8.tar.bz2
spark-81639115947a13017d1637549a8f66ba599b27b8.zip
[SPARK-17791][SQL] Join reordering using star schema detection
## What changes were proposed in this pull request? Star schema consists of one or more fact tables referencing a number of dimension tables. In general, queries against star schema are expected to run fast because of the established RI constraints among the tables. This design proposes a join reordering based on natural, generally accepted heuristics for star schema queries: - Finds the star join with the largest fact table and places it on the driving arm of the left-deep join. This plan avoids large tables on the inner, and thus favors hash joins. - Applies the most selective dimensions early in the plan to reduce the amount of data flow. The design document was included in SPARK-17791. Link to the google doc: [StarSchemaDetection](https://docs.google.com/document/d/1UAfwbm_A6wo7goHlVZfYK99pqDMEZUumi7pubJXETEA/edit?usp=sharing) ## How was this patch tested? A new test suite StarJoinSuite.scala was implemented. Author: Ioana Delaney <ioanamdelaney@gmail.com> Closes #15363 from ioana-delaney/starJoinReord2.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala350
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala16
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala29
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala580
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala26
10 files changed, 978 insertions, 36 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala
index 0d4903e03b..ac97987c55 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala
@@ -40,6 +40,7 @@ case class SimpleCatalystConf(
override val cboEnabled: Boolean = false,
override val joinReorderEnabled: Boolean = false,
override val joinReorderDPThreshold: Int = 12,
+ override val starSchemaDetection: Boolean = false,
override val warehousePath: String = "/user/hive/warehouse",
override val sessionLocalTimeZone: String = TimeZone.getDefault().getID,
override val maxNestedViewDepth: Int = 100)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala
index 1b32bda72b..521c468fe1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala
@@ -53,6 +53,8 @@ case class CostBasedJoinReorder(conf: SQLConf) extends Rule[LogicalPlan] with Pr
def reorder(plan: LogicalPlan, output: AttributeSet): LogicalPlan = {
val (items, conditions) = extractInnerJoins(plan)
+ // TODO: Compute the set of star-joins and use them in the join enumeration
+ // algorithm to prune un-optimal plan choices.
val result =
// Do reordering if the number of items is appropriate and join conditions exist.
// We also need to check if costs of all items can be evaluated.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index c8ed4190a1..d7524a57ad 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -82,7 +82,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
Batch("Operator Optimizations", fixedPoint,
// Operator push down
PushProjectionThroughUnion,
- ReorderJoin,
+ ReorderJoin(conf),
EliminateOuterJoin,
PushPredicateThroughJoin,
PushDownPredicate,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
index bfe529e21e..58e4a230f4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
@@ -20,19 +20,347 @@ package org.apache.spark.sql.catalyst.optimizer
import scala.annotation.tailrec
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins
+import org.apache.spark.sql.catalyst.planning.{ExtractFiltersAndInnerJoins, PhysicalOperation}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
+import org.apache.spark.sql.internal.SQLConf
+
+/**
+ * Encapsulates star-schema join detection.
+ */
+case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper {
+
+ /**
+ * Star schema consists of one or more fact tables referencing a number of dimension
+ * tables. In general, star-schema joins are detected using the following conditions:
+ * 1. Informational RI constraints (reliable detection)
+ * + Dimension contains a primary key that is being joined to the fact table.
+ * + Fact table contains foreign keys referencing multiple dimension tables.
+ * 2. Cardinality based heuristics
+ * + Usually, the table with the highest cardinality is the fact table.
+ * + Table being joined with the most number of tables is the fact table.
+ *
+ * To detect star joins, the algorithm uses a combination of the above two conditions.
+ * The fact table is chosen based on the cardinality heuristics, and the dimension
+ * tables are chosen based on the RI constraints. A star join will consist of the largest
+ * fact table joined with the dimension tables on their primary keys. To detect that a
+ * column is a primary key, the algorithm uses table and column statistics.
+ *
+ * Since Catalyst only supports left-deep tree plans, the algorithm currently returns only
+ * the star join with the largest fact table. Choosing the largest fact table on the
+ * driving arm to avoid large inners is in general a good heuristic. This restriction can
+ * be lifted with support for bushy tree plans.
+ *
+ * The highlights of the algorithm are the following:
+ *
+ * Given a set of joined tables/plans, the algorithm first verifies if they are eligible
+ * for star join detection. An eligible plan is a base table access with valid statistics.
+ * A base table access represents Project or Filter operators above a LeafNode. Conservatively,
+ * the algorithm only considers base table access as part of a star join since they provide
+ * reliable statistics.
+ *
+ * If some of the plans are not base table access, or statistics are not available, the algorithm
+ * returns an empty star join plan since, in the absence of statistics, it cannot make
+ * good planning decisions. Otherwise, the algorithm finds the table with the largest cardinality
+ * (number of rows), which is assumed to be a fact table.
+ *
+ * Next, it computes the set of dimension tables for the current fact table. A dimension table
+ * is assumed to be in a RI relationship with a fact table. To infer column uniqueness,
+ * the algorithm compares the number of distinct values with the total number of rows in the
+ * table. If their relative difference is within certain limits (i.e. ndvMaxError * 2, adjusted
+ * based on 1TB TPC-DS data), the column is assumed to be unique.
+ */
+ def findStarJoins(
+ input: Seq[LogicalPlan],
+ conditions: Seq[Expression]): Seq[Seq[LogicalPlan]] = {
+
+ val emptyStarJoinPlan = Seq.empty[Seq[LogicalPlan]]
+
+ if (!conf.starSchemaDetection || input.size < 2) {
+ emptyStarJoinPlan
+ } else {
+ // Find if the input plans are eligible for star join detection.
+ // An eligible plan is a base table access with valid statistics.
+ val foundEligibleJoin = input.forall {
+ case PhysicalOperation(_, _, t: LeafNode) if t.stats(conf).rowCount.isDefined => true
+ case _ => false
+ }
+
+ if (!foundEligibleJoin) {
+ // Some plans don't have stats or are complex plans. Conservatively,
+ // return an empty star join. This restriction can be lifted
+ // once statistics are propagated in the plan.
+ emptyStarJoinPlan
+ } else {
+ // Find the fact table using cardinality based heuristics i.e.
+ // the table with the largest number of rows.
+ val sortedFactTables = input.map { plan =>
+ TableAccessCardinality(plan, getTableAccessCardinality(plan))
+ }.collect { case t @ TableAccessCardinality(_, Some(_)) =>
+ t
+ }.sortBy(_.size)(implicitly[Ordering[Option[BigInt]]].reverse)
+
+ sortedFactTables match {
+ case Nil =>
+ emptyStarJoinPlan
+ case table1 :: table2 :: _
+ if table2.size.get.toDouble > conf.starSchemaFTRatio * table1.size.get.toDouble =>
+ // If the top largest tables have comparable number of rows, return an empty star plan.
+ // This restriction will be lifted when the algorithm is generalized
+ // to return multiple star plans.
+ emptyStarJoinPlan
+ case TableAccessCardinality(factTable, _) :: rest =>
+ // Find the fact table joins.
+ val allFactJoins = rest.collect { case TableAccessCardinality(plan, _)
+ if findJoinConditions(factTable, plan, conditions).nonEmpty =>
+ plan
+ }
+
+ // Find the corresponding join conditions.
+ val allFactJoinCond = allFactJoins.flatMap { plan =>
+ val joinCond = findJoinConditions(factTable, plan, conditions)
+ joinCond
+ }
+
+ // Verify if the join columns have valid statistics.
+ // Allow any relational comparison between the tables. Later
+ // we will heuristically choose a subset of equi-join
+ // tables.
+ val areStatsAvailable = allFactJoins.forall { dimTable =>
+ allFactJoinCond.exists {
+ case BinaryComparison(lhs: AttributeReference, rhs: AttributeReference) =>
+ val dimCol = if (dimTable.outputSet.contains(lhs)) lhs else rhs
+ val factCol = if (factTable.outputSet.contains(lhs)) lhs else rhs
+ hasStatistics(dimCol, dimTable) && hasStatistics(factCol, factTable)
+ case _ => false
+ }
+ }
+
+ if (!areStatsAvailable) {
+ emptyStarJoinPlan
+ } else {
+ // Find the subset of dimension tables. A dimension table is assumed to be in a
+ // RI relationship with the fact table. Only consider equi-joins
+ // between a fact and a dimension table to avoid expanding joins.
+ val eligibleDimPlans = allFactJoins.filter { dimTable =>
+ allFactJoinCond.exists {
+ case cond @ Equality(lhs: AttributeReference, rhs: AttributeReference) =>
+ val dimCol = if (dimTable.outputSet.contains(lhs)) lhs else rhs
+ isUnique(dimCol, dimTable)
+ case _ => false
+ }
+ }
+
+ if (eligibleDimPlans.isEmpty) {
+ // An eligible star join was not found because the join is not
+ // an RI join, or the star join is an expanding join.
+ emptyStarJoinPlan
+ } else {
+ Seq(factTable +: eligibleDimPlans)
+ }
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * Reorders a star join based on heuristics:
+ * 1) Finds the star join with the largest fact table and places it on the driving
+ * arm of the left-deep tree. This plan avoids large table access on the inner, and
+ * thus favor hash joins.
+ * 2) Applies the most selective dimensions early in the plan to reduce the amount of
+ * data flow.
+ */
+ def reorderStarJoins(
+ input: Seq[(LogicalPlan, InnerLike)],
+ conditions: Seq[Expression]): Seq[(LogicalPlan, InnerLike)] = {
+ assert(input.size >= 2)
+
+ val emptyStarJoinPlan = Seq.empty[(LogicalPlan, InnerLike)]
+
+ // Find the eligible star plans. Currently, it only returns
+ // the star join with the largest fact table.
+ val eligibleJoins = input.collect{ case (plan, Inner) => plan }
+ val starPlans = findStarJoins(eligibleJoins, conditions)
+
+ if (starPlans.isEmpty) {
+ emptyStarJoinPlan
+ } else {
+ val starPlan = starPlans.head
+ val (factTable, dimTables) = (starPlan.head, starPlan.tail)
+
+ // Only consider selective joins. This case is detected by observing local predicates
+ // on the dimension tables. In a star schema relationship, the join between the fact and the
+ // dimension table is a FK-PK join. Heuristically, a selective dimension may reduce
+ // the result of a join.
+ // Also, conservatively assume that a fact table is joined with more than one dimension.
+ if (dimTables.size >= 2 && isSelectiveStarJoin(dimTables, conditions)) {
+ val reorderDimTables = dimTables.map { plan =>
+ TableAccessCardinality(plan, getTableAccessCardinality(plan))
+ }.sortBy(_.size).map {
+ case TableAccessCardinality(p1, _) => p1
+ }
+
+ val reorderStarPlan = factTable +: reorderDimTables
+ reorderStarPlan.map(plan => (plan, Inner))
+ } else {
+ emptyStarJoinPlan
+ }
+ }
+ }
+
+ /**
+ * Determines if a column referenced by a base table access is a primary key.
+ * A column is a PK if it is not nullable and has unique values.
+ * To determine if a column has unique values in the absence of informational
+ * RI constraints, the number of distinct values is compared to the total
+ * number of rows in the table. If their relative difference
+ * is within the expected limits (i.e. 2 * spark.sql.statistics.ndv.maxError based
+ * on TPCDS data results), the column is assumed to have unique values.
+ */
+ private def isUnique(
+ column: Attribute,
+ plan: LogicalPlan): Boolean = plan match {
+ case PhysicalOperation(_, _, t: LeafNode) =>
+ val leafCol = findLeafNodeCol(column, plan)
+ leafCol match {
+ case Some(col) if t.outputSet.contains(col) =>
+ val stats = t.stats(conf)
+ stats.rowCount match {
+ case Some(rowCount) if rowCount >= 0 =>
+ if (stats.attributeStats.nonEmpty && stats.attributeStats.contains(col)) {
+ val colStats = stats.attributeStats.get(col)
+ if (colStats.get.nullCount > 0) {
+ false
+ } else {
+ val distinctCount = colStats.get.distinctCount
+ val relDiff = math.abs((distinctCount.toDouble / rowCount.toDouble) - 1.0d)
+ // ndvMaxErr adjusted based on TPCDS 1TB data results
+ relDiff <= conf.ndvMaxError * 2
+ }
+ } else {
+ false
+ }
+ case None => false
+ }
+ case None => false
+ }
+ case _ => false
+ }
+
+ /**
+ * Given a column over a base table access, it returns
+ * the leaf node column from which the input column is derived.
+ */
+ @tailrec
+ private def findLeafNodeCol(
+ column: Attribute,
+ plan: LogicalPlan): Option[Attribute] = plan match {
+ case pl @ PhysicalOperation(_, _, _: LeafNode) =>
+ pl match {
+ case t: LeafNode if t.outputSet.contains(column) =>
+ Option(column)
+ case p: Project if p.outputSet.exists(_.semanticEquals(column)) =>
+ val col = p.outputSet.find(_.semanticEquals(column)).get
+ findLeafNodeCol(col, p.child)
+ case f: Filter =>
+ findLeafNodeCol(column, f.child)
+ case _ => None
+ }
+ case _ => None
+ }
+
+ /**
+ * Checks if a column has statistics.
+ * The column is assumed to be over a base table access.
+ */
+ private def hasStatistics(
+ column: Attribute,
+ plan: LogicalPlan): Boolean = plan match {
+ case PhysicalOperation(_, _, t: LeafNode) =>
+ val leafCol = findLeafNodeCol(column, plan)
+ leafCol match {
+ case Some(col) if t.outputSet.contains(col) =>
+ val stats = t.stats(conf)
+ stats.attributeStats.nonEmpty && stats.attributeStats.contains(col)
+ case None => false
+ }
+ case _ => false
+ }
+
+ /**
+ * Returns the join predicates between two input plans. It only
+ * considers basic comparison operators.
+ */
+ @inline
+ private def findJoinConditions(
+ plan1: LogicalPlan,
+ plan2: LogicalPlan,
+ conditions: Seq[Expression]): Seq[Expression] = {
+ val refs = plan1.outputSet ++ plan2.outputSet
+ conditions.filter {
+ case BinaryComparison(_, _) => true
+ case _ => false
+ }.filterNot(canEvaluate(_, plan1))
+ .filterNot(canEvaluate(_, plan2))
+ .filter(_.references.subsetOf(refs))
+ }
+
+ /**
+ * Checks if a star join is a selective join. A star join is assumed
+ * to be selective if there are local predicates on the dimension
+ * tables.
+ */
+ private def isSelectiveStarJoin(
+ dimTables: Seq[LogicalPlan],
+ conditions: Seq[Expression]): Boolean = dimTables.exists {
+ case plan @ PhysicalOperation(_, p, _: LeafNode) =>
+ // Checks if any condition applies to the dimension tables.
+ // Exclude the IsNotNull predicates until predicate selectivity is available.
+ // In most cases, this predicate is artificially introduced by the Optimizer
+ // to enforce nullability constraints.
+ val localPredicates = conditions.filterNot(_.isInstanceOf[IsNotNull])
+ .exists(canEvaluate(_, plan))
+
+ // Checks if there are any predicates pushed down to the base table access.
+ val pushedDownPredicates = p.nonEmpty && !p.forall(_.isInstanceOf[IsNotNull])
+
+ localPredicates || pushedDownPredicates
+ case _ => false
+ }
+
+ /**
+ * Helper case class to hold (plan, rowCount) pairs.
+ */
+ private case class TableAccessCardinality(plan: LogicalPlan, size: Option[BigInt])
+
+ /**
+ * Returns the cardinality of a base table access. A base table access represents
+ * a LeafNode, or Project or Filter operators above a LeafNode.
+ */
+ private def getTableAccessCardinality(
+ input: LogicalPlan): Option[BigInt] = input match {
+ case PhysicalOperation(_, cond, t: LeafNode) if t.stats(conf).rowCount.isDefined =>
+ if (conf.cboEnabled && input.stats(conf).rowCount.isDefined) {
+ Option(input.stats(conf).rowCount.get)
+ } else {
+ Option(t.stats(conf).rowCount.get)
+ }
+ case _ => None
+ }
+}
/**
* Reorder the joins and push all the conditions into join, so that the bottom ones have at least
* one condition.
*
* The order of joins will not be changed if all of them already have at least one condition.
+ *
+ * If star schema detection is enabled, reorder the star join plans based on heuristics.
*/
-object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
-
+case class ReorderJoin(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHelper {
/**
* Join a list of plans together and push down the conditions into them.
*
@@ -42,7 +370,7 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
* @param conditions a list of condition for join.
*/
@tailrec
- def createOrderedJoin(input: Seq[(LogicalPlan, InnerLike)], conditions: Seq[Expression])
+ final def createOrderedJoin(input: Seq[(LogicalPlan, InnerLike)], conditions: Seq[Expression])
: LogicalPlan = {
assert(input.size >= 2)
if (input.size == 2) {
@@ -83,9 +411,19 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
}
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case j @ ExtractFiltersAndInnerJoins(input, conditions)
+ case ExtractFiltersAndInnerJoins(input, conditions)
if input.size > 2 && conditions.nonEmpty =>
- createOrderedJoin(input, conditions)
+ if (conf.starSchemaDetection && !conf.cboEnabled) {
+ val starJoinPlan = StarSchemaDetection(conf).reorderStarJoins(input, conditions)
+ if (starJoinPlan.nonEmpty) {
+ val rest = input.filterNot(starJoinPlan.contains(_))
+ createOrderedJoin(starJoinPlan ++ rest, conditions)
+ } else {
+ createOrderedJoin(input, conditions)
+ }
+ } else {
+ createOrderedJoin(input, conditions)
+ }
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 0893af2673..d39b0ef7e1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -167,8 +167,8 @@ object ExtractFiltersAndInnerJoins extends PredicateHelper {
: (Seq[(LogicalPlan, InnerLike)], Seq[Expression]) = plan match {
case Join(left, right, joinType: InnerLike, cond) =>
val (plans, conditions) = flattenJoin(left, joinType)
- (plans ++ Seq((right, joinType)), conditions ++ cond.toSeq)
-
+ (plans ++ Seq((right, joinType)), conditions ++
+ cond.toSeq.flatMap(splitConjunctivePredicates))
case Filter(filterCondition, j @ Join(left, right, _: InnerLike, joinCondition)) =>
val (plans, conditions) = flattenJoin(j)
(plans, conditions ++ splitConjunctivePredicates(filterCondition))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index d2ac4b88ee..b6e0b8ccbe 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -719,6 +719,18 @@ object SQLConf {
.checkValue(weight => weight >= 0 && weight <= 1, "The weight value must be in [0, 1].")
.createWithDefault(0.7)
+ val STARSCHEMA_DETECTION = buildConf("spark.sql.cbo.starSchemaDetection")
+ .doc("When true, it enables join reordering based on star schema detection. ")
+ .booleanConf
+ .createWithDefault(false)
+
+ val STARSCHEMA_FACT_TABLE_RATIO = buildConf("spark.sql.cbo.starJoinFTRatio")
+ .internal()
+ .doc("Specifies the upper limit of the ratio between the largest fact tables" +
+ " for a star join to be considered. ")
+ .doubleConf
+ .createWithDefault(0.9)
+
val SESSION_LOCAL_TIMEZONE =
buildConf("spark.sql.session.timeZone")
.doc("""The ID of session local timezone, e.g. "GMT", "America/Los_Angeles", etc.""")
@@ -988,6 +1000,10 @@ class SQLConf extends Serializable with Logging {
def maxNestedViewDepth: Int = getConf(SQLConf.MAX_NESTED_VIEW_DEPTH)
+ def starSchemaDetection: Boolean = getConf(STARSCHEMA_DETECTION)
+
+ def starSchemaFTRatio: Double = getConf(STARSCHEMA_FACT_TABLE_RATIO)
+
/** ********************** SQLConf functionality methods ************ */
/** Set Spark SQL configuration properties. */
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
index 985e49069d..61e8180814 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins
import org.apache.spark.sql.catalyst.plans.{Cross, Inner, InnerLike, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
-
+import org.apache.spark.sql.catalyst.SimpleCatalystConf
class JoinOptimizationSuite extends PlanTest {
@@ -38,7 +38,7 @@ class JoinOptimizationSuite extends PlanTest {
CombineFilters,
PushDownPredicate,
BooleanSimplification,
- ReorderJoin,
+ ReorderJoin(SimpleCatalystConf(true)),
PushPredicateThroughJoin,
ColumnPruning,
CollapseProject) :: Nil
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala
index 5607bcd16f..05b839b011 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala
@@ -22,10 +22,9 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
-import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, LogicalPlan}
+import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, StatsTestPlan}
-import org.apache.spark.sql.catalyst.util._
class JoinReorderSuite extends PlanTest with StatsEstimationTestBase {
@@ -38,7 +37,7 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase {
Batch("Operator Optimizations", FixedPoint(100),
CombineFilters,
PushDownPredicate,
- ReorderJoin,
+ ReorderJoin(conf),
PushPredicateThroughJoin,
ColumnPruning,
CollapseProject) ::
@@ -203,27 +202,7 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase {
originalPlan: LogicalPlan,
groundTruthBestPlan: LogicalPlan): Unit = {
val optimized = Optimize.execute(originalPlan.analyze)
- val normalized1 = normalizePlan(normalizeExprIds(optimized))
- val normalized2 = normalizePlan(normalizeExprIds(groundTruthBestPlan.analyze))
- if (!sameJoinPlan(normalized1, normalized2)) {
- fail(
- s"""
- |== FAIL: Plans do not match ===
- |${sideBySide(normalized1.treeString, normalized2.treeString).mkString("\n")}
- """.stripMargin)
- }
- }
-
- /** Consider symmetry for joins when comparing plans. */
- private def sameJoinPlan(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = {
- (plan1, plan2) match {
- case (j1: Join, j2: Join) =>
- (sameJoinPlan(j1.left, j2.left) && sameJoinPlan(j1.right, j2.right)) ||
- (sameJoinPlan(j1.left, j2.right) && sameJoinPlan(j1.right, j2.left))
- case _ if plan1.children.nonEmpty && plan2.children.nonEmpty =>
- (plan1.children, plan2.children).zipped.forall { case (c1, c2) => sameJoinPlan(c1, c2) }
- case _ =>
- plan1 == plan2
- }
+ val expected = groundTruthBestPlan.analyze
+ compareJoinOrder(optimized, expected)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala
new file mode 100644
index 0000000000..93fdd98d1a
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala
@@ -0,0 +1,580 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.SimpleCatalystConf
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
+import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
+import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, StatsTestPlan}
+
+
+class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase {
+
+ override val conf = SimpleCatalystConf(
+ caseSensitiveAnalysis = true, starSchemaDetection = true)
+
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches =
+ Batch("Operator Optimizations", FixedPoint(100),
+ CombineFilters,
+ PushDownPredicate,
+ ReorderJoin(conf),
+ PushPredicateThroughJoin,
+ ColumnPruning,
+ CollapseProject) :: Nil
+ }
+
+ // Table setup using star schema relationships:
+ //
+ // d1 - f1 - d2
+ // |
+ // d3 - s3
+ //
+ // Table f1 is the fact table. Tables d1, d2, and d3 are the dimension tables.
+ // Dimension d3 is further joined/normalized into table s3.
+ // Tables' cardinality: f1 > d3 > d1 > d2 > s3
+ private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq(
+ // F1
+ attr("f1_fk1") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ attr("f1_fk2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ attr("f1_fk3") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ attr("f1_c4") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ // D1
+ attr("d1_pk1") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ attr("d1_c2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ attr("d1_c3") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ attr("d1_c4") -> ColumnStat(distinctCount = 2, min = Some(2), max = Some(3),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ // D2
+ attr("d2_c2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3),
+ nullCount = 1, avgLen = 4, maxLen = 4),
+ attr("d2_pk1") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ attr("d2_c3") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ attr("d2_c4") -> ColumnStat(distinctCount = 2, min = Some(3), max = Some(4),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ // D3
+ attr("d3_fk1") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ attr("d3_c2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ attr("d3_pk1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ attr("d3_c4") -> ColumnStat(distinctCount = 2, min = Some(2), max = Some(3),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ // S3
+ attr("s3_pk1") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ attr("s3_c2") -> ColumnStat(distinctCount = 1, min = Some(3), max = Some(3),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ attr("s3_c3") -> ColumnStat(distinctCount = 1, min = Some(3), max = Some(3),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ attr("s3_c4") -> ColumnStat(distinctCount = 2, min = Some(3), max = Some(4),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ // F11
+ attr("f11_fk1") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ attr("f11_fk2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ attr("f11_fk3") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ attr("f11_c4") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4),
+ nullCount = 0, avgLen = 4, maxLen = 4)
+ ))
+
+ private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1)
+ private val nameToColInfo: Map[String, (Attribute, ColumnStat)] =
+ columnInfo.map(kv => kv._1.name -> kv)
+
+ private val f1 = StatsTestPlan(
+ outputList = Seq("f1_fk1", "f1_fk2", "f1_fk3", "f1_c4").map(nameToAttr),
+ rowCount = 6,
+ size = Some(48),
+ attributeStats = AttributeMap(Seq("f1_fk1", "f1_fk2", "f1_fk3", "f1_c4").map(nameToColInfo)))
+
+ private val d1 = StatsTestPlan(
+ outputList = Seq("d1_pk1", "d1_c2", "d1_c3", "d1_c4").map(nameToAttr),
+ rowCount = 4,
+ size = Some(32),
+ attributeStats = AttributeMap(Seq("d1_pk1", "d1_c2", "d1_c3", "d1_c4").map(nameToColInfo)))
+
+ private val d2 = StatsTestPlan(
+ outputList = Seq("d2_c2", "d2_pk1", "d2_c3", "d2_c4").map(nameToAttr),
+ rowCount = 3,
+ size = Some(24),
+ attributeStats = AttributeMap(Seq("d2_c2", "d2_pk1", "d2_c3", "d2_c4").map(nameToColInfo)))
+
+ private val d3 = StatsTestPlan(
+ outputList = Seq("d3_fk1", "d3_c2", "d3_pk1", "d3_c4").map(nameToAttr),
+ rowCount = 5,
+ size = Some(40),
+ attributeStats = AttributeMap(Seq("d3_fk1", "d3_c2", "d3_pk1", "d3_c4").map(nameToColInfo)))
+
+ private val s3 = StatsTestPlan(
+ outputList = Seq("s3_pk1", "s3_c2", "s3_c3", "s3_c4").map(nameToAttr),
+ rowCount = 2,
+ size = Some(17),
+ attributeStats = AttributeMap(Seq("s3_pk1", "s3_c2", "s3_c3", "s3_c4").map(nameToColInfo)))
+
+ private val d3_ns = LocalRelation('d3_fk1.int, 'd3_c2.int, 'd3_pk1.int, 'd3_c4.int)
+
+ private val f11 = StatsTestPlan(
+ outputList = Seq("f11_fk1", "f11_fk2", "f11_fk3", "f11_c4").map(nameToAttr),
+ rowCount = 6,
+ size = Some(48),
+ attributeStats = AttributeMap(Seq("f11_fk1", "f11_fk2", "f11_fk3", "f11_c4")
+ .map(nameToColInfo)))
+
+ private val subq = d3.select(sum('d3_fk1).as('col))
+
+ test("Test 1: Selective star-join on all dimensions") {
+ // Star join:
+ // (=) (=)
+ // d1 - f1 - d2
+ // | (=)
+ // s3 - d3
+ //
+ // Query:
+ // select f1_fk1, f1_fk3
+ // from d1, d2, f1, d3, s3
+ // where f1_fk2 = d2_pk1 and d2_c2 < 2
+ // and f1_fk1 = d1_pk1
+ // and f1_fk3 = d3_pk1
+ // and d3_fk1 = s3_pk1
+ //
+ // Positional join reordering: d1, f1, d2, d3, s3
+ // Star join reordering: f1, d2, d1, d3, s3
+ val query =
+ d1.join(d2).join(f1).join(d3).join(s3)
+ .where((nameToAttr("f1_fk2") === nameToAttr("d2_pk1")) &&
+ (nameToAttr("d2_c2") === 2) &&
+ (nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) &&
+ (nameToAttr("f1_fk3") === nameToAttr("d3_pk1")) &&
+ (nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
+
+ val expected =
+ f1.join(d2.where(nameToAttr("d2_c2") === 2), Inner,
+ Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1")))
+ .join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1")))
+ .join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1")))
+ .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
+
+ assertEqualPlans(query, expected)
+ }
+
+ test("Test 2: Star join on a subset of dimensions due to inequality joins") {
+ // Star join:
+ // (=) (<)
+ // d1 - f1 - d2
+ // |
+ // | (=)
+ // d3 - s3
+ // (=)
+ //
+ // Query:
+ // select f1_fk1, f1_fk3
+ // from d1, f1, d2, s3, d3
+ // where f1_fk2 < d2_pk1
+ // and f1_fk1 = d1_pk1 and d1_c2 = 2
+ // and f1_fk3 = d3_pk1
+ // and d3_fk1 = s3_pk1
+ //
+ // Default join reordering: d1, f1, d2, d3, s3
+ // Star join reordering: f1, d1, d3, d2,, d3
+
+ val query =
+ d1.join(f1).join(d2).join(s3).join(d3)
+ .where((nameToAttr("f1_fk2") < nameToAttr("d2_pk1")) &&
+ (nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) &&
+ (nameToAttr("d1_c2") === 2) &&
+ (nameToAttr("f1_fk3") === nameToAttr("d3_pk1")) &&
+ (nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
+
+ val expected =
+ f1.join(d1.where(nameToAttr("d1_c2") === 2), Inner,
+ Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1")))
+ .join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1")))
+ .join(d2, Inner, Some(nameToAttr("f1_fk2") < nameToAttr("d2_pk1")))
+ .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
+
+ assertEqualPlans(query, expected)
+ }
+
+ test("Test 3: Star join on a subset of dimensions since join column is not unique") {
+ // Star join:
+ // (=) (=)
+ // d1 - f1 - d2
+ // | (=)
+ // d3 - s3
+ //
+ // Query:
+ // select f1_fk1, f1_fk3
+ // from d1, f1, d2, s3, d3
+ // where f1_fk2 = d2_c4
+ // and f1_fk1 = d1_pk1 and d1_c2 = 2
+ // and f1_fk3 = d3_pk1
+ // and d3_fk1 = s3_pk1
+ //
+ // Default join reordering: d1, f1, d2, d3, s3
+ // Star join reordering: f1, d1, d3, d2, d3
+ val query =
+ d1.join(f1).join(d2).join(s3).join(d3)
+ .where((nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) &&
+ (nameToAttr("d1_c2") === 2) &&
+ (nameToAttr("f1_fk2") === nameToAttr("d2_c4")) &&
+ (nameToAttr("f1_fk3") === nameToAttr("d3_pk1")) &&
+ (nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
+
+ val expected =
+ f1.join(d1.where(nameToAttr("d1_c2") === 2), Inner,
+ Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1")))
+ .join(d3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
+ .join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1")))
+ .join(s3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("s3_c2")))
+
+
+ assertEqualPlans(query, expected)
+ }
+
+ test("Test 4: Star join on a subset of dimensions since join column is nullable") {
+ // Star join:
+ // (=) (=)
+ // d1 - f1 - d2
+ // | (=)
+ // s3 - d3
+ //
+ // Query:
+ // select f1_fk1, f1_fk3
+ // from d1, f1, d2, s3, d3
+ // where f1_fk2 = d2_c2
+ // and f1_fk1 = d1_pk1 and d1_c2 = 2
+ // and f1_fk3 = d3_pk1
+ // and d3_fk1 = s3_pk1
+ //
+ // Default join reordering: d1, f1, d2, d3, s3
+ // Star join reordering: f1, d1, d3, d2, s3
+
+ val query =
+ d1.join(f1).join(d2).join(s3).join(d3)
+ .where((nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) &&
+ (nameToAttr("d1_c2") === 2) &&
+ (nameToAttr("f1_fk2") === nameToAttr("d2_c2")) &&
+ (nameToAttr("f1_fk3") === nameToAttr("d3_pk1")) &&
+ (nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
+
+ val expected =
+ f1.join(d1.where(nameToAttr("d1_c2") === 2), Inner,
+ Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1")))
+ .join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1")))
+ .join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_c2")))
+ .join(s3, Inner, Some(nameToAttr("d3_fk1") < nameToAttr("s3_pk1")))
+
+ assertEqualPlans(query, expected)
+ }
+
+ test("Test 5: Table stats not available for some of the joined tables") {
+ // Star join:
+ // (=) (=)
+ // d1 - f1 - d2
+ // | (=)
+ // d3_ns - s3
+ //
+ // select f1_fk1, f1_fk3
+ // from d3_ns, f1, d1, d2, s3
+ // where f1_fk2 = d2_pk1 and d2_c2 = 2
+ // and f1_fk1 = d1_pk1
+ // and f1_fk3 = d3_pk1
+ // and d3_fk1 = s3_pk1
+ //
+ // Positional join reordering: d3_ns, f1, d1, d2, s3
+ // Star join reordering: empty
+
+ val query =
+ d3_ns.join(f1).join(d1).join(d2).join(s3)
+ .where((nameToAttr("f1_fk2") === nameToAttr("d2_pk1")) &&
+ (nameToAttr("d2_c2") === 2) &&
+ (nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) &&
+ (nameToAttr("f1_fk3") === nameToAttr("d3_pk1")) &&
+ (nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
+
+ val equivQuery =
+ d3_ns.join(f1, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1")))
+ .join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1")))
+ .join(d2.where(nameToAttr("d2_c2") === 2), Inner,
+ Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1")))
+ .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
+
+ assertEqualPlans(query, equivQuery)
+ }
+
+ test("Test 6: Join with complex plans") {
+ // Star join:
+ // (=) (=)
+ // d1 - f1 - d2
+ // | (=)
+ // (sub-query)
+ //
+ // select f1_fk1, f1_fk3
+ // from (select sum(d3_fk1) as col from d3) subq, f1, d1, d2
+ // where f1_fk2 = d2_pk1 and d2_c2 < 2
+ // and f1_fk1 = d1_pk1
+ // and f1_fk3 = sq.col
+ //
+ // Positional join reordering: d3, f1, d1, d2
+ // Star join reordering: empty
+
+ val query =
+ subq.join(f1).join(d1).join(d2)
+ .where((nameToAttr("f1_fk2") === nameToAttr("d2_pk1")) &&
+ (nameToAttr("d2_c2") === 2) &&
+ (nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) &&
+ (nameToAttr("f1_fk3") === "col".attr))
+
+ val expected =
+ d3.select('d3_fk1).select(sum('d3_fk1).as('col))
+ .join(f1, Inner, Some(nameToAttr("f1_fk3") === "col".attr))
+ .join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1")))
+ .join(d2.where(nameToAttr("d2_c2") === 2), Inner,
+ Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1")))
+
+ assertEqualPlans(query, expected)
+ }
+
+ test("Test 7: Comparable fact table sizes") {
+ // Star join:
+ // (=) (=)
+ // d1 - f1 - d2
+ // | (=)
+ // f11 - s3
+ //
+ // select f1.f1_fk1, f1.f1_fk3
+ // from d1, f11, f1, d2, s3
+ // where f1.f1_fk2 = d2_pk1 and d2_c2 = 2
+ // and f1.f1_fk1 = d1_pk1
+ // and f1.f1_fk3 = f11.f1_fk3
+ // and f11.f1_fk1 = s3_pk1
+ //
+ // Positional join reordering: d1, f1, f11, d2, s3
+ // Star join reordering: empty
+
+ val query =
+ d1.join(f11).join(f1).join(d2).join(s3)
+ .where((nameToAttr("f1_fk2") === nameToAttr("d2_pk1")) &&
+ (nameToAttr("d2_c2") === 2) &&
+ (nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) &&
+ (nameToAttr("f1_fk3") === nameToAttr("f11_fk3")) &&
+ (nameToAttr("f11_fk1") === nameToAttr("s3_pk1")))
+
+ val equivQuery =
+ d1.join(f1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1")))
+ .join(f11, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("f11_fk3")))
+ .join(d2.where(nameToAttr("d2_c2") === 2), Inner,
+ Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1")))
+ .join(s3, Inner, Some(nameToAttr("f11_fk1") === nameToAttr("s3_pk1")))
+
+ assertEqualPlans(query, equivQuery)
+ }
+
+ test("Test 8: No RI joins") {
+ // Star join:
+ // (=) (=)
+ // d1 - f1 - d2
+ // | (=)
+ // d3 - s3
+ //
+ // select f1_fk1, f1_fk3
+ // from d1, d3, f1, d2, s3
+ // where f1_fk2 = d2_c4 and d2_c2 = 2
+ // and f1_fk1 = d1_c4
+ // and f1_fk3 = d3_c4
+ // and d3_fk1 = s3_pk1
+ //
+ // Positional/default join reordering: d1, f1, d3, d2, s3
+ // Star join reordering: empty
+
+ val query =
+ d1.join(d3).join(f1).join(d2).join(s3)
+ .where((nameToAttr("f1_fk2") === nameToAttr("d2_c4")) &&
+ (nameToAttr("d2_c2") === 2) &&
+ (nameToAttr("f1_fk1") === nameToAttr("d1_c4")) &&
+ (nameToAttr("f1_fk3") === nameToAttr("d3_c4")) &&
+ (nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
+
+ val expected =
+ d1.join(f1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_c4")))
+ .join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_c4")))
+ .join(d2.where(nameToAttr("d2_c2") === 2), Inner,
+ Some(nameToAttr("f1_fk2") === nameToAttr("d2_c4")))
+ .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
+
+ assertEqualPlans(query, expected)
+ }
+
+ test("Test 9: Complex join predicates") {
+ // Star join:
+ // (=) (=)
+ // d1 - f1 - d2
+ // | (=)
+ // d3 - s3
+ //
+ // select f1_fk1, f1_fk3
+ // from d1, d3, f1, d2, s3
+ // where f1_fk2 = d2_pk1 and d2_c2 = 2
+ // and abs(f1_fk1) = d1_pk1
+ // and f1_fk3 = d3_pk1
+ // and d3_fk1 = s3_pk1
+ //
+ // Positional/default join reordering: d1, f1, d3, d2, s3
+ // Star join reordering: empty
+
+ val query =
+ d1.join(d3).join(f1).join(d2).join(s3)
+ .where((nameToAttr("f1_fk2") === nameToAttr("d2_pk1")) &&
+ (nameToAttr("d2_c2") === 2) &&
+ (abs(nameToAttr("f1_fk1")) === nameToAttr("d1_pk1")) &&
+ (nameToAttr("f1_fk3") === nameToAttr("d3_pk1")) &&
+ (nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
+
+ val expected =
+ d1.join(f1, Inner, Some(abs(nameToAttr("f1_fk1")) === nameToAttr("d1_pk1")))
+ .join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1")))
+ .join(d2.where(nameToAttr("d2_c2") === 2), Inner,
+ Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1")))
+ .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
+
+ assertEqualPlans(query, expected)
+ }
+
+ test("Test 10: Less than two dimensions") {
+ // Star join:
+ // (<) (=)
+ // d1 - f1 - d2
+ // |(<)
+ // d3 - s3
+ //
+ // select f1_fk1, f1_fk3
+ // from d1, d3, f1, d2, s3
+ // where f1_fk2 = d2_pk1 and d2_c2 = 2
+ // and f1_fk1 < d1_pk1
+ // and f1_fk3 < d3_pk1
+ //
+ // Positional join reordering: d1, f1, d3, d2, s3
+ // Star join reordering: empty
+
+ val query =
+ d1.join(d3).join(f1).join(d2).join(s3)
+ .where((nameToAttr("f1_fk2") === nameToAttr("d2_pk1")) &&
+ (nameToAttr("d2_c2") === 2) &&
+ (nameToAttr("f1_fk1") < nameToAttr("d1_pk1")) &&
+ (nameToAttr("f1_fk3") < nameToAttr("d3_pk1")) &&
+ (nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
+
+ val expected =
+ d1.join(f1, Inner, Some(nameToAttr("f1_fk1") < nameToAttr("d1_pk1")))
+ .join(d3, Inner, Some(nameToAttr("f1_fk3") < nameToAttr("d3_pk1")))
+ .join(d2.where(nameToAttr("d2_c2") === 2),
+ Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1")))
+ .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
+
+ assertEqualPlans(query, expected)
+ }
+
+ test("Test 11: Expanding star join") {
+ // Star join:
+ // (<) (<)
+ // d1 - f1 - d2
+ // | (<)
+ // d3 - s3
+ //
+ // select f1_fk1, f1_fk3
+ // from d1, d3, f1, d2, s3
+ // where f1_fk2 < d2_pk1
+ // and f1_fk1 < d1_pk1
+ // and f1_fk3 < d3_pk1
+ // and d3_fk1 < s3_pk1
+ //
+ // Positional join reordering: d1, f1, d3, d2, s3
+ // Star join reordering: empty
+
+ val query =
+ d1.join(d3).join(f1).join(d2).join(s3)
+ .where((nameToAttr("f1_fk2") < nameToAttr("d2_pk1")) &&
+ (nameToAttr("f1_fk1") < nameToAttr("d1_pk1")) &&
+ (nameToAttr("f1_fk3") < nameToAttr("d3_pk1")) &&
+ (nameToAttr("d3_fk1") < nameToAttr("s3_pk1")))
+
+ val expected =
+ d1.join(f1, Inner, Some(nameToAttr("f1_fk1") < nameToAttr("d1_pk1")))
+ .join(d3, Inner, Some(nameToAttr("f1_fk3") < nameToAttr("d3_pk1")))
+ .join(d2, Inner, Some(nameToAttr("f1_fk2") < nameToAttr("d2_pk1")))
+ .join(s3, Inner, Some(nameToAttr("d3_fk1") < nameToAttr("s3_pk1")))
+
+ assertEqualPlans(query, expected)
+ }
+
+ test("Test 12: Non selective star join") {
+ // Star join:
+ // (=) (=)
+ // d1 - f1 - d2
+ // | (=)
+ // d3 - s3
+ //
+ // select f1_fk1, f1_fk3
+ // from d1, d3, f1, d2, s3
+ // where f1_fk2 = d2_pk1
+ // and f1_fk1 = d1_pk1
+ // and f1_fk3 = d3_pk1
+ // and d3_fk1 = s3_pk1
+ //
+ // Positional join reordering: d1, f1, d3, d2, s3
+ // Star join reordering: empty
+
+ val query =
+ d1.join(d3).join(f1).join(d2).join(s3)
+ .where((nameToAttr("f1_fk2") === nameToAttr("d2_pk1")) &&
+ (nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) &&
+ (nameToAttr("f1_fk3") === nameToAttr("d3_pk1")) &&
+ (nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
+
+ val expected =
+ d1.join(f1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1")))
+ .join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1")))
+ .join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1")))
+ .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
+
+ assertEqualPlans(query, expected)
+ }
+
+ private def assertEqualPlans( plan1: LogicalPlan, plan2: LogicalPlan): Unit = {
+ val optimized = Optimize.execute(plan1.analyze)
+ val expected = plan2.analyze
+ compareJoinOrder(optimized, expected)
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
index 5eb31413ad..2a9d057014 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
@@ -106,4 +106,30 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper {
protected def compareExpressions(e1: Expression, e2: Expression): Unit = {
comparePlans(Filter(e1, OneRowRelation), Filter(e2, OneRowRelation))
}
+
+ /** Fails the test if the join order in the two plans do not match */
+ protected def compareJoinOrder(plan1: LogicalPlan, plan2: LogicalPlan) {
+ val normalized1 = normalizePlan(normalizeExprIds(plan1))
+ val normalized2 = normalizePlan(normalizeExprIds(plan2))
+ if (!sameJoinPlan(normalized1, normalized2)) {
+ fail(
+ s"""
+ |== FAIL: Plans do not match ===
+ |${sideBySide(normalized1.treeString, normalized2.treeString).mkString("\n")}
+ """.stripMargin)
+ }
+ }
+
+ /** Consider symmetry for joins when comparing plans. */
+ private def sameJoinPlan(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = {
+ (plan1, plan2) match {
+ case (j1: Join, j2: Join) =>
+ (sameJoinPlan(j1.left, j2.left) && sameJoinPlan(j1.right, j2.right)) ||
+ (sameJoinPlan(j1.left, j2.right) && sameJoinPlan(j1.right, j2.left))
+ case _ if plan1.children.nonEmpty && plan2.children.nonEmpty =>
+ (plan1.children, plan2.children).zipped.forall { case (c1, c2) => sameJoinPlan(c1, c2) }
+ case _ =>
+ plan1 == plan2
+ }
+ }
}