diff options
author | Srinath Shankar <srinath@databricks.com> | 2016-09-03 00:20:43 +0200 |
---|---|---|
committer | Herman van Hovell <hvanhovell@databricks.com> | 2016-09-03 00:20:43 +0200 |
commit | e6132a6cf10df8b12af8dd8d1a2c563792b5cc5a (patch) | |
tree | d706ac4d4091a7ae31eda5c7d62c2d8c2c4a7414 /sql/catalyst/src/test | |
parent | a2c9acb0e54b2e38cb8ee6431f1ea0e0b4cd959a (diff) | |
download | spark-e6132a6cf10df8b12af8dd8d1a2c563792b5cc5a.tar.gz spark-e6132a6cf10df8b12af8dd8d1a2c563792b5cc5a.tar.bz2 spark-e6132a6cf10df8b12af8dd8d1a2c563792b5cc5a.zip |
[SPARK-17298][SQL] Require explicit CROSS join for cartesian products
## What changes were proposed in this pull request?
Require the use of CROSS join syntax in SQL (and a new crossJoin
DataFrame API) to specify explicit cartesian products between relations.
By cartesian product we mean a join between relations R and S where
there is no join condition involving columns from both R and S.
If a cartesian product is detected in the absence of an explicit CROSS
join, an error must be thrown. Turning on the
"spark.sql.crossJoin.enabled" configuration flag will disable this check
and allow cartesian products without an explicit CROSS join.
The new crossJoin DataFrame API must be used to specify explicit cross
joins. The existing join(DataFrame) method will produce a INNER join
that will require a subsequent join condition.
That is df1.join(df2) is equivalent to select * from df1, df2.
## How was this patch tested?
Added cross-join.sql to the SQLQueryTestSuite to test the check for cartesian products. Added a couple of tests to the DataFrameJoinSuite to test the crossJoin API. Modified various other test suites to explicitly specify a cross join where an INNER join or a comma-separated list was previously used.
Author: Srinath Shankar <srinath@databricks.com>
Closes #14866 from srinathshankar/crossjoin.
Diffstat (limited to 'sql/catalyst/src/test')
5 files changed, 57 insertions, 21 deletions
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 13bf034f83..e7c8615bc5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Count, Max} -import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.plans.{Cross, Inner, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData} import org.apache.spark.sql.types._ @@ -396,7 +396,7 @@ class AnalysisErrorSuite extends AnalysisTest { } test("error test for self-join") { - val join = Join(testRelation, testRelation, Inner, None) + val join = Join(testRelation, testRelation, Cross, None) val error = intercept[AnalysisException] { SimpleAnalyzer.checkAnalysis(join) } @@ -475,7 +475,7 @@ class AnalysisErrorSuite extends AnalysisTest { LocalRelation( AttributeReference("c", BinaryType)(exprId = ExprId(4)), AttributeReference("d", IntegerType)(exprId = ExprId(3))), - Inner, + Cross, Some(EqualTo(AttributeReference("a", BinaryType)(exprId = ExprId(2)), AttributeReference("c", BinaryType)(exprId = ExprId(4))))) @@ -489,7 +489,7 @@ class AnalysisErrorSuite extends AnalysisTest { LocalRelation( AttributeReference("c", MapType(IntegerType, StringType))(exprId = ExprId(4)), AttributeReference("d", IntegerType)(exprId = ExprId(3))), - Inner, + Cross, Some(EqualTo(AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)), AttributeReference("c", MapType(IntegerType, StringType))(exprId = ExprId(4))))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 8971edc7d3..50ebad25cd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.{SimpleCatalystConf, TableIdentifier} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.{Cross, Inner} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -341,7 +341,7 @@ class AnalysisSuite extends AnalysisTest { Join( Project(Seq($"x.key"), SubqueryAlias("x", input, None)), Project(Seq($"y.key"), SubqueryAlias("y", input, None)), - Inner, None)) + Cross, None)) assertAnalysisSuccess(query) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala index dbb3e6a527..087718b3ec 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins -import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} +import org.apache.spark.sql.catalyst.plans.{Cross, Inner, InnerLike, PlanTest} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -54,6 +54,18 @@ class JoinOptimizationSuite extends PlanTest { val z = testRelation.subquery('z) def testExtract(plan: LogicalPlan, expected: Option[(Seq[LogicalPlan], Seq[Expression])]) { + val expectedNoCross = expected map { + seq_pair => { + val plans = seq_pair._1 + val noCartesian = plans map { plan => (plan, Inner) } + (noCartesian, seq_pair._2) + } + } + testExtractCheckCross(plan, expectedNoCross) + } + + def testExtractCheckCross + (plan: LogicalPlan, expected: Option[(Seq[(LogicalPlan, InnerLike)], Seq[Expression])]) { assert(ExtractFiltersAndInnerJoins.unapply(plan) === expected) } @@ -70,6 +82,16 @@ class JoinOptimizationSuite extends PlanTest { testExtract(x.join(y).join(x.join(z)), Some(Seq(x, y, x.join(z)), Seq())) testExtract(x.join(y).join(x.join(z)).where("x.b".attr === "y.d".attr), Some(Seq(x, y, x.join(z)), Seq("x.b".attr === "y.d".attr))) + + testExtractCheckCross(x.join(y, Cross), Some(Seq((x, Cross), (y, Cross)), Seq())) + testExtractCheckCross(x.join(y, Cross).join(z, Cross), + Some(Seq((x, Cross), (y, Cross), (z, Cross)), Seq())) + testExtractCheckCross(x.join(y, Cross, Some("x.b".attr === "y.d".attr)).join(z, Cross), + Some(Seq((x, Cross), (y, Cross), (z, Cross)), Seq("x.b".attr === "y.d".attr))) + testExtractCheckCross(x.join(y, Inner, Some("x.b".attr === "y.d".attr)).join(z, Cross), + Some(Seq((x, Inner), (y, Inner), (z, Cross)), Seq("x.b".attr === "y.d".attr))) + testExtractCheckCross(x.join(y, Cross, Some("x.b".attr === "y.d".attr)).join(z, Inner), + Some(Seq((x, Cross), (y, Cross), (z, Inner)), Seq("x.b".attr === "y.d".attr))) } test("reorder inner joins") { @@ -77,18 +99,28 @@ class JoinOptimizationSuite extends PlanTest { val y = testRelation1.subquery('y) val z = testRelation.subquery('z) - val originalQuery = { - x.join(y).join(z) - .where(("x.b".attr === "z.b".attr) && ("y.d".attr === "z.a".attr)) + val queryAnswers = Seq( + ( + x.join(y).join(z).where(("x.b".attr === "z.b".attr) && ("y.d".attr === "z.a".attr)), + x.join(z, condition = Some("x.b".attr === "z.b".attr)) + .join(y, condition = Some("y.d".attr === "z.a".attr)) + ), + ( + x.join(y, Cross).join(z, Cross) + .where(("x.b".attr === "z.b".attr) && ("y.d".attr === "z.a".attr)), + x.join(z, Cross, Some("x.b".attr === "z.b".attr)) + .join(y, Cross, Some("y.d".attr === "z.a".attr)) + ), + ( + x.join(y, Inner).join(z, Cross).where("x.b".attr === "z.a".attr), + x.join(z, Cross, Some("x.b".attr === "z.a".attr)).join(y, Inner) + ) + ) + + queryAnswers foreach { queryAnswerPair => + val optimized = Optimize.execute(queryAnswerPair._1.analyze) + comparePlans(optimized, analysis.EliminateSubqueryAliases(queryAnswerPair._2.analyze)) } - - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = - x.join(z, condition = Some("x.b".attr === "z.b".attr)) - .join(y, condition = Some("y.d".attr === "z.a".attr)) - .analyze - - comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer)) } test("broadcasthint sets relation statistics to smallest value") { @@ -98,7 +130,7 @@ class JoinOptimizationSuite extends PlanTest { Project(Seq($"x.key", $"y.key"), Join( SubqueryAlias("x", input, None), - BroadcastHint(SubqueryAlias("y", input, None)), Inner, None)).analyze + BroadcastHint(SubqueryAlias("y", input, None)), Cross, None)).analyze val optimized = Optimize.execute(query) @@ -106,7 +138,7 @@ class JoinOptimizationSuite extends PlanTest { Join( Project(Seq($"x.key"), SubqueryAlias("x", input, None)), BroadcastHint(Project(Seq($"y.key"), SubqueryAlias("y", input, None))), - Inner, None).analyze + Cross, None).analyze comparePlans(optimized, expected) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala index c549832ef3..908dde7a66 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala @@ -67,6 +67,7 @@ class PropagateEmptyRelationSuite extends PlanTest { // Note that `None` is used to compare with OptimizeWithoutPropagateEmptyRelation. val testcases = Seq( (true, true, Inner, None), + (true, true, Cross, None), (true, true, LeftOuter, None), (true, true, RightOuter, None), (true, true, FullOuter, None), @@ -74,6 +75,7 @@ class PropagateEmptyRelationSuite extends PlanTest { (true, true, LeftSemi, None), (true, false, Inner, Some(LocalRelation('a.int, 'b.int))), + (true, false, Cross, Some(LocalRelation('a.int, 'b.int))), (true, false, LeftOuter, None), (true, false, RightOuter, Some(LocalRelation('a.int, 'b.int))), (true, false, FullOuter, None), @@ -81,6 +83,7 @@ class PropagateEmptyRelationSuite extends PlanTest { (true, false, LeftSemi, None), (false, true, Inner, Some(LocalRelation('a.int, 'b.int))), + (false, true, Cross, Some(LocalRelation('a.int, 'b.int))), (false, true, LeftOuter, Some(LocalRelation('a.int, 'b.int))), (false, true, RightOuter, None), (false, true, FullOuter, None), @@ -88,6 +91,7 @@ class PropagateEmptyRelationSuite extends PlanTest { (false, true, LeftSemi, Some(LocalRelation('a.int))), (false, false, Inner, Some(LocalRelation('a.int, 'b.int))), + (false, false, Cross, Some(LocalRelation('a.int, 'b.int))), (false, false, LeftOuter, Some(LocalRelation('a.int, 'b.int))), (false, false, RightOuter, Some(LocalRelation('a.int, 'b.int))), (false, false, FullOuter, Some(LocalRelation('a.int, 'b.int))), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 2fcbfc7067..faaea17b64 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -346,7 +346,7 @@ class PlanParserSuite extends PlanTest { def test(sql: String, jt: JoinType, tests: Seq[(String, JoinType) => Unit]): Unit = { tests.foreach(_(sql, jt)) } - test("cross join", Inner, Seq(testUnconditionalJoin)) + test("cross join", Cross, Seq(testUnconditionalJoin)) test(",", Inner, Seq(testUnconditionalJoin)) test("join", Inner, testAll) test("inner join", Inner, testAll) |