From d9c6039897236c3f1e4503aa95c5c9b07b32eadd Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Tue, 27 Oct 2015 20:26:38 -0700 Subject: [SPARK-10484] [SQL] Optimize the cartesian join with broadcast join for some cases In some cases, we can broadcast the smaller relation in cartesian join, which improve the performance significantly. Author: Cheng Hao Closes #8652 from chenghao-intel/cartesian. --- .../apache/spark/sql/execution/SparkPlanner.scala | 3 +- .../spark/sql/execution/SparkStrategies.scala | 38 ++++++--- .../execution/joins/BroadcastNestedLoopJoin.scala | 7 +- .../scala/org/apache/spark/sql/JoinSuite.scala | 92 ++++++++++++++++++++++ .../org/apache/spark/sql/hive/HiveContext.scala | 3 +- ...based JOIN #1-0-abfc0b99ee357f71639f6162345fe8e | 20 +++++ ...ased JOIN #2-0-8412a39ee57885ccb0aaf848db8ef1dd | 20 +++++ ...ased JOIN #3-0-e8a0427dbde35eea6011144443e5ffb4 | 20 +++++ ...ased JOIN #4-0-45f8602d257655322b7d18cad09f6a0f | 20 +++++ .../spark/sql/hive/execution/HiveQuerySuite.scala | 54 +++++++++++++ 10 files changed, 261 insertions(+), 16 deletions(-) create mode 100644 sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #1-0-abfc0b99ee357f71639f6162345fe8e create mode 100644 sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #2-0-8412a39ee57885ccb0aaf848db8ef1dd create mode 100644 sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #3-0-e8a0427dbde35eea6011144443e5ffb4 create mode 100644 sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #4-0-45f8602d257655322b7d18cad09f6a0f diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index b346f43fae..0f98fe88b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -44,8 +44,9 @@ class SparkPlanner(val sqlContext: SQLContext) extends SparkStrategies { EquiJoinSelection :: InMemoryScans :: BasicOperators :: + BroadcastNestedLoop :: CartesianProduct :: - BroadcastNestedLoopJoin :: Nil) + DefaultJoin :: Nil) /** * Used to build table scan operators where complex projection and filtering are done using diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 637deff4e2..ee97162853 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -294,25 +294,24 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } - - object BroadcastNestedLoopJoin extends Strategy { + object BroadcastNestedLoop extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case logical.Join(left, right, joinType, condition) => - val buildSide = - if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { - joins.BuildRight - } else { - joins.BuildLeft - } - joins.BroadcastNestedLoopJoin( - planLater(left), planLater(right), buildSide, joinType, condition) :: Nil + case logical.Join( + CanBroadcast(left), right, joinType, condition) if joinType != LeftSemiJoin => + execution.joins.BroadcastNestedLoopJoin( + planLater(left), planLater(right), joins.BuildLeft, joinType, condition) :: Nil + case logical.Join( + left, CanBroadcast(right), joinType, condition) if joinType != LeftSemiJoin => + execution.joins.BroadcastNestedLoopJoin( + planLater(left), planLater(right), joins.BuildRight, joinType, condition) :: Nil case _ => Nil } } object CartesianProduct extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case logical.Join(left, right, _, None) => + // TODO CartesianProduct doesn't support the Left Semi Join + case logical.Join(left, right, joinType, None) if joinType != LeftSemiJoin => execution.joins.CartesianProduct(planLater(left), planLater(right)) :: Nil case logical.Join(left, right, Inner, Some(condition)) => execution.Filter(condition, @@ -321,6 +320,21 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + object DefaultJoin extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case logical.Join(left, right, joinType, condition) => + val buildSide = + if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { + joins.BuildRight + } else { + joins.BuildLeft + } + joins.BroadcastNestedLoopJoin( + planLater(left), planLater(right), buildSide, joinType, condition) :: Nil + case _ => Nil + } + } + protected lazy val singleRowRdd = sparkContext.parallelize(Seq(InternalRow()), 1) object TakeOrderedAndProject extends Strategy { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index efef8c8a8b..05d20f511a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning -import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.collection.CompactBuffer @@ -67,7 +67,10 @@ case class BroadcastNestedLoopJoin( left.output.map(_.withNullability(true)) ++ right.output case FullOuter => left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) - case x => + case Inner => + // TODO we can avoid breaking the lineage, since we union an empty RDD for Inner Join case + left.output ++ right.output + case x => // TODO support the Left Semi Join throw new IllegalArgumentException( s"BroadcastNestedLoopJoin should not take $x as the JoinType") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index b1fb068158..a9ca46cab0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -28,6 +28,10 @@ class JoinSuite extends QueryTest with SharedSQLContext { setupTestData() + def statisticSizeInByte(df: DataFrame): BigInt = { + df.queryExecution.optimizedPlan.statistics.sizeInBytes + } + test("equi-join is hash-join") { val x = testData2.as("x") val y = testData2.as("y") @@ -466,6 +470,94 @@ class JoinSuite extends QueryTest with SharedSQLContext { sql("UNCACHE TABLE testData") } + test("cross join with broadcast") { + sql("CACHE TABLE testData") + + val sizeInByteOfTestData = statisticSizeInByte(sqlContext.table("testData")) + + // we set the threshold is greater than statistic of the cached table testData + withSQLConf( + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> (sizeInByteOfTestData + 1).toString()) { + + assert(statisticSizeInByte(sqlContext.table("testData2")) > + sqlContext.conf.autoBroadcastJoinThreshold) + + assert(statisticSizeInByte(sqlContext.table("testData")) < + sqlContext.conf.autoBroadcastJoinThreshold) + + Seq( + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", + classOf[LeftSemiJoinHash]), + ("SELECT * FROM testData LEFT SEMI JOIN testData2", + classOf[LeftSemiJoinBNL]), + ("SELECT * FROM testData JOIN testData2", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData JOIN testData2 WHERE key = 2", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData LEFT JOIN testData2", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData FULL OUTER JOIN testData2", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 WHERE key = 2", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2 WHERE key = 2", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key = 2", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData JOIN testData2 WHERE key > a", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key > a", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData left JOIN testData2 WHERE (key * a != key + a)", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData right JOIN testData2 WHERE (key * a != key + a)", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData full JOIN testData2 WHERE (key * a != key + a)", + classOf[BroadcastNestedLoopJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + + checkAnswer( + sql( + """ + SELECT x.value, y.a, y.b FROM testData x JOIN testData2 y WHERE x.key = 2 + """.stripMargin), + Row("2", 1, 1) :: + Row("2", 1, 2) :: + Row("2", 2, 1) :: + Row("2", 2, 2) :: + Row("2", 3, 1) :: + Row("2", 3, 2) :: Nil) + + checkAnswer( + sql( + """ + SELECT x.value, y.a, y.b FROM testData x JOIN testData2 y WHERE x.key < y.a + """.stripMargin), + Row("1", 2, 1) :: + Row("1", 2, 2) :: + Row("1", 3, 1) :: + Row("1", 3, 2) :: + Row("2", 3, 1) :: + Row("2", 3, 2) :: Nil) + + checkAnswer( + sql( + """ + SELECT x.value, y.a, y.b FROM testData x JOIN testData2 y ON x.key < y.a + """.stripMargin), + Row("1", 2, 1) :: + Row("1", 2, 2) :: + Row("1", 3, 1) :: + Row("1", 3, 2) :: + Row("2", 3, 1) :: + Row("2", 3, 2) :: Nil) + } + + sql("UNCACHE TABLE testData") + } + test("left semi join") { val df = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a") checkAnswer(df, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index c328734df3..83a81cf5d1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -588,8 +588,9 @@ class HiveContext private[hive]( LeftSemiJoin, EquiJoinSelection, BasicOperators, + BroadcastNestedLoop, CartesianProduct, - BroadcastNestedLoopJoin + DefaultJoin ) } diff --git a/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #1-0-abfc0b99ee357f71639f6162345fe8e b/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #1-0-abfc0b99ee357f71639f6162345fe8e new file mode 100644 index 0000000000..0bb9399af0 --- /dev/null +++ b/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #1-0-abfc0b99ee357f71639f6162345fe8e @@ -0,0 +1,20 @@ +302 0 +302 0 +302 0 +305 0 +305 0 +305 0 +306 0 +306 0 +306 0 +307 0 +307 0 +307 0 +307 0 +307 0 +307 0 +308 0 +308 0 +308 0 +309 0 +309 0 diff --git a/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #2-0-8412a39ee57885ccb0aaf848db8ef1dd b/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #2-0-8412a39ee57885ccb0aaf848db8ef1dd new file mode 100644 index 0000000000..4e455ed255 --- /dev/null +++ b/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #2-0-8412a39ee57885ccb0aaf848db8ef1dd @@ -0,0 +1,20 @@ +302 0 +302 0 +302 0 +305 0 +305 0 +305 0 +305 2 +305 4 +306 0 +306 0 +306 0 +306 2 +306 4 +306 5 +306 5 +306 5 +307 0 +307 0 +307 0 +307 0 diff --git a/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #3-0-e8a0427dbde35eea6011144443e5ffb4 b/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #3-0-e8a0427dbde35eea6011144443e5ffb4 new file mode 100644 index 0000000000..4e455ed255 --- /dev/null +++ b/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #3-0-e8a0427dbde35eea6011144443e5ffb4 @@ -0,0 +1,20 @@ +302 0 +302 0 +302 0 +305 0 +305 0 +305 0 +305 2 +305 4 +306 0 +306 0 +306 0 +306 2 +306 4 +306 5 +306 5 +306 5 +307 0 +307 0 +307 0 +307 0 diff --git a/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #4-0-45f8602d257655322b7d18cad09f6a0f b/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #4-0-45f8602d257655322b7d18cad09f6a0f new file mode 100644 index 0000000000..4e455ed255 --- /dev/null +++ b/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #4-0-45f8602d257655322b7d18cad09f6a0f @@ -0,0 +1,20 @@ +302 0 +302 0 +302 0 +305 0 +305 0 +305 0 +305 2 +305 4 +306 0 +306 0 +306 0 +306 2 +306 4 +306 5 +306 5 +306 5 +307 0 +307 0 +307 0 +307 0 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 2878500453..b52f7d4b57 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.hive.execution import java.io.File import java.util.{Locale, TimeZone} +import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoin + import scala.util.Try import org.scalatest.BeforeAndAfter @@ -69,6 +71,58 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } } + // Testing the Broadcast based join for cartesian join (cross join) + // We assume that the Broadcast Join Threshold will works since the src is a small table + private val spark_10484_1 = """ + | SELECT a.key, b.key + | FROM src a LEFT JOIN src b WHERE a.key > b.key + 300 + | ORDER BY b.key, a.key + | LIMIT 20 + """.stripMargin + private val spark_10484_2 = """ + | SELECT a.key, b.key + | FROM src a RIGHT JOIN src b WHERE a.key > b.key + 300 + | ORDER BY a.key, b.key + | LIMIT 20 + """.stripMargin + private val spark_10484_3 = """ + | SELECT a.key, b.key + | FROM src a FULL OUTER JOIN src b WHERE a.key > b.key + 300 + | ORDER BY a.key, b.key + | LIMIT 20 + """.stripMargin + private val spark_10484_4 = """ + | SELECT a.key, b.key + | FROM src a JOIN src b WHERE a.key > b.key + 300 + | ORDER BY a.key, b.key + | LIMIT 20 + """.stripMargin + + createQueryTest("SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #1", + spark_10484_1) + + createQueryTest("SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #2", + spark_10484_2) + + createQueryTest("SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #3", + spark_10484_3) + + createQueryTest("SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #4", + spark_10484_4) + + test("SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN") { + def assertBroadcastNestedLoopJoin(sqlText: String): Unit = { + assert(sql(sqlText).queryExecution.sparkPlan.collect { + case _: BroadcastNestedLoopJoin => 1 + }.nonEmpty) + } + + assertBroadcastNestedLoopJoin(spark_10484_1) + assertBroadcastNestedLoopJoin(spark_10484_2) + assertBroadcastNestedLoopJoin(spark_10484_3) + assertBroadcastNestedLoopJoin(spark_10484_4) + } + createQueryTest("SPARK-8976 Wrong Result for Rollup #1", """ SELECT count(*) AS cnt, key % 5,GROUPING__ID FROM src group by key%5 WITH ROLLUP -- cgit v1.2.3