aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJurriaan Pruis <email@jurriaanpruis.nl>2016-05-21 23:01:14 -0700
committerReynold Xin <rxin@databricks.com>2016-05-21 23:01:14 -0700
commit223f6339088434eb3590c2f42091a38f05f1e5db (patch)
tree0ae551a262d4b35f10603f53061f5be2a2ac642c
parentdf9adb5ec994f054b2fa58e492867bbc5a60c234 (diff)
downloadspark-223f6339088434eb3590c2f42091a38f05f1e5db.tar.gz
spark-223f6339088434eb3590c2f42091a38f05f1e5db.tar.bz2
spark-223f6339088434eb3590c2f42091a38f05f1e5db.zip
[SPARK-15415][SQL] Fix BroadcastHint when autoBroadcastJoinThreshold is 0 or -1
## What changes were proposed in this pull request? This PR makes BroadcastHint more deterministic by using a special isBroadcastable property instead of setting the sizeInBytes to 1. See https://issues.apache.org/jira/browse/SPARK-15415 ## How was this patch tested? Added testcases to test if the broadcast hash join is included in the plan when the BroadcastHint is supplied and also tests for propagation of the joins. Author: Jurriaan Pruis <email@jurriaanpruis.nl> Closes #13244 from jurriaan/broadcast-hint.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala29
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala103
5 files changed, 114 insertions, 26 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 45ac126a72..4984f235b4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -313,7 +313,8 @@ abstract class UnaryNode extends LogicalPlan {
// (product of children).
sizeInBytes = 1
}
- Statistics(sizeInBytes = sizeInBytes)
+
+ child.statistics.copy(sizeInBytes = sizeInBytes)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala
index 9ac4c3a2a5..63f86ad094 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala
@@ -32,4 +32,4 @@ package org.apache.spark.sql.catalyst.plans.logical
* @param sizeInBytes Physical size in bytes. For leaf operators this defaults to 1, otherwise it
* defaults to the product of children's `sizeInBytes`.
*/
-private[sql] case class Statistics(sizeInBytes: BigInt)
+private[sql] case class Statistics(sizeInBytes: BigInt, isBroadcastable: Boolean = false)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 732b0d7919..bed48b6f61 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -163,7 +163,9 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation
val leftSize = left.statistics.sizeInBytes
val rightSize = right.statistics.sizeInBytes
val sizeInBytes = if (leftSize < rightSize) leftSize else rightSize
- Statistics(sizeInBytes = sizeInBytes)
+ val isBroadcastable = left.statistics.isBroadcastable || right.statistics.isBroadcastable
+
+ Statistics(sizeInBytes = sizeInBytes, isBroadcastable = isBroadcastable)
}
}
@@ -183,7 +185,7 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le
duplicateResolved
override def statistics: Statistics = {
- Statistics(sizeInBytes = left.statistics.sizeInBytes)
+ left.statistics.copy()
}
}
@@ -330,6 +332,16 @@ case class Join(
case UsingJoin(_, _) => false
case _ => resolvedExceptNatural
}
+
+ override def statistics: Statistics = joinType match {
+ case LeftAnti | LeftSemi =>
+ // LeftSemi and LeftAnti won't ever be bigger than left
+ left.statistics.copy()
+ case _ =>
+ // make sure we don't propagate isBroadcastable in other joins, because
+ // they could explode the size.
+ super.statistics.copy(isBroadcastable = false)
+ }
}
/**
@@ -338,9 +350,8 @@ case class Join(
case class BroadcastHint(child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output
- // We manually set statistics of BroadcastHint to smallest value to make sure
- // the plan wrapped by BroadcastHint will be considered to broadcast later.
- override def statistics: Statistics = Statistics(sizeInBytes = 1)
+ // set isBroadcastable to true so the child will be broadcasted
+ override def statistics: Statistics = super.statistics.copy(isBroadcastable = true)
}
case class InsertIntoTable(
@@ -465,7 +476,7 @@ case class Aggregate(
override def statistics: Statistics = {
if (groupingExpressions.isEmpty) {
- Statistics(sizeInBytes = 1)
+ super.statistics.copy(sizeInBytes = 1)
} else {
super.statistics
}
@@ -638,7 +649,7 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN
override lazy val statistics: Statistics = {
val limit = limitExpr.eval().asInstanceOf[Int]
val sizeInBytes = (limit: Long) * output.map(a => a.dataType.defaultSize).sum
- Statistics(sizeInBytes = sizeInBytes)
+ child.statistics.copy(sizeInBytes = sizeInBytes)
}
}
@@ -653,7 +664,7 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo
override lazy val statistics: Statistics = {
val limit = limitExpr.eval().asInstanceOf[Int]
val sizeInBytes = (limit: Long) * output.map(a => a.dataType.defaultSize).sum
- Statistics(sizeInBytes = sizeInBytes)
+ child.statistics.copy(sizeInBytes = sizeInBytes)
}
}
@@ -690,7 +701,7 @@ case class Sample(
if (sizeInBytes == 0) {
sizeInBytes = 1
}
- Statistics(sizeInBytes = sizeInBytes)
+ child.statistics.copy(sizeInBytes = sizeInBytes)
}
override protected def otherCopyArgs: Seq[AnyRef] = isTableSample :: Nil
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 3343039ae1..664e7f5661 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -92,7 +92,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
* Matches a plan whose output should be small enough to be used in broadcast join.
*/
private def canBroadcast(plan: LogicalPlan): Boolean = {
- plan.statistics.sizeInBytes <= conf.autoBroadcastJoinThreshold
+ plan.statistics.isBroadcastable ||
+ plan.statistics.sizeInBytes <= conf.autoBroadcastJoinThreshold
}
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
index 730ec43556..e681b88685 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
@@ -22,9 +22,12 @@ import scala.reflect.ClassTag
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.{AccumulatorSuite, SparkConf, SparkContext}
-import org.apache.spark.sql.{QueryTest, SparkSession}
+import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession}
import org.apache.spark.sql.execution.exchange.EnsureRequirements
+import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.functions._
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SQLTestUtils
/**
* Test various broadcast join operators.
@@ -33,7 +36,9 @@ import org.apache.spark.sql.functions._
* unsafe map in [[org.apache.spark.sql.execution.joins.UnsafeHashedRelation]] is not triggered
* without serializing the hashed relation, which does not happen in local mode.
*/
-class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll {
+class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
+ import testImplicits._
+
protected var spark: SparkSession = null
/**
@@ -56,30 +61,100 @@ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll {
/**
* Test whether the specified broadcast join updates the peak execution memory accumulator.
*/
- private def testBroadcastJoin[T: ClassTag](name: String, joinType: String): Unit = {
+ private def testBroadcastJoinPeak[T: ClassTag](name: String, joinType: String): Unit = {
AccumulatorSuite.verifyPeakExecutionMemorySet(spark.sparkContext, name) {
- val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value")
- val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value")
- // Comparison at the end is for broadcast left semi join
- val joinExpression = df1("key") === df2("key") && df1("value") > df2("value")
- val df3 = df1.join(broadcast(df2), joinExpression, joinType)
- val plan =
- EnsureRequirements(spark.sessionState.conf).apply(df3.queryExecution.sparkPlan)
- assert(plan.collect { case p: T => p }.size === 1)
+ val plan = testBroadcastJoin[T](joinType)
plan.executeCollect()
}
}
+ private def testBroadcastJoin[T: ClassTag](joinType: String,
+ forceBroadcast: Boolean = false): SparkPlan = {
+ val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value")
+ var df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value")
+
+ // Comparison at the end is for broadcast left semi join
+ val joinExpression = df1("key") === df2("key") && df1("value") > df2("value")
+ val df3 = if (forceBroadcast) {
+ df1.join(broadcast(df2), joinExpression, joinType)
+ } else {
+ df1.join(df2, joinExpression, joinType)
+ }
+ val plan =
+ EnsureRequirements(spark.sessionState.conf).apply(df3.queryExecution.sparkPlan)
+ assert(plan.collect { case p: T => p }.size === 1)
+
+ return plan
+ }
+
test("unsafe broadcast hash join updates peak execution memory") {
- testBroadcastJoin[BroadcastHashJoinExec]("unsafe broadcast hash join", "inner")
+ testBroadcastJoinPeak[BroadcastHashJoinExec]("unsafe broadcast hash join", "inner")
}
test("unsafe broadcast hash outer join updates peak execution memory") {
- testBroadcastJoin[BroadcastHashJoinExec]("unsafe broadcast hash outer join", "left_outer")
+ testBroadcastJoinPeak[BroadcastHashJoinExec]("unsafe broadcast hash outer join", "left_outer")
}
test("unsafe broadcast left semi join updates peak execution memory") {
- testBroadcastJoin[BroadcastHashJoinExec]("unsafe broadcast left semi join", "leftsemi")
+ testBroadcastJoinPeak[BroadcastHashJoinExec]("unsafe broadcast left semi join", "leftsemi")
+ }
+
+ test("broadcast hint isn't bothered by authBroadcastJoinThreshold set to low values") {
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") {
+ testBroadcastJoin[BroadcastHashJoinExec]("inner", true)
+ }
+ }
+
+ test("broadcast hint isn't bothered by a disabled authBroadcastJoinThreshold") {
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+ testBroadcastJoin[BroadcastHashJoinExec]("inner", true)
+ }
+ }
+
+ test("broadcast hint isn't propagated after a join") {
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+ val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value")
+ val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value")
+ val df3 = df1.join(broadcast(df2), Seq("key"), "inner").drop(df2("key"))
+
+ val df4 = spark.createDataFrame(Seq((1, "5"), (2, "5"))).toDF("key", "value")
+ val df5 = df4.join(df3, Seq("key"), "inner")
+
+ val plan =
+ EnsureRequirements(spark.sessionState.conf).apply(df5.queryExecution.sparkPlan)
+
+ assert(plan.collect { case p: BroadcastHashJoinExec => p }.size === 1)
+ assert(plan.collect { case p: SortMergeJoinExec => p }.size === 1)
+ }
}
+ private def assertBroadcastJoin(df : Dataset[Row]) : Unit = {
+ val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value")
+ val joined = df1.join(df, Seq("key"), "inner")
+
+ val plan =
+ EnsureRequirements(spark.sessionState.conf).apply(joined.queryExecution.sparkPlan)
+
+ assert(plan.collect { case p: BroadcastHashJoinExec => p }.size === 1)
+ }
+
+ test("broadcast hint is propagated correctly") {
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+ val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"), (3, "2"))).toDF("key", "value")
+ val broadcasted = broadcast(df2)
+ val df3 = spark.createDataFrame(Seq((2, "2"), (3, "3"))).toDF("key", "value")
+
+ val cases = Seq(broadcasted.limit(2),
+ broadcasted.filter("value < 10"),
+ broadcasted.sample(true, 0.5),
+ broadcasted.distinct(),
+ broadcasted.groupBy("value").agg(min($"key").as("key")),
+ // except and intersect are semi/anti-joins which won't return more data then
+ // their left argument, so the broadcast hint should be propagated here
+ broadcasted.except(df3),
+ broadcasted.intersect(df3))
+
+ cases.foreach(assertBroadcastJoin)
+ }
+ }
}