aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSameer Agarwal <sameer@databricks.com>2016-05-22 23:32:39 -0700
committerReynold Xin <rxin@databricks.com>2016-05-22 23:32:39 -0700
commitdafcb05c2ef8e09f45edfb7eabf58116c23975a0 (patch)
tree7c37771c4144b61cd31831e7de4671b0e6b42e12
parentfc44b694bf5162b3a044768da4627b9969909829 (diff)
downloadspark-dafcb05c2ef8e09f45edfb7eabf58116c23975a0.tar.gz
spark-dafcb05c2ef8e09f45edfb7eabf58116c23975a0.tar.bz2
spark-dafcb05c2ef8e09f45edfb7eabf58116c23975a0.zip
[SPARK-15425][SQL] Disallow cross joins by default
## What changes were proposed in this pull request? In order to prevent users from inadvertently writing queries with cartesian joins, this patch introduces a new conf `spark.sql.crossJoin.enabled` (set to `false` by default) that if not set, results in a `SparkException` if the query contains one or more cartesian products. ## How was this patch tested? Added a test to verify the new behavior in `JoinSuite`. Additionally, `SQLQuerySuite` and `SQLMetricsSuite` were modified to explicitly enable cartesian products. Author: Sameer Agarwal <sameer@databricks.com> Closes #13209 from sameeragarwal/disallow-cartesian.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala14
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala9
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala31
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala32
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala46
-rw-r--r--sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala4
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala6
10 files changed, 113 insertions, 46 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 555a2f4c01..c46cecc71f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -190,7 +190,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
// This join could be very slow or OOM
joins.BroadcastNestedLoopJoinExec(
- planLater(left), planLater(right), buildSide, joinType, condition) :: Nil
+ planLater(left), planLater(right), buildSide, joinType, condition,
+ withinBroadcastThreshold = false) :: Nil
// --- Cases where this strategy does not apply ---------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
index 2a250ecce6..4d43765f8f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
@@ -19,12 +19,14 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.collection.{BitSet, CompactBuffer}
case class BroadcastNestedLoopJoinExec(
@@ -32,7 +34,8 @@ case class BroadcastNestedLoopJoinExec(
right: SparkPlan,
buildSide: BuildSide,
joinType: JoinType,
- condition: Option[Expression]) extends BinaryExecNode {
+ condition: Option[Expression],
+ withinBroadcastThreshold: Boolean = true) extends BinaryExecNode {
override private[sql] lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
@@ -337,6 +340,15 @@ case class BroadcastNestedLoopJoinExec(
)
}
+ protected override def doPrepare(): Unit = {
+ if (!withinBroadcastThreshold && !sqlContext.conf.crossJoinEnabled) {
+ throw new AnalysisException("Both sides of this join are outside the broadcasting " +
+ "threshold and computing it could be prohibitively expensive. To explicitly enable it, " +
+ s"please set ${SQLConf.CROSS_JOINS_ENABLED.key} = true")
+ }
+ super.doPrepare()
+ }
+
protected override def doExecute(): RDD[InternalRow] = {
val broadcastedRelation = broadcast.executeBroadcast[Array[InternalRow]]()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
index 8d7ecc442a..88f78a7a73 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
@@ -19,11 +19,13 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark._
import org.apache.spark.rdd.{CartesianPartition, CartesianRDD, RDD}
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, JoinedRow, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
@@ -88,6 +90,15 @@ case class CartesianProductExec(
override private[sql] lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
+ protected override def doPrepare(): Unit = {
+ if (!sqlContext.conf.crossJoinEnabled) {
+ throw new AnalysisException("Cartesian joins could be prohibitively expensive and are " +
+ "disabled by default. To explicitly enable them, please set " +
+ s"${SQLConf.CROSS_JOINS_ENABLED.key} = true")
+ }
+ super.doPrepare()
+ }
+
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 35d67ca2d8..f3064eb6ac 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -338,9 +338,14 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
+ val CROSS_JOINS_ENABLED = SQLConfigBuilder("spark.sql.crossJoin.enabled")
+ .doc("When false, we will throw an error if a query contains a cross join")
+ .booleanConf
+ .createWithDefault(false)
+
val ORDER_BY_ORDINAL = SQLConfigBuilder("spark.sql.orderByOrdinal")
.doc("When true, the ordinal numbers are treated as the position in the select list. " +
- "When false, the ordinal numbers in order/sort By clause are ignored.")
+ "When false, the ordinal numbers in order/sort by clause are ignored.")
.booleanConf
.createWithDefault(true)
@@ -622,6 +627,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging {
def bucketingEnabled: Boolean = getConf(SQLConf.BUCKETING_ENABLED)
+ def crossJoinEnabled: Boolean = getConf(SQLConf.CROSS_JOINS_ENABLED)
+
// Do not use a value larger than 4000 as the default value of this property.
// See the comments of SCHEMA_STRING_LENGTH_THRESHOLD above for more information.
def schemaStringLengthThreshold: Int = getConf(SCHEMA_STRING_LENGTH_THRESHOLD)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index a5d8cb19ea..5583673708 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -62,7 +62,8 @@ class JoinSuite extends QueryTest with SharedSQLContext {
test("join operator selection") {
spark.cacheManager.clearCache()
- withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "0") {
+ withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "0",
+ SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
Seq(
("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a",
classOf[SortMergeJoinExec]),
@@ -204,13 +205,27 @@ class JoinSuite extends QueryTest with SharedSQLContext {
testData.rdd.flatMap(row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq)
}
- test("cartisian product join") {
- checkAnswer(
- testData3.join(testData3),
- Row(1, null, 1, null) ::
- Row(1, null, 2, 2) ::
- Row(2, 2, 1, null) ::
- Row(2, 2, 2, 2) :: Nil)
+ test("cartesian product join") {
+ withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
+ checkAnswer(
+ testData3.join(testData3),
+ Row(1, null, 1, null) ::
+ Row(1, null, 2, 2) ::
+ Row(2, 2, 1, null) ::
+ Row(2, 2, 2, 2) :: Nil)
+ }
+ withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "false") {
+ val e = intercept[Exception] {
+ checkAnswer(
+ testData3.join(testData3),
+ Row(1, null, 1, null) ::
+ Row(1, null, 2, 2) ::
+ Row(2, 2, 1, null) ::
+ Row(2, 2, 2, 2) :: Nil)
+ }
+ assert(e.getMessage.contains("Cartesian joins could be prohibitively expensive and are " +
+ "disabled by default"))
+ }
}
test("left outer join") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 460e34a5ff..b1f848fdc8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -104,9 +104,11 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
).toDF("a", "b", "c").createOrReplaceTempView("cachedData")
spark.catalog.cacheTable("cachedData")
- checkAnswer(
- sql("SELECT t1.b FROM cachedData, cachedData t1 GROUP BY t1.b"),
- Row(0) :: Row(81) :: Nil)
+ withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
+ checkAnswer(
+ sql("SELECT t1.b FROM cachedData, cachedData t1 GROUP BY t1.b"),
+ Row(0) :: Row(81) :: Nil)
+ }
}
test("self join with aliases") {
@@ -435,10 +437,12 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
test("left semi greater than predicate") {
- checkAnswer(
- sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.a >= y.a + 2"),
- Seq(Row(3, 1), Row(3, 2))
- )
+ withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
+ checkAnswer(
+ sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.a >= y.a + 2"),
+ Seq(Row(3, 1), Row(3, 2))
+ )
+ }
}
test("left semi greater than predicate and equal operator") {
@@ -824,12 +828,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
test("cartesian product join") {
- checkAnswer(
- testData3.join(testData3),
- Row(1, null, 1, null) ::
- Row(1, null, 2, 2) ::
- Row(2, 2, 1, null) ::
- Row(2, 2, 2, 2) :: Nil)
+ withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
+ checkAnswer(
+ testData3.join(testData3),
+ Row(1, null, 1, null) ::
+ Row(1, null, 2, 2) ::
+ Row(2, 2, 1, null) ::
+ Row(2, 2, 2, 2) :: Nil)
+ }
}
test("left outer join") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
index 7caeb3be54..27f6abcd95 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
@@ -187,7 +187,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
}
test(s"$testName using CartesianProduct") {
- withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1",
+ SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
CartesianProductExec(left, right, Some(condition())),
expectedAnswer.map(Row.fromTuple),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index 7a89b484eb..12940c86fe 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -29,6 +29,7 @@ import org.apache.spark.sql._
import org.apache.spark.sql.execution.SparkPlanInfo
import org.apache.spark.sql.execution.ui.SparkPlanGraph
import org.apache.spark.sql.functions._
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.util.{AccumulatorContext, JsonProtocol, Utils}
@@ -237,16 +238,18 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
test("BroadcastNestedLoopJoin metrics") {
val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
testDataForJoin.createOrReplaceTempView("testDataForJoin")
- withTempTable("testDataForJoin") {
- // Assume the execution plan is
- // ... -> BroadcastNestedLoopJoin(nodeId = 1) -> TungstenProject(nodeId = 0)
- val df = spark.sql(
- "SELECT * FROM testData2 left JOIN testDataForJoin ON " +
- "testData2.a * testDataForJoin.a != testData2.a + testDataForJoin.a")
- testSparkPlanMetrics(df, 3, Map(
- 1L -> ("BroadcastNestedLoopJoin", Map(
- "number of output rows" -> 12L)))
- )
+ withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
+ withTempTable("testDataForJoin") {
+ // Assume the execution plan is
+ // ... -> BroadcastNestedLoopJoin(nodeId = 1) -> TungstenProject(nodeId = 0)
+ val df = spark.sql(
+ "SELECT * FROM testData2 left JOIN testDataForJoin ON " +
+ "testData2.a * testDataForJoin.a != testData2.a + testDataForJoin.a")
+ testSparkPlanMetrics(df, 3, Map(
+ 1L -> ("BroadcastNestedLoopJoin", Map(
+ "number of output rows" -> 12L)))
+ )
+ }
}
}
@@ -263,17 +266,18 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
}
test("CartesianProduct metrics") {
- val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
- testDataForJoin.createOrReplaceTempView("testDataForJoin")
- withTempTable("testDataForJoin") {
- // Assume the execution plan is
- // ... -> CartesianProduct(nodeId = 1) -> TungstenProject(nodeId = 0)
- val df = spark.sql(
- "SELECT * FROM testData2 JOIN testDataForJoin")
- testSparkPlanMetrics(df, 1, Map(
- 0L -> ("CartesianProduct", Map(
- "number of output rows" -> 12L)))
- )
+ withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
+ val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
+ testDataForJoin.createOrReplaceTempView("testDataForJoin")
+ withTempTable("testDataForJoin") {
+ // Assume the execution plan is
+ // ... -> CartesianProduct(nodeId = 1) -> TungstenProject(nodeId = 0)
+ val df = spark.sql(
+ "SELECT * FROM testData2 JOIN testDataForJoin")
+ testSparkPlanMetrics(df, 1, Map(
+ 0L -> ("CartesianProduct", Map("number of output rows" -> 12L)))
+ )
+ }
}
}
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
index 54fb440b33..a8645f7cd3 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
@@ -40,6 +40,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
private val originalColumnBatchSize = TestHive.conf.columnBatchSize
private val originalInMemoryPartitionPruning = TestHive.conf.inMemoryPartitionPruning
private val originalConvertMetastoreOrc = TestHive.sessionState.convertMetastoreOrc
+ private val originalCrossJoinEnabled = TestHive.conf.crossJoinEnabled
def testCases: Seq[(String, File)] = {
hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f)
@@ -61,6 +62,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
// Ensures that the plans generation use metastore relation and not OrcRelation
// Was done because SqlBuilder does not work with plans having logical relation
TestHive.setConf(HiveUtils.CONVERT_METASTORE_ORC, false)
+ // Ensures that cross joins are enabled so that we can test them
+ TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, true)
RuleExecutor.resetTime()
}
@@ -72,6 +75,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize)
TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning)
TestHive.setConf(HiveUtils.CONVERT_METASTORE_ORC, originalConvertMetastoreOrc)
+ TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, originalCrossJoinEnabled)
TestHive.sessionState.functionRegistry.restore()
// For debugging dump some statistics about how much time was spent in various optimizer rules
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index 2aaaaadb6a..e179021491 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -35,6 +35,7 @@ import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoinExec
import org.apache.spark.sql.hive._
import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext}
import org.apache.spark.sql.hive.test.TestHive._
+import org.apache.spark.sql.internal.SQLConf
case class TestData(a: Int, b: String)
@@ -48,6 +49,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
import org.apache.spark.sql.hive.test.TestHive.implicits._
+ private val originalCrossJoinEnabled = TestHive.conf.crossJoinEnabled
+
override def beforeAll() {
super.beforeAll()
TestHive.setCacheTables(true)
@@ -55,6 +58,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles"))
// Add Locale setting
Locale.setDefault(Locale.US)
+ // Ensures that cross joins are enabled so that we can test them
+ TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, true)
}
override def afterAll() {
@@ -63,6 +68,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
TimeZone.setDefault(originalTimeZone)
Locale.setDefault(originalLocale)
sql("DROP TEMPORARY FUNCTION IF EXISTS udtf_count2")
+ TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, originalCrossJoinEnabled)
} finally {
super.afterAll()
}