diff options
author | Cheng Hao <hao.cheng@intel.com> | 2015-10-27 20:26:38 -0700 |
---|---|---|
committer | Yin Huai <yhuai@databricks.com> | 2015-10-27 20:26:38 -0700 |
commit | d9c6039897236c3f1e4503aa95c5c9b07b32eadd (patch) | |
tree | 3f9ddb1f1c7b91ef3cb7073cd7094522e48b340f /sql/core | |
parent | b960a890561eaf3795b93c621bd95be81e56f5b7 (diff) | |
download | spark-d9c6039897236c3f1e4503aa95c5c9b07b32eadd.tar.gz spark-d9c6039897236c3f1e4503aa95c5c9b07b32eadd.tar.bz2 spark-d9c6039897236c3f1e4503aa95c5c9b07b32eadd.zip |
[SPARK-10484] [SQL] Optimize the cartesian join with broadcast join for some cases
In some cases, we can broadcast the smaller relation in cartesian join, which improve the performance significantly.
Author: Cheng Hao <hao.cheng@intel.com>
Closes #8652 from chenghao-intel/cartesian.
Diffstat (limited to 'sql/core')
4 files changed, 125 insertions, 15 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index b346f43fae..0f98fe88b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -44,8 +44,9 @@ class SparkPlanner(val sqlContext: SQLContext) extends SparkStrategies { EquiJoinSelection :: InMemoryScans :: BasicOperators :: + BroadcastNestedLoop :: CartesianProduct :: - BroadcastNestedLoopJoin :: Nil) + DefaultJoin :: Nil) /** * Used to build table scan operators where complex projection and filtering are done using diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 637deff4e2..ee97162853 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -294,25 +294,24 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } - - object BroadcastNestedLoopJoin extends Strategy { + object BroadcastNestedLoop extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case logical.Join(left, right, joinType, condition) => - val buildSide = - if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { - joins.BuildRight - } else { - joins.BuildLeft - } - joins.BroadcastNestedLoopJoin( - planLater(left), planLater(right), buildSide, joinType, condition) :: Nil + case logical.Join( + CanBroadcast(left), right, joinType, condition) if joinType != LeftSemiJoin => + execution.joins.BroadcastNestedLoopJoin( + planLater(left), planLater(right), joins.BuildLeft, joinType, condition) :: Nil + case logical.Join( + left, CanBroadcast(right), joinType, condition) if joinType != LeftSemiJoin => + execution.joins.BroadcastNestedLoopJoin( + planLater(left), planLater(right), joins.BuildRight, joinType, condition) :: Nil case _ => Nil } } object CartesianProduct extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case logical.Join(left, right, _, None) => + // TODO CartesianProduct doesn't support the Left Semi Join + case logical.Join(left, right, joinType, None) if joinType != LeftSemiJoin => execution.joins.CartesianProduct(planLater(left), planLater(right)) :: Nil case logical.Join(left, right, Inner, Some(condition)) => execution.Filter(condition, @@ -321,6 +320,21 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + object DefaultJoin extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case logical.Join(left, right, joinType, condition) => + val buildSide = + if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { + joins.BuildRight + } else { + joins.BuildLeft + } + joins.BroadcastNestedLoopJoin( + planLater(left), planLater(right), buildSide, joinType, condition) :: Nil + case _ => Nil + } + } + protected lazy val singleRowRdd = sparkContext.parallelize(Seq(InternalRow()), 1) object TakeOrderedAndProject extends Strategy { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index efef8c8a8b..05d20f511a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning -import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.collection.CompactBuffer @@ -67,7 +67,10 @@ case class BroadcastNestedLoopJoin( left.output.map(_.withNullability(true)) ++ right.output case FullOuter => left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) - case x => + case Inner => + // TODO we can avoid breaking the lineage, since we union an empty RDD for Inner Join case + left.output ++ right.output + case x => // TODO support the Left Semi Join throw new IllegalArgumentException( s"BroadcastNestedLoopJoin should not take $x as the JoinType") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index b1fb068158..a9ca46cab0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -28,6 +28,10 @@ class JoinSuite extends QueryTest with SharedSQLContext { setupTestData() + def statisticSizeInByte(df: DataFrame): BigInt = { + df.queryExecution.optimizedPlan.statistics.sizeInBytes + } + test("equi-join is hash-join") { val x = testData2.as("x") val y = testData2.as("y") @@ -466,6 +470,94 @@ class JoinSuite extends QueryTest with SharedSQLContext { sql("UNCACHE TABLE testData") } + test("cross join with broadcast") { + sql("CACHE TABLE testData") + + val sizeInByteOfTestData = statisticSizeInByte(sqlContext.table("testData")) + + // we set the threshold is greater than statistic of the cached table testData + withSQLConf( + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> (sizeInByteOfTestData + 1).toString()) { + + assert(statisticSizeInByte(sqlContext.table("testData2")) > + sqlContext.conf.autoBroadcastJoinThreshold) + + assert(statisticSizeInByte(sqlContext.table("testData")) < + sqlContext.conf.autoBroadcastJoinThreshold) + + Seq( + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", + classOf[LeftSemiJoinHash]), + ("SELECT * FROM testData LEFT SEMI JOIN testData2", + classOf[LeftSemiJoinBNL]), + ("SELECT * FROM testData JOIN testData2", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData JOIN testData2 WHERE key = 2", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData LEFT JOIN testData2", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData FULL OUTER JOIN testData2", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 WHERE key = 2", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2 WHERE key = 2", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key = 2", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData JOIN testData2 WHERE key > a", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key > a", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData left JOIN testData2 WHERE (key * a != key + a)", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData right JOIN testData2 WHERE (key * a != key + a)", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData full JOIN testData2 WHERE (key * a != key + a)", + classOf[BroadcastNestedLoopJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + + checkAnswer( + sql( + """ + SELECT x.value, y.a, y.b FROM testData x JOIN testData2 y WHERE x.key = 2 + """.stripMargin), + Row("2", 1, 1) :: + Row("2", 1, 2) :: + Row("2", 2, 1) :: + Row("2", 2, 2) :: + Row("2", 3, 1) :: + Row("2", 3, 2) :: Nil) + + checkAnswer( + sql( + """ + SELECT x.value, y.a, y.b FROM testData x JOIN testData2 y WHERE x.key < y.a + """.stripMargin), + Row("1", 2, 1) :: + Row("1", 2, 2) :: + Row("1", 3, 1) :: + Row("1", 3, 2) :: + Row("2", 3, 1) :: + Row("2", 3, 2) :: Nil) + + checkAnswer( + sql( + """ + SELECT x.value, y.a, y.b FROM testData x JOIN testData2 y ON x.key < y.a + """.stripMargin), + Row("1", 2, 1) :: + Row("1", 2, 2) :: + Row("1", 3, 1) :: + Row("1", 3, 2) :: + Row("2", 3, 1) :: + Row("2", 3, 2) :: Nil) + } + + sql("UNCACHE TABLE testData") + } + test("left semi join") { val df = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a") checkAnswer(df, |