aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src/test
diff options
context:
space:
mode:
authorSrinath Shankar <srinath@databricks.com>2016-09-03 00:20:43 +0200
committerHerman van Hovell <hvanhovell@databricks.com>2016-09-03 00:20:43 +0200
commite6132a6cf10df8b12af8dd8d1a2c563792b5cc5a (patch)
treed706ac4d4091a7ae31eda5c7d62c2d8c2c4a7414 /sql/catalyst/src/test
parenta2c9acb0e54b2e38cb8ee6431f1ea0e0b4cd959a (diff)
downloadspark-e6132a6cf10df8b12af8dd8d1a2c563792b5cc5a.tar.gz
spark-e6132a6cf10df8b12af8dd8d1a2c563792b5cc5a.tar.bz2
spark-e6132a6cf10df8b12af8dd8d1a2c563792b5cc5a.zip
[SPARK-17298][SQL] Require explicit CROSS join for cartesian products
## What changes were proposed in this pull request? Require the use of CROSS join syntax in SQL (and a new crossJoin DataFrame API) to specify explicit cartesian products between relations. By cartesian product we mean a join between relations R and S where there is no join condition involving columns from both R and S. If a cartesian product is detected in the absence of an explicit CROSS join, an error must be thrown. Turning on the "spark.sql.crossJoin.enabled" configuration flag will disable this check and allow cartesian products without an explicit CROSS join. The new crossJoin DataFrame API must be used to specify explicit cross joins. The existing join(DataFrame) method will produce a INNER join that will require a subsequent join condition. That is df1.join(df2) is equivalent to select * from df1, df2. ## How was this patch tested? Added cross-join.sql to the SQLQueryTestSuite to test the check for cartesian products. Added a couple of tests to the DataFrameJoinSuite to test the crossJoin API. Modified various other test suites to explicitly specify a cross join where an INNER join or a comma-separated list was previously used. Author: Srinath Shankar <srinath@databricks.com> Closes #14866 from srinathshankar/crossjoin.
Diffstat (limited to 'sql/catalyst/src/test')
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala8
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala60
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala2
5 files changed, 57 insertions, 21 deletions
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index 13bf034f83..e7c8615bc5 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Count, Max}
-import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, RightOuter}
+import org.apache.spark.sql.catalyst.plans.{Cross, Inner, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData}
import org.apache.spark.sql.types._
@@ -396,7 +396,7 @@ class AnalysisErrorSuite extends AnalysisTest {
}
test("error test for self-join") {
- val join = Join(testRelation, testRelation, Inner, None)
+ val join = Join(testRelation, testRelation, Cross, None)
val error = intercept[AnalysisException] {
SimpleAnalyzer.checkAnalysis(join)
}
@@ -475,7 +475,7 @@ class AnalysisErrorSuite extends AnalysisTest {
LocalRelation(
AttributeReference("c", BinaryType)(exprId = ExprId(4)),
AttributeReference("d", IntegerType)(exprId = ExprId(3))),
- Inner,
+ Cross,
Some(EqualTo(AttributeReference("a", BinaryType)(exprId = ExprId(2)),
AttributeReference("c", BinaryType)(exprId = ExprId(4)))))
@@ -489,7 +489,7 @@ class AnalysisErrorSuite extends AnalysisTest {
LocalRelation(
AttributeReference("c", MapType(IntegerType, StringType))(exprId = ExprId(4)),
AttributeReference("d", IntegerType)(exprId = ExprId(3))),
- Inner,
+ Cross,
Some(EqualTo(AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)),
AttributeReference("c", MapType(IntegerType, StringType))(exprId = ExprId(4)))))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 8971edc7d3..50ebad25cd 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.{SimpleCatalystConf, TableIdentifier}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.Inner
+import org.apache.spark.sql.catalyst.plans.{Cross, Inner}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types._
@@ -341,7 +341,7 @@ class AnalysisSuite extends AnalysisTest {
Join(
Project(Seq($"x.key"), SubqueryAlias("x", input, None)),
Project(Seq($"y.key"), SubqueryAlias("y", input, None)),
- Inner, None))
+ Cross, None))
assertAnalysisSuccess(query)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
index dbb3e6a527..087718b3ec 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins
-import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
+import org.apache.spark.sql.catalyst.plans.{Cross, Inner, InnerLike, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
@@ -54,6 +54,18 @@ class JoinOptimizationSuite extends PlanTest {
val z = testRelation.subquery('z)
def testExtract(plan: LogicalPlan, expected: Option[(Seq[LogicalPlan], Seq[Expression])]) {
+ val expectedNoCross = expected map {
+ seq_pair => {
+ val plans = seq_pair._1
+ val noCartesian = plans map { plan => (plan, Inner) }
+ (noCartesian, seq_pair._2)
+ }
+ }
+ testExtractCheckCross(plan, expectedNoCross)
+ }
+
+ def testExtractCheckCross
+ (plan: LogicalPlan, expected: Option[(Seq[(LogicalPlan, InnerLike)], Seq[Expression])]) {
assert(ExtractFiltersAndInnerJoins.unapply(plan) === expected)
}
@@ -70,6 +82,16 @@ class JoinOptimizationSuite extends PlanTest {
testExtract(x.join(y).join(x.join(z)), Some(Seq(x, y, x.join(z)), Seq()))
testExtract(x.join(y).join(x.join(z)).where("x.b".attr === "y.d".attr),
Some(Seq(x, y, x.join(z)), Seq("x.b".attr === "y.d".attr)))
+
+ testExtractCheckCross(x.join(y, Cross), Some(Seq((x, Cross), (y, Cross)), Seq()))
+ testExtractCheckCross(x.join(y, Cross).join(z, Cross),
+ Some(Seq((x, Cross), (y, Cross), (z, Cross)), Seq()))
+ testExtractCheckCross(x.join(y, Cross, Some("x.b".attr === "y.d".attr)).join(z, Cross),
+ Some(Seq((x, Cross), (y, Cross), (z, Cross)), Seq("x.b".attr === "y.d".attr)))
+ testExtractCheckCross(x.join(y, Inner, Some("x.b".attr === "y.d".attr)).join(z, Cross),
+ Some(Seq((x, Inner), (y, Inner), (z, Cross)), Seq("x.b".attr === "y.d".attr)))
+ testExtractCheckCross(x.join(y, Cross, Some("x.b".attr === "y.d".attr)).join(z, Inner),
+ Some(Seq((x, Cross), (y, Cross), (z, Inner)), Seq("x.b".attr === "y.d".attr)))
}
test("reorder inner joins") {
@@ -77,18 +99,28 @@ class JoinOptimizationSuite extends PlanTest {
val y = testRelation1.subquery('y)
val z = testRelation.subquery('z)
- val originalQuery = {
- x.join(y).join(z)
- .where(("x.b".attr === "z.b".attr) && ("y.d".attr === "z.a".attr))
+ val queryAnswers = Seq(
+ (
+ x.join(y).join(z).where(("x.b".attr === "z.b".attr) && ("y.d".attr === "z.a".attr)),
+ x.join(z, condition = Some("x.b".attr === "z.b".attr))
+ .join(y, condition = Some("y.d".attr === "z.a".attr))
+ ),
+ (
+ x.join(y, Cross).join(z, Cross)
+ .where(("x.b".attr === "z.b".attr) && ("y.d".attr === "z.a".attr)),
+ x.join(z, Cross, Some("x.b".attr === "z.b".attr))
+ .join(y, Cross, Some("y.d".attr === "z.a".attr))
+ ),
+ (
+ x.join(y, Inner).join(z, Cross).where("x.b".attr === "z.a".attr),
+ x.join(z, Cross, Some("x.b".attr === "z.a".attr)).join(y, Inner)
+ )
+ )
+
+ queryAnswers foreach { queryAnswerPair =>
+ val optimized = Optimize.execute(queryAnswerPair._1.analyze)
+ comparePlans(optimized, analysis.EliminateSubqueryAliases(queryAnswerPair._2.analyze))
}
-
- val optimized = Optimize.execute(originalQuery.analyze)
- val correctAnswer =
- x.join(z, condition = Some("x.b".attr === "z.b".attr))
- .join(y, condition = Some("y.d".attr === "z.a".attr))
- .analyze
-
- comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer))
}
test("broadcasthint sets relation statistics to smallest value") {
@@ -98,7 +130,7 @@ class JoinOptimizationSuite extends PlanTest {
Project(Seq($"x.key", $"y.key"),
Join(
SubqueryAlias("x", input, None),
- BroadcastHint(SubqueryAlias("y", input, None)), Inner, None)).analyze
+ BroadcastHint(SubqueryAlias("y", input, None)), Cross, None)).analyze
val optimized = Optimize.execute(query)
@@ -106,7 +138,7 @@ class JoinOptimizationSuite extends PlanTest {
Join(
Project(Seq($"x.key"), SubqueryAlias("x", input, None)),
BroadcastHint(Project(Seq($"y.key"), SubqueryAlias("y", input, None))),
- Inner, None).analyze
+ Cross, None).analyze
comparePlans(optimized, expected)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala
index c549832ef3..908dde7a66 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala
@@ -67,6 +67,7 @@ class PropagateEmptyRelationSuite extends PlanTest {
// Note that `None` is used to compare with OptimizeWithoutPropagateEmptyRelation.
val testcases = Seq(
(true, true, Inner, None),
+ (true, true, Cross, None),
(true, true, LeftOuter, None),
(true, true, RightOuter, None),
(true, true, FullOuter, None),
@@ -74,6 +75,7 @@ class PropagateEmptyRelationSuite extends PlanTest {
(true, true, LeftSemi, None),
(true, false, Inner, Some(LocalRelation('a.int, 'b.int))),
+ (true, false, Cross, Some(LocalRelation('a.int, 'b.int))),
(true, false, LeftOuter, None),
(true, false, RightOuter, Some(LocalRelation('a.int, 'b.int))),
(true, false, FullOuter, None),
@@ -81,6 +83,7 @@ class PropagateEmptyRelationSuite extends PlanTest {
(true, false, LeftSemi, None),
(false, true, Inner, Some(LocalRelation('a.int, 'b.int))),
+ (false, true, Cross, Some(LocalRelation('a.int, 'b.int))),
(false, true, LeftOuter, Some(LocalRelation('a.int, 'b.int))),
(false, true, RightOuter, None),
(false, true, FullOuter, None),
@@ -88,6 +91,7 @@ class PropagateEmptyRelationSuite extends PlanTest {
(false, true, LeftSemi, Some(LocalRelation('a.int))),
(false, false, Inner, Some(LocalRelation('a.int, 'b.int))),
+ (false, false, Cross, Some(LocalRelation('a.int, 'b.int))),
(false, false, LeftOuter, Some(LocalRelation('a.int, 'b.int))),
(false, false, RightOuter, Some(LocalRelation('a.int, 'b.int))),
(false, false, FullOuter, Some(LocalRelation('a.int, 'b.int))),
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
index 2fcbfc7067..faaea17b64 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
@@ -346,7 +346,7 @@ class PlanParserSuite extends PlanTest {
def test(sql: String, jt: JoinType, tests: Seq[(String, JoinType) => Unit]): Unit = {
tests.foreach(_(sql, jt))
}
- test("cross join", Inner, Seq(testUnconditionalJoin))
+ test("cross join", Cross, Seq(testUnconditionalJoin))
test(",", Inner, Seq(testUnconditionalJoin))
test("join", Inner, testAll)
test("inner join", Inner, testAll)