aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala18
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala16
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala7
-rw-r--r--sql/core/src/test/resources/sql-tests/inputs/cross-join.sql35
-rw-r--r--sql/core/src/test/resources/sql-tests/inputs/cte.sql2
-rw-r--r--sql/core/src/test/resources/sql-tests/inputs/outer-join.sql5
-rw-r--r--sql/core/src/test/resources/sql-tests/results/cross-join.sql.out129
-rw-r--r--sql/core/src/test/resources/sql-tests/results/cte.sql.out2
-rw-r--r--sql/core/src/test/resources/sql-tests/results/outer-join.sql.out22
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala19
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala37
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala8
20 files changed, 288 insertions, 62 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index e7dcf0f51f..3b3cb82078 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -589,9 +589,9 @@ class Dataset[T] private[sql](
def stat: DataFrameStatFunctions = new DataFrameStatFunctions(toDF())
/**
- * Cartesian join with another [[DataFrame]].
+ * Join with another [[DataFrame]].
*
- * Note that cartesian joins are very expensive without an extra filter that can be pushed down.
+ * Behaves as an INNER JOIN and requires a subsequent join predicate.
*
* @param right Right side of the join operation.
*
@@ -764,6 +764,20 @@ class Dataset[T] private[sql](
}
/**
+ * Explicit cartesian join with another [[DataFrame]].
+ *
+ * Note that cartesian joins are very expensive without an extra filter that can be pushed down.
+ *
+ * @param right Right side of the join operation.
+ *
+ * @group untypedrel
+ * @since 2.0.0
+ */
+ def crossJoin(right: Dataset[_]): DataFrame = withPlan {
+ Join(logicalPlan, right.logicalPlan, joinType = Cross, None)
+ }
+
+ /**
* :: Experimental ::
* Joins this Dataset returning a [[Tuple2]] for each pair where `condition` evaluates to
* true.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index b4899ad688..c389593b4f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -140,13 +140,13 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
private def canBuildRight(joinType: JoinType): Boolean = joinType match {
- case Inner | LeftOuter | LeftSemi | LeftAnti => true
+ case _: InnerLike | LeftOuter | LeftSemi | LeftAnti => true
case j: ExistenceJoin => true
case _ => false
}
private def canBuildLeft(joinType: JoinType): Boolean = joinType match {
- case Inner | RightOuter => true
+ case _: InnerLike | RightOuter => true
case _ => false
}
@@ -200,7 +200,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
planLater(left), planLater(right), BuildLeft, joinType, condition) :: Nil
// Pick CartesianProduct for InnerJoin
- case logical.Join(left, right, Inner, condition) =>
+ case logical.Join(left, right, _: InnerLike, condition) =>
joins.CartesianProductExec(planLater(left), planLater(right), condition) :: Nil
case logical.Join(left, right, joinType, condition) =>
@@ -212,8 +212,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
// This join could be very slow or OOM
joins.BroadcastNestedLoopJoinExec(
- planLater(left), planLater(right), buildSide, joinType, condition,
- withinBroadcastThreshold = false) :: Nil
+ planLater(left), planLater(right), buildSide, joinType, condition) :: Nil
// --- Cases where this strategy does not apply ---------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
index 0f24baacd1..0bc261d593 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
@@ -79,7 +79,7 @@ case class BroadcastHashJoinExec(
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
joinType match {
- case Inner => codegenInner(ctx, input)
+ case _: InnerLike => codegenInner(ctx, input)
case LeftOuter | RightOuter => codegenOuter(ctx, input)
case LeftSemi => codegenSemi(ctx, input)
case LeftAnti => codegenAnti(ctx, input)
@@ -134,7 +134,7 @@ case class BroadcastHashJoinExec(
ctx.INPUT_ROW = matched
buildPlan.output.zipWithIndex.map { case (a, i) =>
val ev = BoundReference(i, a.dataType, a.nullable).genCode(ctx)
- if (joinType == Inner) {
+ if (joinType.isInstanceOf[InnerLike]) {
ev
} else {
// the variables are needed even there is no matched rows
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
index 6a9965f1a2..43cdce7de8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
@@ -34,8 +34,7 @@ case class BroadcastNestedLoopJoinExec(
right: SparkPlan,
buildSide: BuildSide,
joinType: JoinType,
- condition: Option[Expression],
- withinBroadcastThreshold: Boolean = true) extends BinaryExecNode {
+ condition: Option[Expression]) extends BinaryExecNode {
override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
@@ -65,7 +64,7 @@ case class BroadcastNestedLoopJoinExec(
override def output: Seq[Attribute] = {
joinType match {
- case Inner =>
+ case _: InnerLike =>
left.output ++ right.output
case LeftOuter =>
left.output ++ right.output.map(_.withNullability(true))
@@ -340,20 +339,11 @@ case class BroadcastNestedLoopJoinExec(
)
}
- protected override def doPrepare(): Unit = {
- if (!withinBroadcastThreshold && !sqlContext.conf.crossJoinEnabled) {
- throw new AnalysisException("Both sides of this join are outside the broadcasting " +
- "threshold and computing it could be prohibitively expensive. To explicitly enable it, " +
- s"please set ${SQLConf.CROSS_JOINS_ENABLED.key} = true")
- }
- super.doPrepare()
- }
-
protected override def doExecute(): RDD[InternalRow] = {
val broadcastedRelation = broadcast.executeBroadcast[Array[InternalRow]]()
val resultRdd = (joinType, buildSide) match {
- case (Inner, _) =>
+ case (_: InnerLike, _) =>
innerJoin(broadcastedRelation)
case (LeftOuter, BuildRight) | (RightOuter, BuildLeft) =>
outerJoin(broadcastedRelation)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
index 57866df90d..15dc9b4066 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
@@ -91,15 +91,6 @@ case class CartesianProductExec(
override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
- protected override def doPrepare(): Unit = {
- if (!sqlContext.conf.crossJoinEnabled) {
- throw new AnalysisException("Cartesian joins could be prohibitively expensive and are " +
- "disabled by default. To explicitly enable them, please set " +
- s"${SQLConf.CROSS_JOINS_ENABLED.key} = true")
- }
- super.doPrepare()
- }
-
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index d46a80423f..fb6bfa7b27 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -38,7 +38,7 @@ trait HashJoin {
override def output: Seq[Attribute] = {
joinType match {
- case Inner =>
+ case _: InnerLike =>
left.output ++ right.output
case LeftOuter =>
left.output ++ right.output.map(_.withNullability(true))
@@ -225,7 +225,7 @@ trait HashJoin {
numOutputRows: SQLMetric): Iterator[InternalRow] = {
val joinedIter = joinType match {
- case Inner =>
+ case _: InnerLike =>
innerJoin(streamedIter, hashed)
case LeftOuter | RightOuter =>
outerJoin(streamedIter, hashed)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index 5c9c1e6062..b46af2a99a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -45,7 +45,7 @@ case class SortMergeJoinExec(
override def output: Seq[Attribute] = {
joinType match {
- case Inner =>
+ case _: InnerLike =>
left.output ++ right.output
case LeftOuter =>
left.output ++ right.output.map(_.withNullability(true))
@@ -64,7 +64,8 @@ case class SortMergeJoinExec(
}
override def outputPartitioning: Partitioning = joinType match {
- case Inner => PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning))
+ case _: InnerLike =>
+ PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning))
// For left and right outer joins, the output is partitioned by the streamed input's join keys.
case LeftOuter => left.outputPartitioning
case RightOuter => right.outputPartitioning
@@ -111,7 +112,7 @@ case class SortMergeJoinExec(
val resultProj: InternalRow => InternalRow = UnsafeProjection.create(output, output)
joinType match {
- case Inner =>
+ case _: InnerLike =>
new RowIterator {
private[this] var currentLeftRow: InternalRow = _
private[this] var currentRightMatches: ArrayBuffer[InternalRow] = _
@@ -318,7 +319,7 @@ case class SortMergeJoinExec(
}
override def supportCodegen: Boolean = {
- joinType == Inner
+ joinType.isInstanceOf[InnerLike]
}
override def inputRDDs(): Seq[RDD[InternalRow]] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index a54342f82e..1d6ca5a965 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -362,7 +362,8 @@ object SQLConf {
.createWithDefault(true)
val CROSS_JOINS_ENABLED = SQLConfigBuilder("spark.sql.crossJoin.enabled")
- .doc("When false, we will throw an error if a query contains a cross join")
+ .doc("When false, we will throw an error if a query contains a cartesian product without " +
+ "explicit CROSS JOIN syntax.")
.booleanConf
.createWithDefault(false)
@@ -683,8 +684,6 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging {
def bucketingEnabled: Boolean = getConf(SQLConf.BUCKETING_ENABLED)
- def crossJoinEnabled: Boolean = getConf(SQLConf.CROSS_JOINS_ENABLED)
-
// Do not use a value larger than 4000 as the default value of this property.
// See the comments of SCHEMA_STRING_LENGTH_THRESHOLD above for more information.
def schemaStringLengthThreshold: Int = getConf(SCHEMA_STRING_LENGTH_THRESHOLD)
@@ -709,6 +708,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging {
override def orderByOrdinal: Boolean = getConf(ORDER_BY_ORDINAL)
override def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL)
+
+ override def crossJoinEnabled: Boolean = getConf(SQLConf.CROSS_JOINS_ENABLED)
/** ********************** SQLConf functionality methods ************ */
/** Set Spark SQL configuration properties. */
diff --git a/sql/core/src/test/resources/sql-tests/inputs/cross-join.sql b/sql/core/src/test/resources/sql-tests/inputs/cross-join.sql
new file mode 100644
index 0000000000..aa73124374
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/cross-join.sql
@@ -0,0 +1,35 @@
+-- Cross join detection and error checking is done in JoinSuite since explain output is
+-- used in the error message and the ids are not stable. Only positive cases are checked here.
+
+create temporary view nt1 as select * from values
+ ("one", 1),
+ ("two", 2),
+ ("three", 3)
+ as nt1(k, v1);
+
+create temporary view nt2 as select * from values
+ ("one", 1),
+ ("two", 22),
+ ("one", 5)
+ as nt2(k, v2);
+
+-- Cross joins with and without predicates
+SELECT * FROM nt1 cross join nt2;
+SELECT * FROM nt1 cross join nt2 where nt1.k = nt2.k;
+SELECT * FROM nt1 cross join nt2 on (nt1.k = nt2.k);
+SELECT * FROM nt1 cross join nt2 where nt1.v1 = 1 and nt2.v2 = 22;
+
+SELECT a.key, b.key FROM
+(SELECT k key FROM nt1 WHERE v1 < 2) a
+CROSS JOIN
+(SELECT k key FROM nt2 WHERE v2 = 22) b;
+
+-- Join reordering
+create temporary view A(a, va) as select * from nt1;
+create temporary view B(b, vb) as select * from nt1;
+create temporary view C(c, vc) as select * from nt1;
+create temporary view D(d, vd) as select * from nt1;
+
+-- Allowed since cross join with C is explicit
+select * from ((A join B on (a = b)) cross join C) join D on (a = d);
+
diff --git a/sql/core/src/test/resources/sql-tests/inputs/cte.sql b/sql/core/src/test/resources/sql-tests/inputs/cte.sql
index 10d34deff4..3914db2691 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/cte.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/cte.sql
@@ -11,4 +11,4 @@ WITH t AS (SELECT 1 FROM t) SELECT * FROM t;
WITH s1 AS (SELECT 1 FROM s2), s2 AS (SELECT 1 FROM s1) SELECT * FROM s1, s2;
-- WITH clause should reference the previous CTE
-WITH t1 AS (SELECT * FROM t2), t2 AS (SELECT 2 FROM t1) SELECT * FROM t1, t2;
+WITH t1 AS (SELECT * FROM t2), t2 AS (SELECT 2 FROM t1) SELECT * FROM t1 cross join t2;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/outer-join.sql b/sql/core/src/test/resources/sql-tests/inputs/outer-join.sql
index f50f1ebad9..cdc6c81e10 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/outer-join.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/outer-join.sql
@@ -24,6 +24,9 @@ CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES (97) as t1(int_col1)
CREATE OR REPLACE TEMPORARY VIEW t2 AS SELECT * FROM VALUES (0) as t2(int_col1);
+-- Set the cross join enabled flag for the LEFT JOIN test since there's no join condition.
+-- Ultimately the join should be optimized away.
+set spark.sql.crossJoin.enabled = true;
SELECT *
FROM (
SELECT
@@ -31,6 +34,6 @@ SELECT
FROM t1
LEFT JOIN t2 ON false
) t where (t.int_col) is not null;
-
+set spark.sql.crossJoin.enabled = false;
diff --git a/sql/core/src/test/resources/sql-tests/results/cross-join.sql.out b/sql/core/src/test/resources/sql-tests/results/cross-join.sql.out
new file mode 100644
index 0000000000..562e174fc0
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/cross-join.sql.out
@@ -0,0 +1,129 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 12
+
+
+-- !query 0
+create temporary view nt1 as select * from values
+ ("one", 1),
+ ("two", 2),
+ ("three", 3)
+ as nt1(k, v1)
+-- !query 0 schema
+struct<>
+-- !query 0 output
+
+
+
+-- !query 1
+create temporary view nt2 as select * from values
+ ("one", 1),
+ ("two", 22),
+ ("one", 5)
+ as nt2(k, v2)
+-- !query 1 schema
+struct<>
+-- !query 1 output
+
+
+
+-- !query 2
+SELECT * FROM nt1 cross join nt2
+-- !query 2 schema
+struct<k:string,v1:int,k:string,v2:int>
+-- !query 2 output
+one 1 one 1
+one 1 one 5
+one 1 two 22
+three 3 one 1
+three 3 one 5
+three 3 two 22
+two 2 one 1
+two 2 one 5
+two 2 two 22
+
+
+-- !query 3
+SELECT * FROM nt1 cross join nt2 where nt1.k = nt2.k
+-- !query 3 schema
+struct<k:string,v1:int,k:string,v2:int>
+-- !query 3 output
+one 1 one 1
+one 1 one 5
+two 2 two 22
+
+
+-- !query 4
+SELECT * FROM nt1 cross join nt2 on (nt1.k = nt2.k)
+-- !query 4 schema
+struct<k:string,v1:int,k:string,v2:int>
+-- !query 4 output
+one 1 one 1
+one 1 one 5
+two 2 two 22
+
+
+-- !query 5
+SELECT * FROM nt1 cross join nt2 where nt1.v1 = 1 and nt2.v2 = 22
+-- !query 5 schema
+struct<k:string,v1:int,k:string,v2:int>
+-- !query 5 output
+one 1 two 22
+
+
+-- !query 6
+SELECT a.key, b.key FROM
+(SELECT k key FROM nt1 WHERE v1 < 2) a
+CROSS JOIN
+(SELECT k key FROM nt2 WHERE v2 = 22) b
+-- !query 6 schema
+struct<key:string,key:string>
+-- !query 6 output
+one two
+
+
+-- !query 7
+create temporary view A(a, va) as select * from nt1
+-- !query 7 schema
+struct<>
+-- !query 7 output
+
+
+
+-- !query 8
+create temporary view B(b, vb) as select * from nt1
+-- !query 8 schema
+struct<>
+-- !query 8 output
+
+
+
+-- !query 9
+create temporary view C(c, vc) as select * from nt1
+-- !query 9 schema
+struct<>
+-- !query 9 output
+
+
+
+-- !query 10
+create temporary view D(d, vd) as select * from nt1
+-- !query 10 schema
+struct<>
+-- !query 10 output
+
+
+
+-- !query 11
+select * from ((A join B on (a = b)) cross join C) join D on (a = d)
+-- !query 11 schema
+struct<a:string,va:int,b:string,vb:int,c:string,vc:int,d:string,vd:int>
+-- !query 11 output
+one 1 one 1 one 1 one 1
+one 1 one 1 three 3 one 1
+one 1 one 1 two 2 one 1
+three 3 three 3 one 1 three 3
+three 3 three 3 three 3 three 3
+three 3 three 3 two 2 three 3
+two 2 two 2 one 1 two 2
+two 2 two 2 three 3 two 2
+two 2 two 2 two 2 two 2
diff --git a/sql/core/src/test/resources/sql-tests/results/cte.sql.out b/sql/core/src/test/resources/sql-tests/results/cte.sql.out
index ddee5bf2d4..9fbad8f380 100644
--- a/sql/core/src/test/resources/sql-tests/results/cte.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/cte.sql.out
@@ -47,7 +47,7 @@ Table or view not found: s2; line 1 pos 26
-- !query 5
-WITH t1 AS (SELECT * FROM t2), t2 AS (SELECT 2 FROM t1) SELECT * FROM t1, t2
+WITH t1 AS (SELECT * FROM t2), t2 AS (SELECT 2 FROM t1) SELECT * FROM t1 cross join t2
-- !query 5 schema
struct<id:int,2:int>
-- !query 5 output
diff --git a/sql/core/src/test/resources/sql-tests/results/outer-join.sql.out b/sql/core/src/test/resources/sql-tests/results/outer-join.sql.out
index b39fdb0e58..cc50b9444b 100644
--- a/sql/core/src/test/resources/sql-tests/results/outer-join.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/outer-join.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 6
+-- Number of queries: 8
-- !query 0
@@ -59,6 +59,14 @@ struct<>
-- !query 5
+set spark.sql.crossJoin.enabled = true
+-- !query 5 schema
+struct<key:string,value:string>
+-- !query 5 output
+spark.sql.crossJoin.enabled
+
+
+-- !query 6
SELECT *
FROM (
SELECT
@@ -66,7 +74,15 @@ SELECT
FROM t1
LEFT JOIN t2 ON false
) t where (t.int_col) is not null
--- !query 5 schema
+-- !query 6 schema
struct<int_col:int>
--- !query 5 output
+-- !query 6 output
97
+
+
+-- !query 7
+set spark.sql.crossJoin.enabled = false
+-- !query 7 schema
+struct<key:string,value:string>
+-- !query 7 output
+spark.sql.crossJoin.enabled
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
index 4abf5e42b9..541ffb58e7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
@@ -104,6 +104,21 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext {
.collect().toSeq)
}
+ test("join - cross join") {
+ val df1 = Seq((1, "1"), (3, "3")).toDF("int", "str")
+ val df2 = Seq((2, "2"), (4, "4")).toDF("int", "str")
+
+ checkAnswer(
+ df1.crossJoin(df2),
+ Row(1, "1", 2, "2") :: Row(1, "1", 4, "4") ::
+ Row(3, "3", 2, "2") :: Row(3, "3", 4, "4") :: Nil)
+
+ checkAnswer(
+ df2.crossJoin(df1),
+ Row(2, "2", 1, "1") :: Row(2, "2", 3, "3") ::
+ Row(4, "4", 1, "1") :: Row(4, "4", 3, "3") :: Nil)
+ }
+
test("join - using aliases after self join") {
val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str")
checkAnswer(
@@ -145,7 +160,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext {
assert(plan1.collect { case p: BroadcastHashJoinExec => p }.size === 1)
// no join key -- should not be a broadcast join
- val plan2 = df1.join(broadcast(df2)).queryExecution.sparkPlan
+ val plan2 = df1.crossJoin(broadcast(df2)).queryExecution.sparkPlan
assert(plan2.collect { case p: BroadcastHashJoinExec => p }.size === 0)
// planner should not crash without a join
@@ -155,7 +170,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext {
withTempPath { path =>
df1.write.parquet(path.getCanonicalPath)
val pf1 = spark.read.parquet(path.getCanonicalPath)
- assert(df1.join(broadcast(pf1)).count() === 4)
+ assert(df1.crossJoin(broadcast(pf1)).count() === 4)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index f89951760f..c2d256bdd3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -626,9 +626,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
test("drop(name: String) search and drop all top level columns that matchs the name") {
val df1 = Seq((1, 2)).toDF("a", "b")
val df2 = Seq((3, 4)).toDF("a", "b")
- checkAnswer(df1.join(df2), Row(1, 2, 3, 4))
+ checkAnswer(df1.crossJoin(df2), Row(1, 2, 3, 4))
// Finds and drops all columns that match the name (case insensitive).
- checkAnswer(df1.join(df2).drop("A"), Row(2, 4))
+ checkAnswer(df1.crossJoin(df2).drop("A"), Row(2, 4))
}
test("withColumnRenamed") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 8ce6ea66b6..3243f352a5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -466,7 +466,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("self join") {
val ds = Seq("1", "2").toDS().as("a")
- val joined = ds.joinWith(ds, lit(true))
+ val joined = ds.joinWith(ds, lit(true), "cross")
checkDataset(joined, ("1", "1"), ("1", "2"), ("2", "1"), ("2", "2"))
}
@@ -486,7 +486,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("Kryo encoder self join") {
implicit val kryoEncoder = Encoders.kryo[KryoData]
val ds = Seq(KryoData(1), KryoData(2)).toDS()
- assert(ds.joinWith(ds, lit(true)).collect().toSet ==
+ assert(ds.joinWith(ds, lit(true), "cross").collect().toSet ==
Set(
(KryoData(1), KryoData(1)),
(KryoData(1), KryoData(2)),
@@ -514,7 +514,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("Java encoder self join") {
implicit val kryoEncoder = Encoders.javaSerialization[JavaData]
val ds = Seq(JavaData(1), JavaData(2)).toDS()
- assert(ds.joinWith(ds, lit(true)).collect().toSet ==
+ assert(ds.joinWith(ds, lit(true), "cross").collect().toSet ==
Set(
(JavaData(1), JavaData(1)),
(JavaData(1), JavaData(2)),
@@ -532,7 +532,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
val ds2 = Seq((nullInt, "1"), (new java.lang.Integer(22), "2")).toDS()
checkDataset(
- ds1.joinWith(ds2, lit(true)),
+ ds1.joinWith(ds2, lit(true), "cross"),
((nullInt, "1"), (nullInt, "1")),
((nullInt, "1"), (new java.lang.Integer(22), "2")),
((new java.lang.Integer(22), "2"), (nullInt, "1")),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 44889d92ee..913b2ae976 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -225,8 +225,8 @@ class JoinSuite extends QueryTest with SharedSQLContext {
Row(2, 2, 1, null) ::
Row(2, 2, 2, 2) :: Nil)
}
- assert(e.getMessage.contains("Cartesian joins could be prohibitively expensive and are " +
- "disabled by default"))
+ assert(e.getMessage.contains("Detected cartesian product for INNER join " +
+ "between logical plans"))
}
}
@@ -482,7 +482,8 @@ class JoinSuite extends QueryTest with SharedSQLContext {
// we set the threshold is greater than statistic of the cached table testData
withSQLConf(
- SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> (sizeInByteOfTestData + 1).toString()) {
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> (sizeInByteOfTestData + 1).toString(),
+ SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
assert(statisticSizeInByte(spark.table("testData2")) >
spark.conf.get(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD))
@@ -573,4 +574,34 @@ class JoinSuite extends QueryTest with SharedSQLContext {
Row(3, 1) ::
Row(3, 2) :: Nil)
}
+
+ test("cross join detection") {
+ testData.createOrReplaceTempView("A")
+ testData.createOrReplaceTempView("B")
+ testData2.createOrReplaceTempView("C")
+ testData3.createOrReplaceTempView("D")
+ upperCaseData.where('N >= 3).createOrReplaceTempView("`right`")
+ val cartesianQueries = Seq(
+ /** The following should error out since there is no explicit cross join */
+ "SELECT * FROM testData inner join testData2",
+ "SELECT * FROM testData left outer join testData2",
+ "SELECT * FROM testData right outer join testData2",
+ "SELECT * FROM testData full outer join testData2",
+ "SELECT * FROM testData, testData2",
+ "SELECT * FROM testData, testData2 where testData.key = 1 and testData2.a = 22",
+ /** The following should fail because after reordering there are cartesian products */
+ "select * from (A join B on (A.key = B.key)) join D on (A.key=D.a) join C",
+ "select * from ((A join B on (A.key = B.key)) join C) join D on (A.key = D.a)",
+ /** Cartesian product involving C, which is not involved in a CROSS join */
+ "select * from ((A join B on (A.key = B.key)) cross join D) join C on (A.key = D.a)");
+
+ def checkCartesianDetection(query: String): Unit = {
+ val e = intercept[Exception] {
+ checkAnswer(sql(query), Nil);
+ }
+ assert(e.getMessage.contains("Detected cartesian product"))
+ }
+
+ cartesianQueries.foreach(checkCartesianDetection)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala
index d3cfa953a3..afd47897ed 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala
@@ -361,7 +361,8 @@ class SQLWindowFunctionSuite extends QueryTest with SharedSQLContext {
|with
| v0 as (select 0 as key, 1 as value),
| v1 as (select key, count(value) over (partition by key) cnt_val from v0),
- | v2 as (select v1.key, v1_lag.cnt_val from v1, v1 v1_lag where v1.key = v1_lag.key)
+ | v2 as (select v1.key, v1_lag.cnt_val from v1 cross join v1 v1_lag
+ | where v1.key = v1_lag.key)
| select key, cnt_val from v2 order by key limit 1
""".stripMargin), Row(0, 1))
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
index 35dab63672..4408ece112 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
@@ -109,8 +109,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
leftPlan: SparkPlan,
rightPlan: SparkPlan,
side: BuildSide) = {
- val shuffledHashJoin =
- joins.ShuffledHashJoinExec(leftKeys, rightKeys, Inner, side, None, leftPlan, rightPlan)
+ val shuffledHashJoin = joins.ShuffledHashJoinExec(leftKeys, rightKeys, Inner,
+ side, None, leftPlan, rightPlan)
val filteredJoin =
boundCondition.map(FilterExec(_, shuffledHashJoin)).getOrElse(shuffledHashJoin)
EnsureRequirements(spark.sessionState.conf).apply(filteredJoin)
@@ -122,8 +122,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
boundCondition: Option[Expression],
leftPlan: SparkPlan,
rightPlan: SparkPlan) = {
- val sortMergeJoin =
- joins.SortMergeJoinExec(leftKeys, rightKeys, Inner, boundCondition, leftPlan, rightPlan)
+ val sortMergeJoin = joins.SortMergeJoinExec(leftKeys, rightKeys, Inner, boundCondition,
+ leftPlan, rightPlan)
EnsureRequirements(spark.sessionState.conf).apply(sortMergeJoin)
}