diff options
Diffstat (limited to 'sql/core')
20 files changed, 288 insertions, 62 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index e7dcf0f51f..3b3cb82078 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -589,9 +589,9 @@ class Dataset[T] private[sql]( def stat: DataFrameStatFunctions = new DataFrameStatFunctions(toDF()) /** - * Cartesian join with another [[DataFrame]]. + * Join with another [[DataFrame]]. * - * Note that cartesian joins are very expensive without an extra filter that can be pushed down. + * Behaves as an INNER JOIN and requires a subsequent join predicate. * * @param right Right side of the join operation. * @@ -764,6 +764,20 @@ class Dataset[T] private[sql]( } /** + * Explicit cartesian join with another [[DataFrame]]. + * + * Note that cartesian joins are very expensive without an extra filter that can be pushed down. + * + * @param right Right side of the join operation. + * + * @group untypedrel + * @since 2.0.0 + */ + def crossJoin(right: Dataset[_]): DataFrame = withPlan { + Join(logicalPlan, right.logicalPlan, joinType = Cross, None) + } + + /** * :: Experimental :: * Joins this Dataset returning a [[Tuple2]] for each pair where `condition` evaluates to * true. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index b4899ad688..c389593b4f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -140,13 +140,13 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } private def canBuildRight(joinType: JoinType): Boolean = joinType match { - case Inner | LeftOuter | LeftSemi | LeftAnti => true + case _: InnerLike | LeftOuter | LeftSemi | LeftAnti => true case j: ExistenceJoin => true case _ => false } private def canBuildLeft(joinType: JoinType): Boolean = joinType match { - case Inner | RightOuter => true + case _: InnerLike | RightOuter => true case _ => false } @@ -200,7 +200,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { planLater(left), planLater(right), BuildLeft, joinType, condition) :: Nil // Pick CartesianProduct for InnerJoin - case logical.Join(left, right, Inner, condition) => + case logical.Join(left, right, _: InnerLike, condition) => joins.CartesianProductExec(planLater(left), planLater(right), condition) :: Nil case logical.Join(left, right, joinType, condition) => @@ -212,8 +212,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } // This join could be very slow or OOM joins.BroadcastNestedLoopJoinExec( - planLater(left), planLater(right), buildSide, joinType, condition, - withinBroadcastThreshold = false) :: Nil + planLater(left), planLater(right), buildSide, joinType, condition) :: Nil // --- Cases where this strategy does not apply --------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 0f24baacd1..0bc261d593 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -79,7 +79,7 @@ case class BroadcastHashJoinExec( override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { joinType match { - case Inner => codegenInner(ctx, input) + case _: InnerLike => codegenInner(ctx, input) case LeftOuter | RightOuter => codegenOuter(ctx, input) case LeftSemi => codegenSemi(ctx, input) case LeftAnti => codegenAnti(ctx, input) @@ -134,7 +134,7 @@ case class BroadcastHashJoinExec( ctx.INPUT_ROW = matched buildPlan.output.zipWithIndex.map { case (a, i) => val ev = BoundReference(i, a.dataType, a.nullable).genCode(ctx) - if (joinType == Inner) { + if (joinType.isInstanceOf[InnerLike]) { ev } else { // the variables are needed even there is no matched rows diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala index 6a9965f1a2..43cdce7de8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala @@ -34,8 +34,7 @@ case class BroadcastNestedLoopJoinExec( right: SparkPlan, buildSide: BuildSide, joinType: JoinType, - condition: Option[Expression], - withinBroadcastThreshold: Boolean = true) extends BinaryExecNode { + condition: Option[Expression]) extends BinaryExecNode { override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) @@ -65,7 +64,7 @@ case class BroadcastNestedLoopJoinExec( override def output: Seq[Attribute] = { joinType match { - case Inner => + case _: InnerLike => left.output ++ right.output case LeftOuter => left.output ++ right.output.map(_.withNullability(true)) @@ -340,20 +339,11 @@ case class BroadcastNestedLoopJoinExec( ) } - protected override def doPrepare(): Unit = { - if (!withinBroadcastThreshold && !sqlContext.conf.crossJoinEnabled) { - throw new AnalysisException("Both sides of this join are outside the broadcasting " + - "threshold and computing it could be prohibitively expensive. To explicitly enable it, " + - s"please set ${SQLConf.CROSS_JOINS_ENABLED.key} = true") - } - super.doPrepare() - } - protected override def doExecute(): RDD[InternalRow] = { val broadcastedRelation = broadcast.executeBroadcast[Array[InternalRow]]() val resultRdd = (joinType, buildSide) match { - case (Inner, _) => + case (_: InnerLike, _) => innerJoin(broadcastedRelation) case (LeftOuter, BuildRight) | (RightOuter, BuildLeft) => outerJoin(broadcastedRelation) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala index 57866df90d..15dc9b4066 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala @@ -91,15 +91,6 @@ case class CartesianProductExec( override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) - protected override def doPrepare(): Unit = { - if (!sqlContext.conf.crossJoinEnabled) { - throw new AnalysisException("Cartesian joins could be prohibitively expensive and are " + - "disabled by default. To explicitly enable them, please set " + - s"${SQLConf.CROSS_JOINS_ENABLED.key} = true") - } - super.doPrepare() - } - protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index d46a80423f..fb6bfa7b27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -38,7 +38,7 @@ trait HashJoin { override def output: Seq[Attribute] = { joinType match { - case Inner => + case _: InnerLike => left.output ++ right.output case LeftOuter => left.output ++ right.output.map(_.withNullability(true)) @@ -225,7 +225,7 @@ trait HashJoin { numOutputRows: SQLMetric): Iterator[InternalRow] = { val joinedIter = joinType match { - case Inner => + case _: InnerLike => innerJoin(streamedIter, hashed) case LeftOuter | RightOuter => outerJoin(streamedIter, hashed) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 5c9c1e6062..b46af2a99a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -45,7 +45,7 @@ case class SortMergeJoinExec( override def output: Seq[Attribute] = { joinType match { - case Inner => + case _: InnerLike => left.output ++ right.output case LeftOuter => left.output ++ right.output.map(_.withNullability(true)) @@ -64,7 +64,8 @@ case class SortMergeJoinExec( } override def outputPartitioning: Partitioning = joinType match { - case Inner => PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) + case _: InnerLike => + PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) // For left and right outer joins, the output is partitioned by the streamed input's join keys. case LeftOuter => left.outputPartitioning case RightOuter => right.outputPartitioning @@ -111,7 +112,7 @@ case class SortMergeJoinExec( val resultProj: InternalRow => InternalRow = UnsafeProjection.create(output, output) joinType match { - case Inner => + case _: InnerLike => new RowIterator { private[this] var currentLeftRow: InternalRow = _ private[this] var currentRightMatches: ArrayBuffer[InternalRow] = _ @@ -318,7 +319,7 @@ case class SortMergeJoinExec( } override def supportCodegen: Boolean = { - joinType == Inner + joinType.isInstanceOf[InnerLike] } override def inputRDDs(): Seq[RDD[InternalRow]] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index a54342f82e..1d6ca5a965 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -362,7 +362,8 @@ object SQLConf { .createWithDefault(true) val CROSS_JOINS_ENABLED = SQLConfigBuilder("spark.sql.crossJoin.enabled") - .doc("When false, we will throw an error if a query contains a cross join") + .doc("When false, we will throw an error if a query contains a cartesian product without " + + "explicit CROSS JOIN syntax.") .booleanConf .createWithDefault(false) @@ -683,8 +684,6 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def bucketingEnabled: Boolean = getConf(SQLConf.BUCKETING_ENABLED) - def crossJoinEnabled: Boolean = getConf(SQLConf.CROSS_JOINS_ENABLED) - // Do not use a value larger than 4000 as the default value of this property. // See the comments of SCHEMA_STRING_LENGTH_THRESHOLD above for more information. def schemaStringLengthThreshold: Int = getConf(SCHEMA_STRING_LENGTH_THRESHOLD) @@ -709,6 +708,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { override def orderByOrdinal: Boolean = getConf(ORDER_BY_ORDINAL) override def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL) + + override def crossJoinEnabled: Boolean = getConf(SQLConf.CROSS_JOINS_ENABLED) /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/test/resources/sql-tests/inputs/cross-join.sql b/sql/core/src/test/resources/sql-tests/inputs/cross-join.sql new file mode 100644 index 0000000000..aa73124374 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/cross-join.sql @@ -0,0 +1,35 @@ +-- Cross join detection and error checking is done in JoinSuite since explain output is +-- used in the error message and the ids are not stable. Only positive cases are checked here. + +create temporary view nt1 as select * from values + ("one", 1), + ("two", 2), + ("three", 3) + as nt1(k, v1); + +create temporary view nt2 as select * from values + ("one", 1), + ("two", 22), + ("one", 5) + as nt2(k, v2); + +-- Cross joins with and without predicates +SELECT * FROM nt1 cross join nt2; +SELECT * FROM nt1 cross join nt2 where nt1.k = nt2.k; +SELECT * FROM nt1 cross join nt2 on (nt1.k = nt2.k); +SELECT * FROM nt1 cross join nt2 where nt1.v1 = 1 and nt2.v2 = 22; + +SELECT a.key, b.key FROM +(SELECT k key FROM nt1 WHERE v1 < 2) a +CROSS JOIN +(SELECT k key FROM nt2 WHERE v2 = 22) b; + +-- Join reordering +create temporary view A(a, va) as select * from nt1; +create temporary view B(b, vb) as select * from nt1; +create temporary view C(c, vc) as select * from nt1; +create temporary view D(d, vd) as select * from nt1; + +-- Allowed since cross join with C is explicit +select * from ((A join B on (a = b)) cross join C) join D on (a = d); + diff --git a/sql/core/src/test/resources/sql-tests/inputs/cte.sql b/sql/core/src/test/resources/sql-tests/inputs/cte.sql index 10d34deff4..3914db2691 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/cte.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/cte.sql @@ -11,4 +11,4 @@ WITH t AS (SELECT 1 FROM t) SELECT * FROM t; WITH s1 AS (SELECT 1 FROM s2), s2 AS (SELECT 1 FROM s1) SELECT * FROM s1, s2; -- WITH clause should reference the previous CTE -WITH t1 AS (SELECT * FROM t2), t2 AS (SELECT 2 FROM t1) SELECT * FROM t1, t2; +WITH t1 AS (SELECT * FROM t2), t2 AS (SELECT 2 FROM t1) SELECT * FROM t1 cross join t2; diff --git a/sql/core/src/test/resources/sql-tests/inputs/outer-join.sql b/sql/core/src/test/resources/sql-tests/inputs/outer-join.sql index f50f1ebad9..cdc6c81e10 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/outer-join.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/outer-join.sql @@ -24,6 +24,9 @@ CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES (97) as t1(int_col1) CREATE OR REPLACE TEMPORARY VIEW t2 AS SELECT * FROM VALUES (0) as t2(int_col1); +-- Set the cross join enabled flag for the LEFT JOIN test since there's no join condition. +-- Ultimately the join should be optimized away. +set spark.sql.crossJoin.enabled = true; SELECT * FROM ( SELECT @@ -31,6 +34,6 @@ SELECT FROM t1 LEFT JOIN t2 ON false ) t where (t.int_col) is not null; - +set spark.sql.crossJoin.enabled = false; diff --git a/sql/core/src/test/resources/sql-tests/results/cross-join.sql.out b/sql/core/src/test/resources/sql-tests/results/cross-join.sql.out new file mode 100644 index 0000000000..562e174fc0 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/cross-join.sql.out @@ -0,0 +1,129 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 12 + + +-- !query 0 +create temporary view nt1 as select * from values + ("one", 1), + ("two", 2), + ("three", 3) + as nt1(k, v1) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +create temporary view nt2 as select * from values + ("one", 1), + ("two", 22), + ("one", 5) + as nt2(k, v2) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +SELECT * FROM nt1 cross join nt2 +-- !query 2 schema +struct<k:string,v1:int,k:string,v2:int> +-- !query 2 output +one 1 one 1 +one 1 one 5 +one 1 two 22 +three 3 one 1 +three 3 one 5 +three 3 two 22 +two 2 one 1 +two 2 one 5 +two 2 two 22 + + +-- !query 3 +SELECT * FROM nt1 cross join nt2 where nt1.k = nt2.k +-- !query 3 schema +struct<k:string,v1:int,k:string,v2:int> +-- !query 3 output +one 1 one 1 +one 1 one 5 +two 2 two 22 + + +-- !query 4 +SELECT * FROM nt1 cross join nt2 on (nt1.k = nt2.k) +-- !query 4 schema +struct<k:string,v1:int,k:string,v2:int> +-- !query 4 output +one 1 one 1 +one 1 one 5 +two 2 two 22 + + +-- !query 5 +SELECT * FROM nt1 cross join nt2 where nt1.v1 = 1 and nt2.v2 = 22 +-- !query 5 schema +struct<k:string,v1:int,k:string,v2:int> +-- !query 5 output +one 1 two 22 + + +-- !query 6 +SELECT a.key, b.key FROM +(SELECT k key FROM nt1 WHERE v1 < 2) a +CROSS JOIN +(SELECT k key FROM nt2 WHERE v2 = 22) b +-- !query 6 schema +struct<key:string,key:string> +-- !query 6 output +one two + + +-- !query 7 +create temporary view A(a, va) as select * from nt1 +-- !query 7 schema +struct<> +-- !query 7 output + + + +-- !query 8 +create temporary view B(b, vb) as select * from nt1 +-- !query 8 schema +struct<> +-- !query 8 output + + + +-- !query 9 +create temporary view C(c, vc) as select * from nt1 +-- !query 9 schema +struct<> +-- !query 9 output + + + +-- !query 10 +create temporary view D(d, vd) as select * from nt1 +-- !query 10 schema +struct<> +-- !query 10 output + + + +-- !query 11 +select * from ((A join B on (a = b)) cross join C) join D on (a = d) +-- !query 11 schema +struct<a:string,va:int,b:string,vb:int,c:string,vc:int,d:string,vd:int> +-- !query 11 output +one 1 one 1 one 1 one 1 +one 1 one 1 three 3 one 1 +one 1 one 1 two 2 one 1 +three 3 three 3 one 1 three 3 +three 3 three 3 three 3 three 3 +three 3 three 3 two 2 three 3 +two 2 two 2 one 1 two 2 +two 2 two 2 three 3 two 2 +two 2 two 2 two 2 two 2 diff --git a/sql/core/src/test/resources/sql-tests/results/cte.sql.out b/sql/core/src/test/resources/sql-tests/results/cte.sql.out index ddee5bf2d4..9fbad8f380 100644 --- a/sql/core/src/test/resources/sql-tests/results/cte.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/cte.sql.out @@ -47,7 +47,7 @@ Table or view not found: s2; line 1 pos 26 -- !query 5 -WITH t1 AS (SELECT * FROM t2), t2 AS (SELECT 2 FROM t1) SELECT * FROM t1, t2 +WITH t1 AS (SELECT * FROM t2), t2 AS (SELECT 2 FROM t1) SELECT * FROM t1 cross join t2 -- !query 5 schema struct<id:int,2:int> -- !query 5 output diff --git a/sql/core/src/test/resources/sql-tests/results/outer-join.sql.out b/sql/core/src/test/resources/sql-tests/results/outer-join.sql.out index b39fdb0e58..cc50b9444b 100644 --- a/sql/core/src/test/resources/sql-tests/results/outer-join.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/outer-join.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 6 +-- Number of queries: 8 -- !query 0 @@ -59,6 +59,14 @@ struct<> -- !query 5 +set spark.sql.crossJoin.enabled = true +-- !query 5 schema +struct<key:string,value:string> +-- !query 5 output +spark.sql.crossJoin.enabled + + +-- !query 6 SELECT * FROM ( SELECT @@ -66,7 +74,15 @@ SELECT FROM t1 LEFT JOIN t2 ON false ) t where (t.int_col) is not null --- !query 5 schema +-- !query 6 schema struct<int_col:int> --- !query 5 output +-- !query 6 output 97 + + +-- !query 7 +set spark.sql.crossJoin.enabled = false +-- !query 7 schema +struct<key:string,value:string> +-- !query 7 output +spark.sql.crossJoin.enabled diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 4abf5e42b9..541ffb58e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -104,6 +104,21 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { .collect().toSeq) } + test("join - cross join") { + val df1 = Seq((1, "1"), (3, "3")).toDF("int", "str") + val df2 = Seq((2, "2"), (4, "4")).toDF("int", "str") + + checkAnswer( + df1.crossJoin(df2), + Row(1, "1", 2, "2") :: Row(1, "1", 4, "4") :: + Row(3, "3", 2, "2") :: Row(3, "3", 4, "4") :: Nil) + + checkAnswer( + df2.crossJoin(df1), + Row(2, "2", 1, "1") :: Row(2, "2", 3, "3") :: + Row(4, "4", 1, "1") :: Row(4, "4", 3, "3") :: Nil) + } + test("join - using aliases after self join") { val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") checkAnswer( @@ -145,7 +160,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { assert(plan1.collect { case p: BroadcastHashJoinExec => p }.size === 1) // no join key -- should not be a broadcast join - val plan2 = df1.join(broadcast(df2)).queryExecution.sparkPlan + val plan2 = df1.crossJoin(broadcast(df2)).queryExecution.sparkPlan assert(plan2.collect { case p: BroadcastHashJoinExec => p }.size === 0) // planner should not crash without a join @@ -155,7 +170,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { withTempPath { path => df1.write.parquet(path.getCanonicalPath) val pf1 = spark.read.parquet(path.getCanonicalPath) - assert(df1.join(broadcast(pf1)).count() === 4) + assert(df1.crossJoin(broadcast(pf1)).count() === 4) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index f89951760f..c2d256bdd3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -626,9 +626,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("drop(name: String) search and drop all top level columns that matchs the name") { val df1 = Seq((1, 2)).toDF("a", "b") val df2 = Seq((3, 4)).toDF("a", "b") - checkAnswer(df1.join(df2), Row(1, 2, 3, 4)) + checkAnswer(df1.crossJoin(df2), Row(1, 2, 3, 4)) // Finds and drops all columns that match the name (case insensitive). - checkAnswer(df1.join(df2).drop("A"), Row(2, 4)) + checkAnswer(df1.crossJoin(df2).drop("A"), Row(2, 4)) } test("withColumnRenamed") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 8ce6ea66b6..3243f352a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -466,7 +466,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("self join") { val ds = Seq("1", "2").toDS().as("a") - val joined = ds.joinWith(ds, lit(true)) + val joined = ds.joinWith(ds, lit(true), "cross") checkDataset(joined, ("1", "1"), ("1", "2"), ("2", "1"), ("2", "2")) } @@ -486,7 +486,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("Kryo encoder self join") { implicit val kryoEncoder = Encoders.kryo[KryoData] val ds = Seq(KryoData(1), KryoData(2)).toDS() - assert(ds.joinWith(ds, lit(true)).collect().toSet == + assert(ds.joinWith(ds, lit(true), "cross").collect().toSet == Set( (KryoData(1), KryoData(1)), (KryoData(1), KryoData(2)), @@ -514,7 +514,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("Java encoder self join") { implicit val kryoEncoder = Encoders.javaSerialization[JavaData] val ds = Seq(JavaData(1), JavaData(2)).toDS() - assert(ds.joinWith(ds, lit(true)).collect().toSet == + assert(ds.joinWith(ds, lit(true), "cross").collect().toSet == Set( (JavaData(1), JavaData(1)), (JavaData(1), JavaData(2)), @@ -532,7 +532,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds2 = Seq((nullInt, "1"), (new java.lang.Integer(22), "2")).toDS() checkDataset( - ds1.joinWith(ds2, lit(true)), + ds1.joinWith(ds2, lit(true), "cross"), ((nullInt, "1"), (nullInt, "1")), ((nullInt, "1"), (new java.lang.Integer(22), "2")), ((new java.lang.Integer(22), "2"), (nullInt, "1")), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 44889d92ee..913b2ae976 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -225,8 +225,8 @@ class JoinSuite extends QueryTest with SharedSQLContext { Row(2, 2, 1, null) :: Row(2, 2, 2, 2) :: Nil) } - assert(e.getMessage.contains("Cartesian joins could be prohibitively expensive and are " + - "disabled by default")) + assert(e.getMessage.contains("Detected cartesian product for INNER join " + + "between logical plans")) } } @@ -482,7 +482,8 @@ class JoinSuite extends QueryTest with SharedSQLContext { // we set the threshold is greater than statistic of the cached table testData withSQLConf( - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> (sizeInByteOfTestData + 1).toString()) { + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> (sizeInByteOfTestData + 1).toString(), + SQLConf.CROSS_JOINS_ENABLED.key -> "true") { assert(statisticSizeInByte(spark.table("testData2")) > spark.conf.get(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD)) @@ -573,4 +574,34 @@ class JoinSuite extends QueryTest with SharedSQLContext { Row(3, 1) :: Row(3, 2) :: Nil) } + + test("cross join detection") { + testData.createOrReplaceTempView("A") + testData.createOrReplaceTempView("B") + testData2.createOrReplaceTempView("C") + testData3.createOrReplaceTempView("D") + upperCaseData.where('N >= 3).createOrReplaceTempView("`right`") + val cartesianQueries = Seq( + /** The following should error out since there is no explicit cross join */ + "SELECT * FROM testData inner join testData2", + "SELECT * FROM testData left outer join testData2", + "SELECT * FROM testData right outer join testData2", + "SELECT * FROM testData full outer join testData2", + "SELECT * FROM testData, testData2", + "SELECT * FROM testData, testData2 where testData.key = 1 and testData2.a = 22", + /** The following should fail because after reordering there are cartesian products */ + "select * from (A join B on (A.key = B.key)) join D on (A.key=D.a) join C", + "select * from ((A join B on (A.key = B.key)) join C) join D on (A.key = D.a)", + /** Cartesian product involving C, which is not involved in a CROSS join */ + "select * from ((A join B on (A.key = B.key)) cross join D) join C on (A.key = D.a)"); + + def checkCartesianDetection(query: String): Unit = { + val e = intercept[Exception] { + checkAnswer(sql(query), Nil); + } + assert(e.getMessage.contains("Detected cartesian product")) + } + + cartesianQueries.foreach(checkCartesianDetection) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala index d3cfa953a3..afd47897ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala @@ -361,7 +361,8 @@ class SQLWindowFunctionSuite extends QueryTest with SharedSQLContext { |with | v0 as (select 0 as key, 1 as value), | v1 as (select key, count(value) over (partition by key) cnt_val from v0), - | v2 as (select v1.key, v1_lag.cnt_val from v1, v1 v1_lag where v1.key = v1_lag.key) + | v2 as (select v1.key, v1_lag.cnt_val from v1 cross join v1 v1_lag + | where v1.key = v1_lag.key) | select key, cnt_val from v2 order by key limit 1 """.stripMargin), Row(0, 1)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 35dab63672..4408ece112 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -109,8 +109,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { leftPlan: SparkPlan, rightPlan: SparkPlan, side: BuildSide) = { - val shuffledHashJoin = - joins.ShuffledHashJoinExec(leftKeys, rightKeys, Inner, side, None, leftPlan, rightPlan) + val shuffledHashJoin = joins.ShuffledHashJoinExec(leftKeys, rightKeys, Inner, + side, None, leftPlan, rightPlan) val filteredJoin = boundCondition.map(FilterExec(_, shuffledHashJoin)).getOrElse(shuffledHashJoin) EnsureRequirements(spark.sessionState.conf).apply(filteredJoin) @@ -122,8 +122,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { boundCondition: Option[Expression], leftPlan: SparkPlan, rightPlan: SparkPlan) = { - val sortMergeJoin = - joins.SortMergeJoinExec(leftKeys, rightKeys, Inner, boundCondition, leftPlan, rightPlan) + val sortMergeJoin = joins.SortMergeJoinExec(leftKeys, rightKeys, Inner, boundCondition, + leftPlan, rightPlan) EnsureRequirements(spark.sessionState.conf).apply(sortMergeJoin) } |