aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCheng Hao <hao.cheng@intel.com>2015-10-27 20:26:38 -0700
committerYin Huai <yhuai@databricks.com>2015-10-27 20:26:38 -0700
commitd9c6039897236c3f1e4503aa95c5c9b07b32eadd (patch)
tree3f9ddb1f1c7b91ef3cb7073cd7094522e48b340f
parentb960a890561eaf3795b93c621bd95be81e56f5b7 (diff)
downloadspark-d9c6039897236c3f1e4503aa95c5c9b07b32eadd.tar.gz
spark-d9c6039897236c3f1e4503aa95c5c9b07b32eadd.tar.bz2
spark-d9c6039897236c3f1e4503aa95c5c9b07b32eadd.zip
[SPARK-10484] [SQL] Optimize the cartesian join with broadcast join for some cases
In some cases, we can broadcast the smaller relation in cartesian join, which improve the performance significantly. Author: Cheng Hao <hao.cheng@intel.com> Closes #8652 from chenghao-intel/cartesian.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala38
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala92
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala3
-rw-r--r--sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #1-0-abfc0b99ee357f71639f6162345fe8e20
-rw-r--r--sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #2-0-8412a39ee57885ccb0aaf848db8ef1dd20
-rw-r--r--sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #3-0-e8a0427dbde35eea6011144443e5ffb420
-rw-r--r--sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #4-0-45f8602d257655322b7d18cad09f6a0f20
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala54
10 files changed, 261 insertions, 16 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
index b346f43fae..0f98fe88b2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
@@ -44,8 +44,9 @@ class SparkPlanner(val sqlContext: SQLContext) extends SparkStrategies {
EquiJoinSelection ::
InMemoryScans ::
BasicOperators ::
+ BroadcastNestedLoop ::
CartesianProduct ::
- BroadcastNestedLoopJoin :: Nil)
+ DefaultJoin :: Nil)
/**
* Used to build table scan operators where complex projection and filtering are done using
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 637deff4e2..ee97162853 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -294,25 +294,24 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
}
-
- object BroadcastNestedLoopJoin extends Strategy {
+ object BroadcastNestedLoop extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case logical.Join(left, right, joinType, condition) =>
- val buildSide =
- if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) {
- joins.BuildRight
- } else {
- joins.BuildLeft
- }
- joins.BroadcastNestedLoopJoin(
- planLater(left), planLater(right), buildSide, joinType, condition) :: Nil
+ case logical.Join(
+ CanBroadcast(left), right, joinType, condition) if joinType != LeftSemiJoin =>
+ execution.joins.BroadcastNestedLoopJoin(
+ planLater(left), planLater(right), joins.BuildLeft, joinType, condition) :: Nil
+ case logical.Join(
+ left, CanBroadcast(right), joinType, condition) if joinType != LeftSemiJoin =>
+ execution.joins.BroadcastNestedLoopJoin(
+ planLater(left), planLater(right), joins.BuildRight, joinType, condition) :: Nil
case _ => Nil
}
}
object CartesianProduct extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case logical.Join(left, right, _, None) =>
+ // TODO CartesianProduct doesn't support the Left Semi Join
+ case logical.Join(left, right, joinType, None) if joinType != LeftSemiJoin =>
execution.joins.CartesianProduct(planLater(left), planLater(right)) :: Nil
case logical.Join(left, right, Inner, Some(condition)) =>
execution.Filter(condition,
@@ -321,6 +320,21 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
}
+ object DefaultJoin extends Strategy {
+ def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ case logical.Join(left, right, joinType, condition) =>
+ val buildSide =
+ if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) {
+ joins.BuildRight
+ } else {
+ joins.BuildLeft
+ }
+ joins.BroadcastNestedLoopJoin(
+ planLater(left), planLater(right), buildSide, joinType, condition) :: Nil
+ case _ => Nil
+ }
+ }
+
protected lazy val singleRowRdd = sparkContext.parallelize(Seq(InternalRow()), 1)
object TakeOrderedAndProject extends Strategy {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
index efef8c8a8b..05d20f511a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
@@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
-import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter}
+import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.util.collection.CompactBuffer
@@ -67,7 +67,10 @@ case class BroadcastNestedLoopJoin(
left.output.map(_.withNullability(true)) ++ right.output
case FullOuter =>
left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
- case x =>
+ case Inner =>
+ // TODO we can avoid breaking the lineage, since we union an empty RDD for Inner Join case
+ left.output ++ right.output
+ case x => // TODO support the Left Semi Join
throw new IllegalArgumentException(
s"BroadcastNestedLoopJoin should not take $x as the JoinType")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index b1fb068158..a9ca46cab0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -28,6 +28,10 @@ class JoinSuite extends QueryTest with SharedSQLContext {
setupTestData()
+ def statisticSizeInByte(df: DataFrame): BigInt = {
+ df.queryExecution.optimizedPlan.statistics.sizeInBytes
+ }
+
test("equi-join is hash-join") {
val x = testData2.as("x")
val y = testData2.as("y")
@@ -466,6 +470,94 @@ class JoinSuite extends QueryTest with SharedSQLContext {
sql("UNCACHE TABLE testData")
}
+ test("cross join with broadcast") {
+ sql("CACHE TABLE testData")
+
+ val sizeInByteOfTestData = statisticSizeInByte(sqlContext.table("testData"))
+
+ // we set the threshold is greater than statistic of the cached table testData
+ withSQLConf(
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> (sizeInByteOfTestData + 1).toString()) {
+
+ assert(statisticSizeInByte(sqlContext.table("testData2")) >
+ sqlContext.conf.autoBroadcastJoinThreshold)
+
+ assert(statisticSizeInByte(sqlContext.table("testData")) <
+ sqlContext.conf.autoBroadcastJoinThreshold)
+
+ Seq(
+ ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a",
+ classOf[LeftSemiJoinHash]),
+ ("SELECT * FROM testData LEFT SEMI JOIN testData2",
+ classOf[LeftSemiJoinBNL]),
+ ("SELECT * FROM testData JOIN testData2",
+ classOf[BroadcastNestedLoopJoin]),
+ ("SELECT * FROM testData JOIN testData2 WHERE key = 2",
+ classOf[BroadcastNestedLoopJoin]),
+ ("SELECT * FROM testData LEFT JOIN testData2",
+ classOf[BroadcastNestedLoopJoin]),
+ ("SELECT * FROM testData RIGHT JOIN testData2",
+ classOf[BroadcastNestedLoopJoin]),
+ ("SELECT * FROM testData FULL OUTER JOIN testData2",
+ classOf[BroadcastNestedLoopJoin]),
+ ("SELECT * FROM testData LEFT JOIN testData2 WHERE key = 2",
+ classOf[BroadcastNestedLoopJoin]),
+ ("SELECT * FROM testData RIGHT JOIN testData2 WHERE key = 2",
+ classOf[BroadcastNestedLoopJoin]),
+ ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key = 2",
+ classOf[BroadcastNestedLoopJoin]),
+ ("SELECT * FROM testData JOIN testData2 WHERE key > a",
+ classOf[BroadcastNestedLoopJoin]),
+ ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key > a",
+ classOf[BroadcastNestedLoopJoin]),
+ ("SELECT * FROM testData left JOIN testData2 WHERE (key * a != key + a)",
+ classOf[BroadcastNestedLoopJoin]),
+ ("SELECT * FROM testData right JOIN testData2 WHERE (key * a != key + a)",
+ classOf[BroadcastNestedLoopJoin]),
+ ("SELECT * FROM testData full JOIN testData2 WHERE (key * a != key + a)",
+ classOf[BroadcastNestedLoopJoin])
+ ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
+
+ checkAnswer(
+ sql(
+ """
+ SELECT x.value, y.a, y.b FROM testData x JOIN testData2 y WHERE x.key = 2
+ """.stripMargin),
+ Row("2", 1, 1) ::
+ Row("2", 1, 2) ::
+ Row("2", 2, 1) ::
+ Row("2", 2, 2) ::
+ Row("2", 3, 1) ::
+ Row("2", 3, 2) :: Nil)
+
+ checkAnswer(
+ sql(
+ """
+ SELECT x.value, y.a, y.b FROM testData x JOIN testData2 y WHERE x.key < y.a
+ """.stripMargin),
+ Row("1", 2, 1) ::
+ Row("1", 2, 2) ::
+ Row("1", 3, 1) ::
+ Row("1", 3, 2) ::
+ Row("2", 3, 1) ::
+ Row("2", 3, 2) :: Nil)
+
+ checkAnswer(
+ sql(
+ """
+ SELECT x.value, y.a, y.b FROM testData x JOIN testData2 y ON x.key < y.a
+ """.stripMargin),
+ Row("1", 2, 1) ::
+ Row("1", 2, 2) ::
+ Row("1", 3, 1) ::
+ Row("1", 3, 2) ::
+ Row("2", 3, 1) ::
+ Row("2", 3, 2) :: Nil)
+ }
+
+ sql("UNCACHE TABLE testData")
+ }
+
test("left semi join") {
val df = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a")
checkAnswer(df,
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index c328734df3..83a81cf5d1 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -588,8 +588,9 @@ class HiveContext private[hive](
LeftSemiJoin,
EquiJoinSelection,
BasicOperators,
+ BroadcastNestedLoop,
CartesianProduct,
- BroadcastNestedLoopJoin
+ DefaultJoin
)
}
diff --git a/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #1-0-abfc0b99ee357f71639f6162345fe8e b/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #1-0-abfc0b99ee357f71639f6162345fe8e
new file mode 100644
index 0000000000..0bb9399af0
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #1-0-abfc0b99ee357f71639f6162345fe8e
@@ -0,0 +1,20 @@
+302 0
+302 0
+302 0
+305 0
+305 0
+305 0
+306 0
+306 0
+306 0
+307 0
+307 0
+307 0
+307 0
+307 0
+307 0
+308 0
+308 0
+308 0
+309 0
+309 0
diff --git a/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #2-0-8412a39ee57885ccb0aaf848db8ef1dd b/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #2-0-8412a39ee57885ccb0aaf848db8ef1dd
new file mode 100644
index 0000000000..4e455ed255
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #2-0-8412a39ee57885ccb0aaf848db8ef1dd
@@ -0,0 +1,20 @@
+302 0
+302 0
+302 0
+305 0
+305 0
+305 0
+305 2
+305 4
+306 0
+306 0
+306 0
+306 2
+306 4
+306 5
+306 5
+306 5
+307 0
+307 0
+307 0
+307 0
diff --git a/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #3-0-e8a0427dbde35eea6011144443e5ffb4 b/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #3-0-e8a0427dbde35eea6011144443e5ffb4
new file mode 100644
index 0000000000..4e455ed255
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #3-0-e8a0427dbde35eea6011144443e5ffb4
@@ -0,0 +1,20 @@
+302 0
+302 0
+302 0
+305 0
+305 0
+305 0
+305 2
+305 4
+306 0
+306 0
+306 0
+306 2
+306 4
+306 5
+306 5
+306 5
+307 0
+307 0
+307 0
+307 0
diff --git a/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #4-0-45f8602d257655322b7d18cad09f6a0f b/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #4-0-45f8602d257655322b7d18cad09f6a0f
new file mode 100644
index 0000000000..4e455ed255
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #4-0-45f8602d257655322b7d18cad09f6a0f
@@ -0,0 +1,20 @@
+302 0
+302 0
+302 0
+305 0
+305 0
+305 0
+305 2
+305 4
+306 0
+306 0
+306 0
+306 2
+306 4
+306 5
+306 5
+306 5
+307 0
+307 0
+307 0
+307 0
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index 2878500453..b52f7d4b57 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -20,6 +20,8 @@ package org.apache.spark.sql.hive.execution
import java.io.File
import java.util.{Locale, TimeZone}
+import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoin
+
import scala.util.Try
import org.scalatest.BeforeAndAfter
@@ -69,6 +71,58 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
}
}
+ // Testing the Broadcast based join for cartesian join (cross join)
+ // We assume that the Broadcast Join Threshold will works since the src is a small table
+ private val spark_10484_1 = """
+ | SELECT a.key, b.key
+ | FROM src a LEFT JOIN src b WHERE a.key > b.key + 300
+ | ORDER BY b.key, a.key
+ | LIMIT 20
+ """.stripMargin
+ private val spark_10484_2 = """
+ | SELECT a.key, b.key
+ | FROM src a RIGHT JOIN src b WHERE a.key > b.key + 300
+ | ORDER BY a.key, b.key
+ | LIMIT 20
+ """.stripMargin
+ private val spark_10484_3 = """
+ | SELECT a.key, b.key
+ | FROM src a FULL OUTER JOIN src b WHERE a.key > b.key + 300
+ | ORDER BY a.key, b.key
+ | LIMIT 20
+ """.stripMargin
+ private val spark_10484_4 = """
+ | SELECT a.key, b.key
+ | FROM src a JOIN src b WHERE a.key > b.key + 300
+ | ORDER BY a.key, b.key
+ | LIMIT 20
+ """.stripMargin
+
+ createQueryTest("SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #1",
+ spark_10484_1)
+
+ createQueryTest("SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #2",
+ spark_10484_2)
+
+ createQueryTest("SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #3",
+ spark_10484_3)
+
+ createQueryTest("SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #4",
+ spark_10484_4)
+
+ test("SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN") {
+ def assertBroadcastNestedLoopJoin(sqlText: String): Unit = {
+ assert(sql(sqlText).queryExecution.sparkPlan.collect {
+ case _: BroadcastNestedLoopJoin => 1
+ }.nonEmpty)
+ }
+
+ assertBroadcastNestedLoopJoin(spark_10484_1)
+ assertBroadcastNestedLoopJoin(spark_10484_2)
+ assertBroadcastNestedLoopJoin(spark_10484_3)
+ assertBroadcastNestedLoopJoin(spark_10484_4)
+ }
+
createQueryTest("SPARK-8976 Wrong Result for Rollup #1",
"""
SELECT count(*) AS cnt, key % 5,GROUPING__ID FROM src group by key%5 WITH ROLLUP