diff options
author | Jurriaan Pruis <email@jurriaanpruis.nl> | 2016-05-21 23:01:14 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2016-05-21 23:01:14 -0700 |
commit | 223f6339088434eb3590c2f42091a38f05f1e5db (patch) | |
tree | 0ae551a262d4b35f10603f53061f5be2a2ac642c | |
parent | df9adb5ec994f054b2fa58e492867bbc5a60c234 (diff) | |
download | spark-223f6339088434eb3590c2f42091a38f05f1e5db.tar.gz spark-223f6339088434eb3590c2f42091a38f05f1e5db.tar.bz2 spark-223f6339088434eb3590c2f42091a38f05f1e5db.zip |
[SPARK-15415][SQL] Fix BroadcastHint when autoBroadcastJoinThreshold is 0 or -1
## What changes were proposed in this pull request?
This PR makes BroadcastHint more deterministic by using a special isBroadcastable property
instead of setting the sizeInBytes to 1.
See https://issues.apache.org/jira/browse/SPARK-15415
## How was this patch tested?
Added testcases to test if the broadcast hash join is included in the plan when the BroadcastHint is supplied and also tests for propagation of the joins.
Author: Jurriaan Pruis <email@jurriaanpruis.nl>
Closes #13244 from jurriaan/broadcast-hint.
5 files changed, 114 insertions, 26 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 45ac126a72..4984f235b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -313,7 +313,8 @@ abstract class UnaryNode extends LogicalPlan { // (product of children). sizeInBytes = 1 } - Statistics(sizeInBytes = sizeInBytes) + + child.statistics.copy(sizeInBytes = sizeInBytes) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala index 9ac4c3a2a5..63f86ad094 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala @@ -32,4 +32,4 @@ package org.apache.spark.sql.catalyst.plans.logical * @param sizeInBytes Physical size in bytes. For leaf operators this defaults to 1, otherwise it * defaults to the product of children's `sizeInBytes`. */ -private[sql] case class Statistics(sizeInBytes: BigInt) +private[sql] case class Statistics(sizeInBytes: BigInt, isBroadcastable: Boolean = false) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 732b0d7919..bed48b6f61 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -163,7 +163,9 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation val leftSize = left.statistics.sizeInBytes val rightSize = right.statistics.sizeInBytes val sizeInBytes = if (leftSize < rightSize) leftSize else rightSize - Statistics(sizeInBytes = sizeInBytes) + val isBroadcastable = left.statistics.isBroadcastable || right.statistics.isBroadcastable + + Statistics(sizeInBytes = sizeInBytes, isBroadcastable = isBroadcastable) } } @@ -183,7 +185,7 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le duplicateResolved override def statistics: Statistics = { - Statistics(sizeInBytes = left.statistics.sizeInBytes) + left.statistics.copy() } } @@ -330,6 +332,16 @@ case class Join( case UsingJoin(_, _) => false case _ => resolvedExceptNatural } + + override def statistics: Statistics = joinType match { + case LeftAnti | LeftSemi => + // LeftSemi and LeftAnti won't ever be bigger than left + left.statistics.copy() + case _ => + // make sure we don't propagate isBroadcastable in other joins, because + // they could explode the size. + super.statistics.copy(isBroadcastable = false) + } } /** @@ -338,9 +350,8 @@ case class Join( case class BroadcastHint(child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output - // We manually set statistics of BroadcastHint to smallest value to make sure - // the plan wrapped by BroadcastHint will be considered to broadcast later. - override def statistics: Statistics = Statistics(sizeInBytes = 1) + // set isBroadcastable to true so the child will be broadcasted + override def statistics: Statistics = super.statistics.copy(isBroadcastable = true) } case class InsertIntoTable( @@ -465,7 +476,7 @@ case class Aggregate( override def statistics: Statistics = { if (groupingExpressions.isEmpty) { - Statistics(sizeInBytes = 1) + super.statistics.copy(sizeInBytes = 1) } else { super.statistics } @@ -638,7 +649,7 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN override lazy val statistics: Statistics = { val limit = limitExpr.eval().asInstanceOf[Int] val sizeInBytes = (limit: Long) * output.map(a => a.dataType.defaultSize).sum - Statistics(sizeInBytes = sizeInBytes) + child.statistics.copy(sizeInBytes = sizeInBytes) } } @@ -653,7 +664,7 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo override lazy val statistics: Statistics = { val limit = limitExpr.eval().asInstanceOf[Int] val sizeInBytes = (limit: Long) * output.map(a => a.dataType.defaultSize).sum - Statistics(sizeInBytes = sizeInBytes) + child.statistics.copy(sizeInBytes = sizeInBytes) } } @@ -690,7 +701,7 @@ case class Sample( if (sizeInBytes == 0) { sizeInBytes = 1 } - Statistics(sizeInBytes = sizeInBytes) + child.statistics.copy(sizeInBytes = sizeInBytes) } override protected def otherCopyArgs: Seq[AnyRef] = isTableSample :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 3343039ae1..664e7f5661 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -92,7 +92,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Matches a plan whose output should be small enough to be used in broadcast join. */ private def canBroadcast(plan: LogicalPlan): Boolean = { - plan.statistics.sizeInBytes <= conf.autoBroadcastJoinThreshold + plan.statistics.isBroadcastable || + plan.statistics.sizeInBytes <= conf.autoBroadcastJoinThreshold } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 730ec43556..e681b88685 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -22,9 +22,12 @@ import scala.reflect.ClassTag import org.scalatest.BeforeAndAfterAll import org.apache.spark.{AccumulatorSuite, SparkConf, SparkContext} -import org.apache.spark.sql.{QueryTest, SparkSession} +import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession} import org.apache.spark.sql.execution.exchange.EnsureRequirements +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SQLTestUtils /** * Test various broadcast join operators. @@ -33,7 +36,9 @@ import org.apache.spark.sql.functions._ * unsafe map in [[org.apache.spark.sql.execution.joins.UnsafeHashedRelation]] is not triggered * without serializing the hashed relation, which does not happen in local mode. */ -class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { +class BroadcastJoinSuite extends QueryTest with SQLTestUtils { + import testImplicits._ + protected var spark: SparkSession = null /** @@ -56,30 +61,100 @@ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { /** * Test whether the specified broadcast join updates the peak execution memory accumulator. */ - private def testBroadcastJoin[T: ClassTag](name: String, joinType: String): Unit = { + private def testBroadcastJoinPeak[T: ClassTag](name: String, joinType: String): Unit = { AccumulatorSuite.verifyPeakExecutionMemorySet(spark.sparkContext, name) { - val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") - val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") - // Comparison at the end is for broadcast left semi join - val joinExpression = df1("key") === df2("key") && df1("value") > df2("value") - val df3 = df1.join(broadcast(df2), joinExpression, joinType) - val plan = - EnsureRequirements(spark.sessionState.conf).apply(df3.queryExecution.sparkPlan) - assert(plan.collect { case p: T => p }.size === 1) + val plan = testBroadcastJoin[T](joinType) plan.executeCollect() } } + private def testBroadcastJoin[T: ClassTag](joinType: String, + forceBroadcast: Boolean = false): SparkPlan = { + val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") + var df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") + + // Comparison at the end is for broadcast left semi join + val joinExpression = df1("key") === df2("key") && df1("value") > df2("value") + val df3 = if (forceBroadcast) { + df1.join(broadcast(df2), joinExpression, joinType) + } else { + df1.join(df2, joinExpression, joinType) + } + val plan = + EnsureRequirements(spark.sessionState.conf).apply(df3.queryExecution.sparkPlan) + assert(plan.collect { case p: T => p }.size === 1) + + return plan + } + test("unsafe broadcast hash join updates peak execution memory") { - testBroadcastJoin[BroadcastHashJoinExec]("unsafe broadcast hash join", "inner") + testBroadcastJoinPeak[BroadcastHashJoinExec]("unsafe broadcast hash join", "inner") } test("unsafe broadcast hash outer join updates peak execution memory") { - testBroadcastJoin[BroadcastHashJoinExec]("unsafe broadcast hash outer join", "left_outer") + testBroadcastJoinPeak[BroadcastHashJoinExec]("unsafe broadcast hash outer join", "left_outer") } test("unsafe broadcast left semi join updates peak execution memory") { - testBroadcastJoin[BroadcastHashJoinExec]("unsafe broadcast left semi join", "leftsemi") + testBroadcastJoinPeak[BroadcastHashJoinExec]("unsafe broadcast left semi join", "leftsemi") + } + + test("broadcast hint isn't bothered by authBroadcastJoinThreshold set to low values") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { + testBroadcastJoin[BroadcastHashJoinExec]("inner", true) + } + } + + test("broadcast hint isn't bothered by a disabled authBroadcastJoinThreshold") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + testBroadcastJoin[BroadcastHashJoinExec]("inner", true) + } + } + + test("broadcast hint isn't propagated after a join") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") + val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") + val df3 = df1.join(broadcast(df2), Seq("key"), "inner").drop(df2("key")) + + val df4 = spark.createDataFrame(Seq((1, "5"), (2, "5"))).toDF("key", "value") + val df5 = df4.join(df3, Seq("key"), "inner") + + val plan = + EnsureRequirements(spark.sessionState.conf).apply(df5.queryExecution.sparkPlan) + + assert(plan.collect { case p: BroadcastHashJoinExec => p }.size === 1) + assert(plan.collect { case p: SortMergeJoinExec => p }.size === 1) + } } + private def assertBroadcastJoin(df : Dataset[Row]) : Unit = { + val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") + val joined = df1.join(df, Seq("key"), "inner") + + val plan = + EnsureRequirements(spark.sessionState.conf).apply(joined.queryExecution.sparkPlan) + + assert(plan.collect { case p: BroadcastHashJoinExec => p }.size === 1) + } + + test("broadcast hint is propagated correctly") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"), (3, "2"))).toDF("key", "value") + val broadcasted = broadcast(df2) + val df3 = spark.createDataFrame(Seq((2, "2"), (3, "3"))).toDF("key", "value") + + val cases = Seq(broadcasted.limit(2), + broadcasted.filter("value < 10"), + broadcasted.sample(true, 0.5), + broadcasted.distinct(), + broadcasted.groupBy("value").agg(min($"key").as("key")), + // except and intersect are semi/anti-joins which won't return more data then + // their left argument, so the broadcast hint should be propagated here + broadcasted.except(df3), + broadcasted.intersect(df3)) + + cases.foreach(assertBroadcastJoin) + } + } } |