aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src
diff options
context:
space:
mode:
authorSrinath Shankar <srinath@databricks.com>2016-09-03 00:20:43 +0200
committerHerman van Hovell <hvanhovell@databricks.com>2016-09-03 00:20:43 +0200
commite6132a6cf10df8b12af8dd8d1a2c563792b5cc5a (patch)
treed706ac4d4091a7ae31eda5c7d62c2d8c2c4a7414 /sql/core/src
parenta2c9acb0e54b2e38cb8ee6431f1ea0e0b4cd959a (diff)
downloadspark-e6132a6cf10df8b12af8dd8d1a2c563792b5cc5a.tar.gz
spark-e6132a6cf10df8b12af8dd8d1a2c563792b5cc5a.tar.bz2
spark-e6132a6cf10df8b12af8dd8d1a2c563792b5cc5a.zip
[SPARK-17298][SQL] Require explicit CROSS join for cartesian products
## What changes were proposed in this pull request? Require the use of CROSS join syntax in SQL (and a new crossJoin DataFrame API) to specify explicit cartesian products between relations. By cartesian product we mean a join between relations R and S where there is no join condition involving columns from both R and S. If a cartesian product is detected in the absence of an explicit CROSS join, an error must be thrown. Turning on the "spark.sql.crossJoin.enabled" configuration flag will disable this check and allow cartesian products without an explicit CROSS join. The new crossJoin DataFrame API must be used to specify explicit cross joins. The existing join(DataFrame) method will produce a INNER join that will require a subsequent join condition. That is df1.join(df2) is equivalent to select * from df1, df2. ## How was this patch tested? Added cross-join.sql to the SQLQueryTestSuite to test the check for cartesian products. Added a couple of tests to the DataFrameJoinSuite to test the crossJoin API. Modified various other test suites to explicitly specify a cross join where an INNER join or a comma-separated list was previously used. Author: Srinath Shankar <srinath@databricks.com> Closes #14866 from srinathshankar/crossjoin.
Diffstat (limited to 'sql/core/src')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala18
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala16
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala7
-rw-r--r--sql/core/src/test/resources/sql-tests/inputs/cross-join.sql35
-rw-r--r--sql/core/src/test/resources/sql-tests/inputs/cte.sql2
-rw-r--r--sql/core/src/test/resources/sql-tests/inputs/outer-join.sql5
-rw-r--r--sql/core/src/test/resources/sql-tests/results/cross-join.sql.out129
-rw-r--r--sql/core/src/test/resources/sql-tests/results/cte.sql.out2
-rw-r--r--sql/core/src/test/resources/sql-tests/results/outer-join.sql.out22
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala19
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala37
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala8
20 files changed, 288 insertions, 62 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index e7dcf0f51f..3b3cb82078 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -589,9 +589,9 @@ class Dataset[T] private[sql](
def stat: DataFrameStatFunctions = new DataFrameStatFunctions(toDF())
/**
- * Cartesian join with another [[DataFrame]].
+ * Join with another [[DataFrame]].
*
- * Note that cartesian joins are very expensive without an extra filter that can be pushed down.
+ * Behaves as an INNER JOIN and requires a subsequent join predicate.
*
* @param right Right side of the join operation.
*
@@ -764,6 +764,20 @@ class Dataset[T] private[sql](
}
/**
+ * Explicit cartesian join with another [[DataFrame]].
+ *
+ * Note that cartesian joins are very expensive without an extra filter that can be pushed down.
+ *
+ * @param right Right side of the join operation.
+ *
+ * @group untypedrel
+ * @since 2.0.0
+ */
+ def crossJoin(right: Dataset[_]): DataFrame = withPlan {
+ Join(logicalPlan, right.logicalPlan, joinType = Cross, None)
+ }
+
+ /**
* :: Experimental ::
* Joins this Dataset returning a [[Tuple2]] for each pair where `condition` evaluates to
* true.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index b4899ad688..c389593b4f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -140,13 +140,13 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
private def canBuildRight(joinType: JoinType): Boolean = joinType match {
- case Inner | LeftOuter | LeftSemi | LeftAnti => true
+ case _: InnerLike | LeftOuter | LeftSemi | LeftAnti => true
case j: ExistenceJoin => true
case _ => false
}
private def canBuildLeft(joinType: JoinType): Boolean = joinType match {
- case Inner | RightOuter => true
+ case _: InnerLike | RightOuter => true
case _ => false
}
@@ -200,7 +200,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
planLater(left), planLater(right), BuildLeft, joinType, condition) :: Nil
// Pick CartesianProduct for InnerJoin
- case logical.Join(left, right, Inner, condition) =>
+ case logical.Join(left, right, _: InnerLike, condition) =>
joins.CartesianProductExec(planLater(left), planLater(right), condition) :: Nil
case logical.Join(left, right, joinType, condition) =>
@@ -212,8 +212,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
// This join could be very slow or OOM
joins.BroadcastNestedLoopJoinExec(
- planLater(left), planLater(right), buildSide, joinType, condition,
- withinBroadcastThreshold = false) :: Nil
+ planLater(left), planLater(right), buildSide, joinType, condition) :: Nil
// --- Cases where this strategy does not apply ---------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
index 0f24baacd1..0bc261d593 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
@@ -79,7 +79,7 @@ case class BroadcastHashJoinExec(
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
joinType match {
- case Inner => codegenInner(ctx, input)
+ case _: InnerLike => codegenInner(ctx, input)
case LeftOuter | RightOuter => codegenOuter(ctx, input)
case LeftSemi => codegenSemi(ctx, input)
case LeftAnti => codegenAnti(ctx, input)
@@ -134,7 +134,7 @@ case class BroadcastHashJoinExec(
ctx.INPUT_ROW = matched
buildPlan.output.zipWithIndex.map { case (a, i) =>
val ev = BoundReference(i, a.dataType, a.nullable).genCode(ctx)
- if (joinType == Inner) {
+ if (joinType.isInstanceOf[InnerLike]) {
ev
} else {
// the variables are needed even there is no matched rows
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
index 6a9965f1a2..43cdce7de8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
@@ -34,8 +34,7 @@ case class BroadcastNestedLoopJoinExec(
right: SparkPlan,
buildSide: BuildSide,
joinType: JoinType,
- condition: Option[Expression],
- withinBroadcastThreshold: Boolean = true) extends BinaryExecNode {
+ condition: Option[Expression]) extends BinaryExecNode {
override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
@@ -65,7 +64,7 @@ case class BroadcastNestedLoopJoinExec(
override def output: Seq[Attribute] = {
joinType match {
- case Inner =>
+ case _: InnerLike =>
left.output ++ right.output
case LeftOuter =>
left.output ++ right.output.map(_.withNullability(true))
@@ -340,20 +339,11 @@ case class BroadcastNestedLoopJoinExec(
)
}
- protected override def doPrepare(): Unit = {
- if (!withinBroadcastThreshold && !sqlContext.conf.crossJoinEnabled) {
- throw new AnalysisException("Both sides of this join are outside the broadcasting " +
- "threshold and computing it could be prohibitively expensive. To explicitly enable it, " +
- s"please set ${SQLConf.CROSS_JOINS_ENABLED.key} = true")
- }
- super.doPrepare()
- }
-
protected override def doExecute(): RDD[InternalRow] = {
val broadcastedRelation = broadcast.executeBroadcast[Array[InternalRow]]()
val resultRdd = (joinType, buildSide) match {
- case (Inner, _) =>
+ case (_: InnerLike, _) =>
innerJoin(broadcastedRelation)
case (LeftOuter, BuildRight) | (RightOuter, BuildLeft) =>
outerJoin(broadcastedRelation)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
index 57866df90d..15dc9b4066 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
@@ -91,15 +91,6 @@ case class CartesianProductExec(
override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
- protected override def doPrepare(): Unit = {
- if (!sqlContext.conf.crossJoinEnabled) {
- throw new AnalysisException("Cartesian joins could be prohibitively expensive and are " +
- "disabled by default. To explicitly enable them, please set " +
- s"${SQLConf.CROSS_JOINS_ENABLED.key} = true")
- }
- super.doPrepare()
- }
-
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index d46a80423f..fb6bfa7b27 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -38,7 +38,7 @@ trait HashJoin {
override def output: Seq[Attribute] = {
joinType match {
- case Inner =>
+ case _: InnerLike =>
left.output ++ right.output
case LeftOuter =>
left.output ++ right.output.map(_.withNullability(true))
@@ -225,7 +225,7 @@ trait HashJoin {
numOutputRows: SQLMetric): Iterator[InternalRow] = {
val joinedIter = joinType match {
- case Inner =>
+ case _: InnerLike =>
innerJoin(streamedIter, hashed)
case LeftOuter | RightOuter =>
outerJoin(streamedIter, hashed)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index 5c9c1e6062..b46af2a99a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -45,7 +45,7 @@ case class SortMergeJoinExec(
override def output: Seq[Attribute] = {
joinType match {
- case Inner =>
+ case _: InnerLike =>
left.output ++ right.output
case LeftOuter =>
left.output ++ right.output.map(_.withNullability(true))
@@ -64,7 +64,8 @@ case class SortMergeJoinExec(
}
override def outputPartitioning: Partitioning = joinType match {
- case Inner => PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning))
+ case _: InnerLike =>
+ PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning))
// For left and right outer joins, the output is partitioned by the streamed input's join keys.
case LeftOuter => left.outputPartitioning
case RightOuter => right.outputPartitioning
@@ -111,7 +112,7 @@ case class SortMergeJoinExec(
val resultProj: InternalRow => InternalRow = UnsafeProjection.create(output, output)
joinType match {
- case Inner =>
+ case _: InnerLike =>
new RowIterator {
private[this] var currentLeftRow: InternalRow = _
private[this] var currentRightMatches: ArrayBuffer[InternalRow] = _
@@ -318,7 +319,7 @@ case class SortMergeJoinExec(
}
override def supportCodegen: Boolean = {
- joinType == Inner
+ joinType.isInstanceOf[InnerLike]
}
override def inputRDDs(): Seq[RDD[InternalRow]] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index a54342f82e..1d6ca5a965 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -362,7 +362,8 @@ object SQLConf {
.createWithDefault(true)
val CROSS_JOINS_ENABLED = SQLConfigBuilder("spark.sql.crossJoin.enabled")
- .doc("When false, we will throw an error if a query contains a cross join")
+ .doc("When false, we will throw an error if a query contains a cartesian product without " +
+ "explicit CROSS JOIN syntax.")
.booleanConf
.createWithDefault(false)
@@ -683,8 +684,6 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging {
def bucketingEnabled: Boolean = getConf(SQLConf.BUCKETING_ENABLED)
- def crossJoinEnabled: Boolean = getConf(SQLConf.CROSS_JOINS_ENABLED)
-
// Do not use a value larger than 4000 as the default value of this property.
// See the comments of SCHEMA_STRING_LENGTH_THRESHOLD above for more information.
def schemaStringLengthThreshold: Int = getConf(SCHEMA_STRING_LENGTH_THRESHOLD)
@@ -709,6 +708,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging {
override def orderByOrdinal: Boolean = getConf(ORDER_BY_ORDINAL)
override def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL)
+
+ override def crossJoinEnabled: Boolean = getConf(SQLConf.CROSS_JOINS_ENABLED)
/** ********************** SQLConf functionality methods ************ */
/** Set Spark SQL configuration properties. */
diff --git a/sql/core/src/test/resources/sql-tests/inputs/cross-join.sql b/sql/core/src/test/resources/sql-tests/inputs/cross-join.sql
new file mode 100644
index 0000000000..aa73124374
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/cross-join.sql
@@ -0,0 +1,35 @@
+-- Cross join detection and error checking is done in JoinSuite since explain output is
+-- used in the error message and the ids are not stable. Only positive cases are checked here.
+
+create temporary view nt1 as select * from values
+ ("one", 1),
+ ("two", 2),
+ ("three", 3)
+ as nt1(k, v1);
+
+create temporary view nt2 as select * from values
+ ("one", 1),
+ ("two", 22),
+ ("one", 5)
+ as nt2(k, v2);
+
+-- Cross joins with and without predicates
+SELECT * FROM nt1 cross join nt2;
+SELECT * FROM nt1 cross join nt2 where nt1.k = nt2.k;
+SELECT * FROM nt1 cross join nt2 on (nt1.k = nt2.k);
+SELECT * FROM nt1 cross join nt2 where nt1.v1 = 1 and nt2.v2 = 22;
+
+SELECT a.key, b.key FROM
+(SELECT k key FROM nt1 WHERE v1 < 2) a
+CROSS JOIN
+(SELECT k key FROM nt2 WHERE v2 = 22) b;
+
+-- Join reordering
+create temporary view A(a, va) as select * from nt1;
+create temporary view B(b, vb) as select * from nt1;
+create temporary view C(c, vc) as select * from nt1;
+create temporary view D(d, vd) as select * from nt1;
+
+-- Allowed since cross join with C is explicit
+select * from ((A join B on (a = b)) cross join C) join D on (a = d);
+
diff --git a/sql/core/src/test/resources/sql-tests/inputs/cte.sql b/sql/core/src/test/resources/sql-tests/inputs/cte.sql
index 10d34deff4..3914db2691 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/cte.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/cte.sql
@@ -11,4 +11,4 @@ WITH t AS (SELECT 1 FROM t) SELECT * FROM t;
WITH s1 AS (SELECT 1 FROM s2), s2 AS (SELECT 1 FROM s1) SELECT * FROM s1, s2;
-- WITH clause should reference the previous CTE
-WITH t1 AS (SELECT * FROM t2), t2 AS (SELECT 2 FROM t1) SELECT * FROM t1, t2;
+WITH t1 AS (SELECT * FROM t2), t2 AS (SELECT 2 FROM t1) SELECT * FROM t1 cross join t2;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/outer-join.sql b/sql/core/src/test/resources/sql-tests/inputs/outer-join.sql
index f50f1ebad9..cdc6c81e10 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/outer-join.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/outer-join.sql
@@ -24,6 +24,9 @@ CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES (97) as t1(int_col1)
CREATE OR REPLACE TEMPORARY VIEW t2 AS SELECT * FROM VALUES (0) as t2(int_col1);
+-- Set the cross join enabled flag for the LEFT JOIN test since there's no join condition.
+-- Ultimately the join should be optimized away.
+set spark.sql.crossJoin.enabled = true;
SELECT *
FROM (
SELECT
@@ -31,6 +34,6 @@ SELECT
FROM t1
LEFT JOIN t2 ON false
) t where (t.int_col) is not null;
-
+set spark.sql.crossJoin.enabled = false;
diff --git a/sql/core/src/test/resources/sql-tests/results/cross-join.sql.out b/sql/core/src/test/resources/sql-tests/results/cross-join.sql.out
new file mode 100644
index 0000000000..562e174fc0
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/cross-join.sql.out
@@ -0,0 +1,129 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 12
+
+
+-- !query 0
+create temporary view nt1 as select * from values
+ ("one", 1),
+ ("two", 2),
+ ("three", 3)
+ as nt1(k, v1)
+-- !query 0 schema
+struct<>
+-- !query 0 output
+
+
+
+-- !query 1
+create temporary view nt2 as select * from values
+ ("one", 1),
+ ("two", 22),
+ ("one", 5)
+ as nt2(k, v2)
+-- !query 1 schema
+struct<>
+-- !query 1 output
+
+
+
+-- !query 2
+SELECT * FROM nt1 cross join nt2
+-- !query 2 schema
+struct<k:string,v1:int,k:string,v2:int>
+-- !query 2 output
+one 1 one 1
+one 1 one 5
+one 1 two 22
+three 3 one 1
+three 3 one 5
+three 3 two 22
+two 2 one 1
+two 2 one 5
+two 2 two 22
+
+
+-- !query 3
+SELECT * FROM nt1 cross join nt2 where nt1.k = nt2.k
+-- !query 3 schema
+struct<k:string,v1:int,k:string,v2:int>
+-- !query 3 output
+one 1 one 1
+one 1 one 5
+two 2 two 22
+
+
+-- !query 4
+SELECT * FROM nt1 cross join nt2 on (nt1.k = nt2.k)
+-- !query 4 schema
+struct<k:string,v1:int,k:string,v2:int>
+-- !query 4 output
+one 1 one 1
+one 1 one 5
+two 2 two 22
+
+
+-- !query 5
+SELECT * FROM nt1 cross join nt2 where nt1.v1 = 1 and nt2.v2 = 22
+-- !query 5 schema
+struct<k:string,v1:int,k:string,v2:int>
+-- !query 5 output
+one 1 two 22
+
+
+-- !query 6
+SELECT a.key, b.key FROM
+(SELECT k key FROM nt1 WHERE v1 < 2) a
+CROSS JOIN
+(SELECT k key FROM nt2 WHERE v2 = 22) b
+-- !query 6 schema
+struct<key:string,key:string>
+-- !query 6 output
+one two
+
+
+-- !query 7
+create temporary view A(a, va) as select * from nt1
+-- !query 7 schema
+struct<>
+-- !query 7 output
+
+
+
+-- !query 8
+create temporary view B(b, vb) as select * from nt1
+-- !query 8 schema
+struct<>
+-- !query 8 output
+
+
+
+-- !query 9
+create temporary view C(c, vc) as select * from nt1
+-- !query 9 schema
+struct<>
+-- !query 9 output
+
+
+
+-- !query 10
+create temporary view D(d, vd) as select * from nt1
+-- !query 10 schema
+struct<>
+-- !query 10 output
+
+
+
+-- !query 11
+select * from ((A join B on (a = b)) cross join C) join D on (a = d)
+-- !query 11 schema
+struct<a:string,va:int,b:string,vb:int,c:string,vc:int,d:string,vd:int>
+-- !query 11 output
+one 1 one 1 one 1 one 1
+one 1 one 1 three 3 one 1
+one 1 one 1 two 2 one 1
+three 3 three 3 one 1 three 3
+three 3 three 3 three 3 three 3
+three 3 three 3 two 2 three 3
+two 2 two 2 one 1 two 2
+two 2 two 2 three 3 two 2
+two 2 two 2 two 2 two 2
diff --git a/sql/core/src/test/resources/sql-tests/results/cte.sql.out b/sql/core/src/test/resources/sql-tests/results/cte.sql.out
index ddee5bf2d4..9fbad8f380 100644
--- a/sql/core/src/test/resources/sql-tests/results/cte.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/cte.sql.out
@@ -47,7 +47,7 @@ Table or view not found: s2; line 1 pos 26
-- !query 5
-WITH t1 AS (SELECT * FROM t2), t2 AS (SELECT 2 FROM t1) SELECT * FROM t1, t2
+WITH t1 AS (SELECT * FROM t2), t2 AS (SELECT 2 FROM t1) SELECT * FROM t1 cross join t2
-- !query 5 schema
struct<id:int,2:int>
-- !query 5 output
diff --git a/sql/core/src/test/resources/sql-tests/results/outer-join.sql.out b/sql/core/src/test/resources/sql-tests/results/outer-join.sql.out
index b39fdb0e58..cc50b9444b 100644
--- a/sql/core/src/test/resources/sql-tests/results/outer-join.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/outer-join.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 6
+-- Number of queries: 8
-- !query 0
@@ -59,6 +59,14 @@ struct<>
-- !query 5
+set spark.sql.crossJoin.enabled = true
+-- !query 5 schema
+struct<key:string,value:string>
+-- !query 5 output
+spark.sql.crossJoin.enabled
+
+
+-- !query 6
SELECT *
FROM (
SELECT
@@ -66,7 +74,15 @@ SELECT
FROM t1
LEFT JOIN t2 ON false
) t where (t.int_col) is not null
--- !query 5 schema
+-- !query 6 schema
struct<int_col:int>
--- !query 5 output
+-- !query 6 output
97
+
+
+-- !query 7
+set spark.sql.crossJoin.enabled = false
+-- !query 7 schema
+struct<key:string,value:string>
+-- !query 7 output
+spark.sql.crossJoin.enabled
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
index 4abf5e42b9..541ffb58e7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
@@ -104,6 +104,21 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext {
.collect().toSeq)
}
+ test("join - cross join") {
+ val df1 = Seq((1, "1"), (3, "3")).toDF("int", "str")
+ val df2 = Seq((2, "2"), (4, "4")).toDF("int", "str")
+
+ checkAnswer(
+ df1.crossJoin(df2),
+ Row(1, "1", 2, "2") :: Row(1, "1", 4, "4") ::
+ Row(3, "3", 2, "2") :: Row(3, "3", 4, "4") :: Nil)
+
+ checkAnswer(
+ df2.crossJoin(df1),
+ Row(2, "2", 1, "1") :: Row(2, "2", 3, "3") ::
+ Row(4, "4", 1, "1") :: Row(4, "4", 3, "3") :: Nil)
+ }
+
test("join - using aliases after self join") {
val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str")
checkAnswer(
@@ -145,7 +160,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext {
assert(plan1.collect { case p: BroadcastHashJoinExec => p }.size === 1)
// no join key -- should not be a broadcast join
- val plan2 = df1.join(broadcast(df2)).queryExecution.sparkPlan
+ val plan2 = df1.crossJoin(broadcast(df2)).queryExecution.sparkPlan
assert(plan2.collect { case p: BroadcastHashJoinExec => p }.size === 0)
// planner should not crash without a join
@@ -155,7 +170,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext {
withTempPath { path =>
df1.write.parquet(path.getCanonicalPath)
val pf1 = spark.read.parquet(path.getCanonicalPath)
- assert(df1.join(broadcast(pf1)).count() === 4)
+ assert(df1.crossJoin(broadcast(pf1)).count() === 4)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index f89951760f..c2d256bdd3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -626,9 +626,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
test("drop(name: String) search and drop all top level columns that matchs the name") {
val df1 = Seq((1, 2)).toDF("a", "b")
val df2 = Seq((3, 4)).toDF("a", "b")
- checkAnswer(df1.join(df2), Row(1, 2, 3, 4))
+ checkAnswer(df1.crossJoin(df2), Row(1, 2, 3, 4))
// Finds and drops all columns that match the name (case insensitive).
- checkAnswer(df1.join(df2).drop("A"), Row(2, 4))
+ checkAnswer(df1.crossJoin(df2).drop("A"), Row(2, 4))
}
test("withColumnRenamed") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 8ce6ea66b6..3243f352a5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -466,7 +466,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("self join") {
val ds = Seq("1", "2").toDS().as("a")
- val joined = ds.joinWith(ds, lit(true))
+ val joined = ds.joinWith(ds, lit(true), "cross")
checkDataset(joined, ("1", "1"), ("1", "2"), ("2", "1"), ("2", "2"))
}
@@ -486,7 +486,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("Kryo encoder self join") {
implicit val kryoEncoder = Encoders.kryo[KryoData]
val ds = Seq(KryoData(1), KryoData(2)).toDS()
- assert(ds.joinWith(ds, lit(true)).collect().toSet ==
+ assert(ds.joinWith(ds, lit(true), "cross").collect().toSet ==
Set(
(KryoData(1), KryoData(1)),
(KryoData(1), KryoData(2)),
@@ -514,7 +514,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("Java encoder self join") {
implicit val kryoEncoder = Encoders.javaSerialization[JavaData]
val ds = Seq(JavaData(1), JavaData(2)).toDS()
- assert(ds.joinWith(ds, lit(true)).collect().toSet ==
+ assert(ds.joinWith(ds, lit(true), "cross").collect().toSet ==
Set(
(JavaData(1), JavaData(1)),
(JavaData(1), JavaData(2)),
@@ -532,7 +532,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
val ds2 = Seq((nullInt, "1"), (new java.lang.Integer(22), "2")).toDS()
checkDataset(
- ds1.joinWith(ds2, lit(true)),
+ ds1.joinWith(ds2, lit(true), "cross"),
((nullInt, "1"), (nullInt, "1")),
((nullInt, "1"), (new java.lang.Integer(22), "2")),
((new java.lang.Integer(22), "2"), (nullInt, "1")),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 44889d92ee..913b2ae976 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -225,8 +225,8 @@ class JoinSuite extends QueryTest with SharedSQLContext {
Row(2, 2, 1, null) ::
Row(2, 2, 2, 2) :: Nil)
}
- assert(e.getMessage.contains("Cartesian joins could be prohibitively expensive and are " +
- "disabled by default"))
+ assert(e.getMessage.contains("Detected cartesian product for INNER join " +
+ "between logical plans"))
}
}
@@ -482,7 +482,8 @@ class JoinSuite extends QueryTest with SharedSQLContext {
// we set the threshold is greater than statistic of the cached table testData
withSQLConf(
- SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> (sizeInByteOfTestData + 1).toString()) {
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> (sizeInByteOfTestData + 1).toString(),
+ SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
assert(statisticSizeInByte(spark.table("testData2")) >
spark.conf.get(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD))
@@ -573,4 +574,34 @@ class JoinSuite extends QueryTest with SharedSQLContext {
Row(3, 1) ::
Row(3, 2) :: Nil)
}
+
+ test("cross join detection") {
+ testData.createOrReplaceTempView("A")
+ testData.createOrReplaceTempView("B")
+ testData2.createOrReplaceTempView("C")
+ testData3.createOrReplaceTempView("D")
+ upperCaseData.where('N >= 3).createOrReplaceTempView("`right`")
+ val cartesianQueries = Seq(
+ /** The following should error out since there is no explicit cross join */
+ "SELECT * FROM testData inner join testData2",
+ "SELECT * FROM testData left outer join testData2",
+ "SELECT * FROM testData right outer join testData2",
+ "SELECT * FROM testData full outer join testData2",
+ "SELECT * FROM testData, testData2",
+ "SELECT * FROM testData, testData2 where testData.key = 1 and testData2.a = 22",
+ /** The following should fail because after reordering there are cartesian products */
+ "select * from (A join B on (A.key = B.key)) join D on (A.key=D.a) join C",
+ "select * from ((A join B on (A.key = B.key)) join C) join D on (A.key = D.a)",
+ /** Cartesian product involving C, which is not involved in a CROSS join */
+ "select * from ((A join B on (A.key = B.key)) cross join D) join C on (A.key = D.a)");
+
+ def checkCartesianDetection(query: String): Unit = {
+ val e = intercept[Exception] {
+ checkAnswer(sql(query), Nil);
+ }
+ assert(e.getMessage.contains("Detected cartesian product"))
+ }
+
+ cartesianQueries.foreach(checkCartesianDetection)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala
index d3cfa953a3..afd47897ed 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala
@@ -361,7 +361,8 @@ class SQLWindowFunctionSuite extends QueryTest with SharedSQLContext {
|with
| v0 as (select 0 as key, 1 as value),
| v1 as (select key, count(value) over (partition by key) cnt_val from v0),
- | v2 as (select v1.key, v1_lag.cnt_val from v1, v1 v1_lag where v1.key = v1_lag.key)
+ | v2 as (select v1.key, v1_lag.cnt_val from v1 cross join v1 v1_lag
+ | where v1.key = v1_lag.key)
| select key, cnt_val from v2 order by key limit 1
""".stripMargin), Row(0, 1))
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
index 35dab63672..4408ece112 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
@@ -109,8 +109,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
leftPlan: SparkPlan,
rightPlan: SparkPlan,
side: BuildSide) = {
- val shuffledHashJoin =
- joins.ShuffledHashJoinExec(leftKeys, rightKeys, Inner, side, None, leftPlan, rightPlan)
+ val shuffledHashJoin = joins.ShuffledHashJoinExec(leftKeys, rightKeys, Inner,
+ side, None, leftPlan, rightPlan)
val filteredJoin =
boundCondition.map(FilterExec(_, shuffledHashJoin)).getOrElse(shuffledHashJoin)
EnsureRequirements(spark.sessionState.conf).apply(filteredJoin)
@@ -122,8 +122,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
boundCondition: Option[Expression],
leftPlan: SparkPlan,
rightPlan: SparkPlan) = {
- val sortMergeJoin =
- joins.SortMergeJoinExec(leftKeys, rightKeys, Inner, boundCondition, leftPlan, rightPlan)
+ val sortMergeJoin = joins.SortMergeJoinExec(leftKeys, rightKeys, Inner, boundCondition,
+ leftPlan, rightPlan)
EnsureRequirements(spark.sessionState.conf).apply(sortMergeJoin)
}