aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/test/scala
diff options
context:
space:
mode:
authorSrinath Shankar <srinath@databricks.com>2016-09-03 00:20:43 +0200
committerHerman van Hovell <hvanhovell@databricks.com>2016-09-03 00:20:43 +0200
commite6132a6cf10df8b12af8dd8d1a2c563792b5cc5a (patch)
treed706ac4d4091a7ae31eda5c7d62c2d8c2c4a7414 /sql/core/src/test/scala
parenta2c9acb0e54b2e38cb8ee6431f1ea0e0b4cd959a (diff)
downloadspark-e6132a6cf10df8b12af8dd8d1a2c563792b5cc5a.tar.gz
spark-e6132a6cf10df8b12af8dd8d1a2c563792b5cc5a.tar.bz2
spark-e6132a6cf10df8b12af8dd8d1a2c563792b5cc5a.zip
[SPARK-17298][SQL] Require explicit CROSS join for cartesian products
## What changes were proposed in this pull request? Require the use of CROSS join syntax in SQL (and a new crossJoin DataFrame API) to specify explicit cartesian products between relations. By cartesian product we mean a join between relations R and S where there is no join condition involving columns from both R and S. If a cartesian product is detected in the absence of an explicit CROSS join, an error must be thrown. Turning on the "spark.sql.crossJoin.enabled" configuration flag will disable this check and allow cartesian products without an explicit CROSS join. The new crossJoin DataFrame API must be used to specify explicit cross joins. The existing join(DataFrame) method will produce a INNER join that will require a subsequent join condition. That is df1.join(df2) is equivalent to select * from df1, df2. ## How was this patch tested? Added cross-join.sql to the SQLQueryTestSuite to test the check for cartesian products. Added a couple of tests to the DataFrameJoinSuite to test the crossJoin API. Modified various other test suites to explicitly specify a cross join where an INNER join or a comma-separated list was previously used. Author: Srinath Shankar <srinath@databricks.com> Closes #14866 from srinathshankar/crossjoin.
Diffstat (limited to 'sql/core/src/test/scala')
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala19
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala37
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala8
6 files changed, 63 insertions, 16 deletions
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
index 4abf5e42b9..541ffb58e7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
@@ -104,6 +104,21 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext {
.collect().toSeq)
}
+ test("join - cross join") {
+ val df1 = Seq((1, "1"), (3, "3")).toDF("int", "str")
+ val df2 = Seq((2, "2"), (4, "4")).toDF("int", "str")
+
+ checkAnswer(
+ df1.crossJoin(df2),
+ Row(1, "1", 2, "2") :: Row(1, "1", 4, "4") ::
+ Row(3, "3", 2, "2") :: Row(3, "3", 4, "4") :: Nil)
+
+ checkAnswer(
+ df2.crossJoin(df1),
+ Row(2, "2", 1, "1") :: Row(2, "2", 3, "3") ::
+ Row(4, "4", 1, "1") :: Row(4, "4", 3, "3") :: Nil)
+ }
+
test("join - using aliases after self join") {
val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str")
checkAnswer(
@@ -145,7 +160,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext {
assert(plan1.collect { case p: BroadcastHashJoinExec => p }.size === 1)
// no join key -- should not be a broadcast join
- val plan2 = df1.join(broadcast(df2)).queryExecution.sparkPlan
+ val plan2 = df1.crossJoin(broadcast(df2)).queryExecution.sparkPlan
assert(plan2.collect { case p: BroadcastHashJoinExec => p }.size === 0)
// planner should not crash without a join
@@ -155,7 +170,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext {
withTempPath { path =>
df1.write.parquet(path.getCanonicalPath)
val pf1 = spark.read.parquet(path.getCanonicalPath)
- assert(df1.join(broadcast(pf1)).count() === 4)
+ assert(df1.crossJoin(broadcast(pf1)).count() === 4)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index f89951760f..c2d256bdd3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -626,9 +626,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
test("drop(name: String) search and drop all top level columns that matchs the name") {
val df1 = Seq((1, 2)).toDF("a", "b")
val df2 = Seq((3, 4)).toDF("a", "b")
- checkAnswer(df1.join(df2), Row(1, 2, 3, 4))
+ checkAnswer(df1.crossJoin(df2), Row(1, 2, 3, 4))
// Finds and drops all columns that match the name (case insensitive).
- checkAnswer(df1.join(df2).drop("A"), Row(2, 4))
+ checkAnswer(df1.crossJoin(df2).drop("A"), Row(2, 4))
}
test("withColumnRenamed") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 8ce6ea66b6..3243f352a5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -466,7 +466,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("self join") {
val ds = Seq("1", "2").toDS().as("a")
- val joined = ds.joinWith(ds, lit(true))
+ val joined = ds.joinWith(ds, lit(true), "cross")
checkDataset(joined, ("1", "1"), ("1", "2"), ("2", "1"), ("2", "2"))
}
@@ -486,7 +486,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("Kryo encoder self join") {
implicit val kryoEncoder = Encoders.kryo[KryoData]
val ds = Seq(KryoData(1), KryoData(2)).toDS()
- assert(ds.joinWith(ds, lit(true)).collect().toSet ==
+ assert(ds.joinWith(ds, lit(true), "cross").collect().toSet ==
Set(
(KryoData(1), KryoData(1)),
(KryoData(1), KryoData(2)),
@@ -514,7 +514,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("Java encoder self join") {
implicit val kryoEncoder = Encoders.javaSerialization[JavaData]
val ds = Seq(JavaData(1), JavaData(2)).toDS()
- assert(ds.joinWith(ds, lit(true)).collect().toSet ==
+ assert(ds.joinWith(ds, lit(true), "cross").collect().toSet ==
Set(
(JavaData(1), JavaData(1)),
(JavaData(1), JavaData(2)),
@@ -532,7 +532,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
val ds2 = Seq((nullInt, "1"), (new java.lang.Integer(22), "2")).toDS()
checkDataset(
- ds1.joinWith(ds2, lit(true)),
+ ds1.joinWith(ds2, lit(true), "cross"),
((nullInt, "1"), (nullInt, "1")),
((nullInt, "1"), (new java.lang.Integer(22), "2")),
((new java.lang.Integer(22), "2"), (nullInt, "1")),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 44889d92ee..913b2ae976 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -225,8 +225,8 @@ class JoinSuite extends QueryTest with SharedSQLContext {
Row(2, 2, 1, null) ::
Row(2, 2, 2, 2) :: Nil)
}
- assert(e.getMessage.contains("Cartesian joins could be prohibitively expensive and are " +
- "disabled by default"))
+ assert(e.getMessage.contains("Detected cartesian product for INNER join " +
+ "between logical plans"))
}
}
@@ -482,7 +482,8 @@ class JoinSuite extends QueryTest with SharedSQLContext {
// we set the threshold is greater than statistic of the cached table testData
withSQLConf(
- SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> (sizeInByteOfTestData + 1).toString()) {
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> (sizeInByteOfTestData + 1).toString(),
+ SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
assert(statisticSizeInByte(spark.table("testData2")) >
spark.conf.get(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD))
@@ -573,4 +574,34 @@ class JoinSuite extends QueryTest with SharedSQLContext {
Row(3, 1) ::
Row(3, 2) :: Nil)
}
+
+ test("cross join detection") {
+ testData.createOrReplaceTempView("A")
+ testData.createOrReplaceTempView("B")
+ testData2.createOrReplaceTempView("C")
+ testData3.createOrReplaceTempView("D")
+ upperCaseData.where('N >= 3).createOrReplaceTempView("`right`")
+ val cartesianQueries = Seq(
+ /** The following should error out since there is no explicit cross join */
+ "SELECT * FROM testData inner join testData2",
+ "SELECT * FROM testData left outer join testData2",
+ "SELECT * FROM testData right outer join testData2",
+ "SELECT * FROM testData full outer join testData2",
+ "SELECT * FROM testData, testData2",
+ "SELECT * FROM testData, testData2 where testData.key = 1 and testData2.a = 22",
+ /** The following should fail because after reordering there are cartesian products */
+ "select * from (A join B on (A.key = B.key)) join D on (A.key=D.a) join C",
+ "select * from ((A join B on (A.key = B.key)) join C) join D on (A.key = D.a)",
+ /** Cartesian product involving C, which is not involved in a CROSS join */
+ "select * from ((A join B on (A.key = B.key)) cross join D) join C on (A.key = D.a)");
+
+ def checkCartesianDetection(query: String): Unit = {
+ val e = intercept[Exception] {
+ checkAnswer(sql(query), Nil);
+ }
+ assert(e.getMessage.contains("Detected cartesian product"))
+ }
+
+ cartesianQueries.foreach(checkCartesianDetection)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala
index d3cfa953a3..afd47897ed 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala
@@ -361,7 +361,8 @@ class SQLWindowFunctionSuite extends QueryTest with SharedSQLContext {
|with
| v0 as (select 0 as key, 1 as value),
| v1 as (select key, count(value) over (partition by key) cnt_val from v0),
- | v2 as (select v1.key, v1_lag.cnt_val from v1, v1 v1_lag where v1.key = v1_lag.key)
+ | v2 as (select v1.key, v1_lag.cnt_val from v1 cross join v1 v1_lag
+ | where v1.key = v1_lag.key)
| select key, cnt_val from v2 order by key limit 1
""".stripMargin), Row(0, 1))
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
index 35dab63672..4408ece112 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
@@ -109,8 +109,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
leftPlan: SparkPlan,
rightPlan: SparkPlan,
side: BuildSide) = {
- val shuffledHashJoin =
- joins.ShuffledHashJoinExec(leftKeys, rightKeys, Inner, side, None, leftPlan, rightPlan)
+ val shuffledHashJoin = joins.ShuffledHashJoinExec(leftKeys, rightKeys, Inner,
+ side, None, leftPlan, rightPlan)
val filteredJoin =
boundCondition.map(FilterExec(_, shuffledHashJoin)).getOrElse(shuffledHashJoin)
EnsureRequirements(spark.sessionState.conf).apply(filteredJoin)
@@ -122,8 +122,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
boundCondition: Option[Expression],
leftPlan: SparkPlan,
rightPlan: SparkPlan) = {
- val sortMergeJoin =
- joins.SortMergeJoinExec(leftKeys, rightKeys, Inner, boundCondition, leftPlan, rightPlan)
+ val sortMergeJoin = joins.SortMergeJoinExec(leftKeys, rightKeys, Inner, boundCondition,
+ leftPlan, rightPlan)
EnsureRequirements(spark.sessionState.conf).apply(sortMergeJoin)
}