aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorSrinath Shankar <srinath@databricks.com>2016-09-03 00:20:43 +0200
committerHerman van Hovell <hvanhovell@databricks.com>2016-09-03 00:20:43 +0200
commite6132a6cf10df8b12af8dd8d1a2c563792b5cc5a (patch)
treed706ac4d4091a7ae31eda5c7d62c2d8c2c4a7414 /sql/catalyst
parenta2c9acb0e54b2e38cb8ee6431f1ea0e0b4cd959a (diff)
downloadspark-e6132a6cf10df8b12af8dd8d1a2c563792b5cc5a.tar.gz
spark-e6132a6cf10df8b12af8dd8d1a2c563792b5cc5a.tar.bz2
spark-e6132a6cf10df8b12af8dd8d1a2c563792b5cc5a.zip
[SPARK-17298][SQL] Require explicit CROSS join for cartesian products
## What changes were proposed in this pull request? Require the use of CROSS join syntax in SQL (and a new crossJoin DataFrame API) to specify explicit cartesian products between relations. By cartesian product we mean a join between relations R and S where there is no join condition involving columns from both R and S. If a cartesian product is detected in the absence of an explicit CROSS join, an error must be thrown. Turning on the "spark.sql.crossJoin.enabled" configuration flag will disable this check and allow cartesian products without an explicit CROSS join. The new crossJoin DataFrame API must be used to specify explicit cross joins. The existing join(DataFrame) method will produce a INNER join that will require a subsequent join condition. That is df1.join(df2) is equivalent to select * from df1, df2. ## How was this patch tested? Added cross-join.sql to the SQLQueryTestSuite to test the check for cartesian products. Added a couple of tests to the DataFrameJoinSuite to test the crossJoin API. Modified various other test suites to explicitly specify a cross join where an INNER join or a comma-separated list was previously used. Author: Srinath Shankar <srinath@databricks.com> Closes #14866 from srinathshankar/crossjoin.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g43
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala49
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala25
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala27
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala20
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala8
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala60
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala2
16 files changed, 169 insertions, 53 deletions
diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
index a8af840c1e..0447436ea7 100644
--- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
+++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
@@ -375,7 +375,7 @@ setQuantifier
relation
: left=relation
- ((CROSS | joinType) JOIN right=relation joinCriteria?
+ (joinType JOIN right=relation joinCriteria?
| NATURAL joinType JOIN right=relation
) #joinRelation
| relationPrimary #relationDefault
@@ -383,6 +383,7 @@ relation
joinType
: INNER?
+ | CROSS
| LEFT OUTER?
| LEFT SEMI
| RIGHT OUTER?
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
index 4df100c2a8..75ae588c18 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
@@ -36,6 +36,12 @@ trait CatalystConf {
def warehousePath: String
+ /** If true, cartesian products between relations will be allowed for all
+ * join types(inner, (left|right|full) outer).
+ * If false, cartesian products will require explicit CROSS JOIN syntax.
+ */
+ def crossJoinEnabled: Boolean
+
/**
* Returns the [[Resolver]] for the current configuration, which can be used to determine if two
* identifiers are equal.
@@ -55,5 +61,6 @@ case class SimpleCatalystConf(
optimizerInSetConversionThreshold: Int = 10,
maxCaseBranchesForCodegen: Int = 20,
runSQLonFile: Boolean = true,
+ crossJoinEnabled: Boolean = false,
warehousePath: String = "/user/hive/warehouse")
extends CatalystConf
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index e559f235c5..18f814d6cd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -1003,7 +1003,7 @@ class Analyzer(
failOnOuterReference(j)
failOnOuterReferenceInSubTree(left, "a RIGHT OUTER JOIN")
j
- case j @ Join(_, right, jt, _) if jt != Inner =>
+ case j @ Join(_, right, jt, _) if !jt.isInstanceOf[InnerLike] =>
failOnOuterReference(j)
failOnOuterReferenceInSubTree(right, "a LEFT (OUTER) JOIN")
j
@@ -1899,7 +1899,7 @@ class Analyzer(
joinedCols ++
lUniqueOutput.map(_.withNullability(true)) ++
rUniqueOutput.map(_.withNullability(true))
- case Inner =>
+ case _ : InnerLike =>
leftKeys ++ lUniqueOutput ++ rUniqueOutput
case _ =>
sys.error("Unsupported natural join type " + joinType)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
index f6e32e29eb..e81370c504 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
@@ -94,7 +94,7 @@ object UnsupportedOperationChecker {
joinType match {
- case Inner =>
+ case _: InnerLike =>
if (left.isStreaming && right.isStreaming) {
throwError("Inner join between two streaming DataFrames/Datasets is not supported")
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 7617d34261..d2f0c97989 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -22,6 +22,7 @@ import scala.collection.immutable.HashSet
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.api.java.function.FilterFunction
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
@@ -107,6 +108,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
RewriteCorrelatedScalarSubquery,
EliminateSerialization,
RemoveAliasOnlyProject) ::
+ Batch("Check Cartesian Products", Once,
+ CheckCartesianProducts(conf)) ::
Batch("Decimal Optimizations", fixedPoint,
DecimalAggregates) ::
Batch("Typed Filter Optimization", fixedPoint,
@@ -838,7 +841,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
val (leftFilterConditions, rightFilterConditions, commonFilterCondition) =
split(splitConjunctivePredicates(filterCondition), left, right)
joinType match {
- case Inner =>
+ case _: InnerLike =>
// push down the single side `where` condition into respective sides
val newLeft = leftFilterConditions.
reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
@@ -848,7 +851,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
commonFilterCondition.partition(e => !SubqueryExpression.hasCorrelatedSubquery(e))
val newJoinCond = (newJoinConditions ++ joinCondition).reduceLeftOption(And)
- val join = Join(newLeft, newRight, Inner, newJoinCond)
+ val join = Join(newLeft, newRight, joinType, newJoinCond)
if (others.nonEmpty) {
Filter(others.reduceLeft(And), join)
} else {
@@ -885,7 +888,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
split(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil), left, right)
joinType match {
- case Inner | LeftExistence(_) =>
+ case _: InnerLike | LeftExistence(_) =>
// push down the single side only join filter for both sides sub queries
val newLeft = leftJoinConditions.
reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
@@ -933,6 +936,46 @@ object CombineLimits extends Rule[LogicalPlan] {
}
/**
+ * Check if there any cartesian products between joins of any type in the optimized plan tree.
+ * Throw an error if a cartesian product is found without an explicit cross join specified.
+ * This rule is effectively disabled if the CROSS_JOINS_ENABLED flag is true.
+ *
+ * This rule must be run AFTER the ReorderJoin rule since the join conditions for each join must be
+ * collected before checking if it is a cartesian product. If you have
+ * SELECT * from R, S where R.r = S.s,
+ * the join between R and S is not a cartesian product and therefore should be allowed.
+ * The predicate R.r = S.s is not recognized as a join condition until the ReorderJoin rule.
+ */
+case class CheckCartesianProducts(conf: CatalystConf)
+ extends Rule[LogicalPlan] with PredicateHelper {
+ /**
+ * Check if a join is a cartesian product. Returns true if
+ * there are no join conditions involving references from both left and right.
+ */
+ def isCartesianProduct(join: Join): Boolean = {
+ val conditions = join.condition.map(splitConjunctivePredicates).getOrElse(Nil)
+ !conditions.map(_.references).exists(refs => refs.exists(join.left.outputSet.contains)
+ && refs.exists(join.right.outputSet.contains))
+ }
+
+ def apply(plan: LogicalPlan): LogicalPlan =
+ if (conf.crossJoinEnabled) {
+ plan
+ } else plan transform {
+ case j @ Join(left, right, Inner | LeftOuter | RightOuter | FullOuter, condition)
+ if isCartesianProduct(j) =>
+ throw new AnalysisException(
+ s"""Detected cartesian product for ${j.joinType.sql} join between logical plans
+ |${left.treeString(false).trim}
+ |and
+ |${right.treeString(false).trim}
+ |Join condition is missing or trivial.
+ |Use the CROSS JOIN syntax to allow cartesian products between these relations."""
+ .stripMargin)
+ }
+}
+
+/**
* Speeds up aggregates on fixed-precision decimals by executing them on unscaled Long values.
*
* This uses the same rules for increasing the precision and scale of the output as
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
index 50076b1a41..7400a01918 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
@@ -50,7 +50,7 @@ object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper {
empty(p)
case p @ Join(_, _, joinType, _) if p.children.exists(isEmptyLocalRelation) => joinType match {
- case Inner => empty(p)
+ case _: InnerLike => empty(p)
// Intersect is handled as LeftSemi by `ReplaceIntersectWithSemiJoin` rule.
// Except is handled as LeftAnti by `ReplaceExceptWithAntiJoin` rule.
case LeftOuter | LeftSemi | LeftAnti if isEmptyLocalRelation(p.left) => empty(p)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
index 158ad3d91f..1621bffd61 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
@@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
-
/**
* Reorder the joins and push all the conditions into join, so that the bottom ones have at least
* one condition.
@@ -39,39 +38,46 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
*
* The joined plan are picked from left to right, prefer those has at least one join condition.
*
- * @param input a list of LogicalPlans to join.
+ * @param input a list of LogicalPlans to inner join and the type of inner join.
* @param conditions a list of condition for join.
*/
@tailrec
- def createOrderedJoin(input: Seq[LogicalPlan], conditions: Seq[Expression]): LogicalPlan = {
+ def createOrderedJoin(input: Seq[(LogicalPlan, InnerLike)], conditions: Seq[Expression])
+ : LogicalPlan = {
assert(input.size >= 2)
if (input.size == 2) {
val (joinConditions, others) = conditions.partition(
e => !SubqueryExpression.hasCorrelatedSubquery(e))
- val join = Join(input(0), input(1), Inner, joinConditions.reduceLeftOption(And))
+ val ((left, leftJoinType), (right, rightJoinType)) = (input(0), input(1))
+ val innerJoinType = (leftJoinType, rightJoinType) match {
+ case (Inner, Inner) => Inner
+ case (_, _) => Cross
+ }
+ val join = Join(left, right, innerJoinType, joinConditions.reduceLeftOption(And))
if (others.nonEmpty) {
Filter(others.reduceLeft(And), join)
} else {
join
}
} else {
- val left :: rest = input.toList
+ val (left, _) :: rest = input.toList
// find out the first join that have at least one join condition
- val conditionalJoin = rest.find { plan =>
+ val conditionalJoin = rest.find { planJoinPair =>
+ val plan = planJoinPair._1
val refs = left.outputSet ++ plan.outputSet
conditions.filterNot(canEvaluate(_, left)).filterNot(canEvaluate(_, plan))
.exists(_.references.subsetOf(refs))
}
// pick the next one if no condition left
- val right = conditionalJoin.getOrElse(rest.head)
+ val (right, innerJoinType) = conditionalJoin.getOrElse(rest.head)
val joinedRefs = left.outputSet ++ right.outputSet
val (joinConditions, others) = conditions.partition(
e => e.references.subsetOf(joinedRefs) && !SubqueryExpression.hasCorrelatedSubquery(e))
- val joined = Join(left, right, Inner, joinConditions.reduceLeftOption(And))
+ val joined = Join(left, right, innerJoinType, joinConditions.reduceLeftOption(And))
// should not have reference to same logical plan
- createOrderedJoin(Seq(joined) ++ rest.filterNot(_ eq right), others)
+ createOrderedJoin(Seq((joined, Inner)) ++ rest.filterNot(_._1 eq right), others)
}
}
@@ -82,7 +88,6 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
}
}
-
/**
* Elimination of outer joins, if the predicates can restrict the result sets so that
* all null-supplying rows are eliminated
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 42fbc16d03..e4cb9f0161 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -539,6 +539,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
def join(ctx: JoinRelationContext, left: LogicalPlan, right: LogicalPlan): Join = {
val baseJoinType = ctx.joinType match {
case null => Inner
+ case jt if jt.CROSS != null => Cross
case jt if jt.FULL != null => FullOuter
case jt if jt.SEMI != null => LeftSemi
case jt if jt.ANTI != null => LeftAnti
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 476c66af76..41cabb8cb3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -159,23 +159,30 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper {
*/
object ExtractFiltersAndInnerJoins extends PredicateHelper {
- // flatten all inner joins, which are next to each other
- def flattenJoin(plan: LogicalPlan): (Seq[LogicalPlan], Seq[Expression]) = plan match {
- case Join(left, right, Inner, cond) =>
- val (plans, conditions) = flattenJoin(left)
- (plans ++ Seq(right), conditions ++ cond.toSeq)
+ /**
+ * Flatten all inner joins, which are next to each other.
+ * Return a list of logical plans to be joined with a boolean for each plan indicating if it
+ * was involved in an explicit cross join. Also returns the entire list of join conditions for
+ * the left-deep tree.
+ */
+ def flattenJoin(plan: LogicalPlan, parentJoinType: InnerLike = Inner)
+ : (Seq[(LogicalPlan, InnerLike)], Seq[Expression]) = plan match {
+ case Join(left, right, joinType: InnerLike, cond) =>
+ val (plans, conditions) = flattenJoin(left, joinType)
+ (plans ++ Seq((right, joinType)), conditions ++ cond.toSeq)
- case Filter(filterCondition, j @ Join(left, right, Inner, joinCondition)) =>
+ case Filter(filterCondition, j @ Join(left, right, _: InnerLike, joinCondition)) =>
val (plans, conditions) = flattenJoin(j)
(plans, conditions ++ splitConjunctivePredicates(filterCondition))
- case _ => (Seq(plan), Seq())
+ case _ => (Seq((plan, parentJoinType)), Seq())
}
- def unapply(plan: LogicalPlan): Option[(Seq[LogicalPlan], Seq[Expression])] = plan match {
- case f @ Filter(filterCondition, j @ Join(_, _, Inner, _)) =>
+ def unapply(plan: LogicalPlan): Option[(Seq[(LogicalPlan, InnerLike)], Seq[Expression])]
+ = plan match {
+ case f @ Filter(filterCondition, j @ Join(_, _, joinType: InnerLike, _)) =>
Some(flattenJoin(f))
- case j @ Join(_, _, Inner, _) =>
+ case j @ Join(_, _, joinType, _) =>
Some(flattenJoin(j))
case _ => None
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
index 80674d9b4b..61e083e6fc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
@@ -28,6 +28,7 @@ object JoinType {
case "rightouter" | "right" => RightOuter
case "leftsemi" => LeftSemi
case "leftanti" => LeftAnti
+ case "cross" => Cross
case _ =>
val supported = Seq(
"inner",
@@ -35,7 +36,8 @@ object JoinType {
"leftouter", "left",
"rightouter", "right",
"leftsemi",
- "leftanti")
+ "leftanti",
+ "cross")
throw new IllegalArgumentException(s"Unsupported join type '$typ'. " +
"Supported join types include: " + supported.mkString("'", "', '", "'") + ".")
@@ -46,10 +48,24 @@ sealed abstract class JoinType {
def sql: String
}
-case object Inner extends JoinType {
+/**
+ * The explicitCartesian flag indicates if the inner join was constructed with a CROSS join
+ * indicating a cartesian product has been explicitly requested.
+ */
+sealed abstract class InnerLike extends JoinType {
+ def explicitCartesian: Boolean
+}
+
+case object Inner extends InnerLike {
+ override def explicitCartesian: Boolean = false
override def sql: String = "INNER"
}
+case object Cross extends InnerLike {
+ override def explicitCartesian: Boolean = true
+ override def sql: String = "CROSS"
+}
+
case object LeftOuter extends JoinType {
override def sql: String = "LEFT OUTER"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 010aec7ba1..d2d33e40a8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -293,7 +293,7 @@ case class Join(
override protected def validConstraints: Set[Expression] = {
joinType match {
- case Inner if condition.isDefined =>
+ case _: InnerLike if condition.isDefined =>
left.constraints
.union(right.constraints)
.union(splitConjunctivePredicates(condition.get).toSet)
@@ -302,7 +302,7 @@ case class Join(
.union(splitConjunctivePredicates(condition.get).toSet)
case j: ExistenceJoin =>
left.constraints
- case Inner =>
+ case _: InnerLike =>
left.constraints.union(right.constraints)
case LeftExistence(_) =>
left.constraints
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index 13bf034f83..e7c8615bc5 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Count, Max}
-import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, RightOuter}
+import org.apache.spark.sql.catalyst.plans.{Cross, Inner, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData}
import org.apache.spark.sql.types._
@@ -396,7 +396,7 @@ class AnalysisErrorSuite extends AnalysisTest {
}
test("error test for self-join") {
- val join = Join(testRelation, testRelation, Inner, None)
+ val join = Join(testRelation, testRelation, Cross, None)
val error = intercept[AnalysisException] {
SimpleAnalyzer.checkAnalysis(join)
}
@@ -475,7 +475,7 @@ class AnalysisErrorSuite extends AnalysisTest {
LocalRelation(
AttributeReference("c", BinaryType)(exprId = ExprId(4)),
AttributeReference("d", IntegerType)(exprId = ExprId(3))),
- Inner,
+ Cross,
Some(EqualTo(AttributeReference("a", BinaryType)(exprId = ExprId(2)),
AttributeReference("c", BinaryType)(exprId = ExprId(4)))))
@@ -489,7 +489,7 @@ class AnalysisErrorSuite extends AnalysisTest {
LocalRelation(
AttributeReference("c", MapType(IntegerType, StringType))(exprId = ExprId(4)),
AttributeReference("d", IntegerType)(exprId = ExprId(3))),
- Inner,
+ Cross,
Some(EqualTo(AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)),
AttributeReference("c", MapType(IntegerType, StringType))(exprId = ExprId(4)))))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 8971edc7d3..50ebad25cd 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.{SimpleCatalystConf, TableIdentifier}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.Inner
+import org.apache.spark.sql.catalyst.plans.{Cross, Inner}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types._
@@ -341,7 +341,7 @@ class AnalysisSuite extends AnalysisTest {
Join(
Project(Seq($"x.key"), SubqueryAlias("x", input, None)),
Project(Seq($"y.key"), SubqueryAlias("y", input, None)),
- Inner, None))
+ Cross, None))
assertAnalysisSuccess(query)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
index dbb3e6a527..087718b3ec 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins
-import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
+import org.apache.spark.sql.catalyst.plans.{Cross, Inner, InnerLike, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
@@ -54,6 +54,18 @@ class JoinOptimizationSuite extends PlanTest {
val z = testRelation.subquery('z)
def testExtract(plan: LogicalPlan, expected: Option[(Seq[LogicalPlan], Seq[Expression])]) {
+ val expectedNoCross = expected map {
+ seq_pair => {
+ val plans = seq_pair._1
+ val noCartesian = plans map { plan => (plan, Inner) }
+ (noCartesian, seq_pair._2)
+ }
+ }
+ testExtractCheckCross(plan, expectedNoCross)
+ }
+
+ def testExtractCheckCross
+ (plan: LogicalPlan, expected: Option[(Seq[(LogicalPlan, InnerLike)], Seq[Expression])]) {
assert(ExtractFiltersAndInnerJoins.unapply(plan) === expected)
}
@@ -70,6 +82,16 @@ class JoinOptimizationSuite extends PlanTest {
testExtract(x.join(y).join(x.join(z)), Some(Seq(x, y, x.join(z)), Seq()))
testExtract(x.join(y).join(x.join(z)).where("x.b".attr === "y.d".attr),
Some(Seq(x, y, x.join(z)), Seq("x.b".attr === "y.d".attr)))
+
+ testExtractCheckCross(x.join(y, Cross), Some(Seq((x, Cross), (y, Cross)), Seq()))
+ testExtractCheckCross(x.join(y, Cross).join(z, Cross),
+ Some(Seq((x, Cross), (y, Cross), (z, Cross)), Seq()))
+ testExtractCheckCross(x.join(y, Cross, Some("x.b".attr === "y.d".attr)).join(z, Cross),
+ Some(Seq((x, Cross), (y, Cross), (z, Cross)), Seq("x.b".attr === "y.d".attr)))
+ testExtractCheckCross(x.join(y, Inner, Some("x.b".attr === "y.d".attr)).join(z, Cross),
+ Some(Seq((x, Inner), (y, Inner), (z, Cross)), Seq("x.b".attr === "y.d".attr)))
+ testExtractCheckCross(x.join(y, Cross, Some("x.b".attr === "y.d".attr)).join(z, Inner),
+ Some(Seq((x, Cross), (y, Cross), (z, Inner)), Seq("x.b".attr === "y.d".attr)))
}
test("reorder inner joins") {
@@ -77,18 +99,28 @@ class JoinOptimizationSuite extends PlanTest {
val y = testRelation1.subquery('y)
val z = testRelation.subquery('z)
- val originalQuery = {
- x.join(y).join(z)
- .where(("x.b".attr === "z.b".attr) && ("y.d".attr === "z.a".attr))
+ val queryAnswers = Seq(
+ (
+ x.join(y).join(z).where(("x.b".attr === "z.b".attr) && ("y.d".attr === "z.a".attr)),
+ x.join(z, condition = Some("x.b".attr === "z.b".attr))
+ .join(y, condition = Some("y.d".attr === "z.a".attr))
+ ),
+ (
+ x.join(y, Cross).join(z, Cross)
+ .where(("x.b".attr === "z.b".attr) && ("y.d".attr === "z.a".attr)),
+ x.join(z, Cross, Some("x.b".attr === "z.b".attr))
+ .join(y, Cross, Some("y.d".attr === "z.a".attr))
+ ),
+ (
+ x.join(y, Inner).join(z, Cross).where("x.b".attr === "z.a".attr),
+ x.join(z, Cross, Some("x.b".attr === "z.a".attr)).join(y, Inner)
+ )
+ )
+
+ queryAnswers foreach { queryAnswerPair =>
+ val optimized = Optimize.execute(queryAnswerPair._1.analyze)
+ comparePlans(optimized, analysis.EliminateSubqueryAliases(queryAnswerPair._2.analyze))
}
-
- val optimized = Optimize.execute(originalQuery.analyze)
- val correctAnswer =
- x.join(z, condition = Some("x.b".attr === "z.b".attr))
- .join(y, condition = Some("y.d".attr === "z.a".attr))
- .analyze
-
- comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer))
}
test("broadcasthint sets relation statistics to smallest value") {
@@ -98,7 +130,7 @@ class JoinOptimizationSuite extends PlanTest {
Project(Seq($"x.key", $"y.key"),
Join(
SubqueryAlias("x", input, None),
- BroadcastHint(SubqueryAlias("y", input, None)), Inner, None)).analyze
+ BroadcastHint(SubqueryAlias("y", input, None)), Cross, None)).analyze
val optimized = Optimize.execute(query)
@@ -106,7 +138,7 @@ class JoinOptimizationSuite extends PlanTest {
Join(
Project(Seq($"x.key"), SubqueryAlias("x", input, None)),
BroadcastHint(Project(Seq($"y.key"), SubqueryAlias("y", input, None))),
- Inner, None).analyze
+ Cross, None).analyze
comparePlans(optimized, expected)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala
index c549832ef3..908dde7a66 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala
@@ -67,6 +67,7 @@ class PropagateEmptyRelationSuite extends PlanTest {
// Note that `None` is used to compare with OptimizeWithoutPropagateEmptyRelation.
val testcases = Seq(
(true, true, Inner, None),
+ (true, true, Cross, None),
(true, true, LeftOuter, None),
(true, true, RightOuter, None),
(true, true, FullOuter, None),
@@ -74,6 +75,7 @@ class PropagateEmptyRelationSuite extends PlanTest {
(true, true, LeftSemi, None),
(true, false, Inner, Some(LocalRelation('a.int, 'b.int))),
+ (true, false, Cross, Some(LocalRelation('a.int, 'b.int))),
(true, false, LeftOuter, None),
(true, false, RightOuter, Some(LocalRelation('a.int, 'b.int))),
(true, false, FullOuter, None),
@@ -81,6 +83,7 @@ class PropagateEmptyRelationSuite extends PlanTest {
(true, false, LeftSemi, None),
(false, true, Inner, Some(LocalRelation('a.int, 'b.int))),
+ (false, true, Cross, Some(LocalRelation('a.int, 'b.int))),
(false, true, LeftOuter, Some(LocalRelation('a.int, 'b.int))),
(false, true, RightOuter, None),
(false, true, FullOuter, None),
@@ -88,6 +91,7 @@ class PropagateEmptyRelationSuite extends PlanTest {
(false, true, LeftSemi, Some(LocalRelation('a.int))),
(false, false, Inner, Some(LocalRelation('a.int, 'b.int))),
+ (false, false, Cross, Some(LocalRelation('a.int, 'b.int))),
(false, false, LeftOuter, Some(LocalRelation('a.int, 'b.int))),
(false, false, RightOuter, Some(LocalRelation('a.int, 'b.int))),
(false, false, FullOuter, Some(LocalRelation('a.int, 'b.int))),
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
index 2fcbfc7067..faaea17b64 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
@@ -346,7 +346,7 @@ class PlanParserSuite extends PlanTest {
def test(sql: String, jt: JoinType, tests: Seq[(String, JoinType) => Unit]): Unit = {
tests.foreach(_(sql, jt))
}
- test("cross join", Inner, Seq(testUnconditionalJoin))
+ test("cross join", Cross, Seq(testUnconditionalJoin))
test(",", Inner, Seq(testUnconditionalJoin))
test("join", Inner, testAll)
test("inner join", Inner, testAll)