diff options
author | wangxiaojing <u9jing@gmail.com> | 2014-12-30 13:54:12 -0800 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2014-12-30 13:54:12 -0800 |
commit | 07fa1910d9c4092d670381c447403105f01c584e (patch) | |
tree | 6552fd64d119b5333145b49200a598676145c534 /sql | |
parent | 8f29b7cafc2b6e802e4eb21f681d6369da2f30fa (diff) | |
download | spark-07fa1910d9c4092d670381c447403105f01c584e.tar.gz spark-07fa1910d9c4092d670381c447403105f01c584e.tar.bz2 spark-07fa1910d9c4092d670381c447403105f01c584e.zip |
[SPARK-4570][SQL]add BroadcastLeftSemiJoinHash
JIRA issue: [SPARK-4570](https://issues.apache.org/jira/browse/SPARK-4570)
We are planning to create a `BroadcastLeftSemiJoinHash` to implement the broadcast join for `left semijoin`
In left semijoin :
If the size of data from right side is smaller than the user-settable threshold `AUTO_BROADCASTJOIN_THRESHOLD`,
the planner would mark it as the `broadcast` relation and mark the other relation as the stream side. The broadcast table will be broadcasted to all of the executors involved in the join, as a `org.apache.spark.broadcast.Broadcast` object. It will use `joins.BroadcastLeftSemiJoinHash`.,else it will use `joins.LeftSemiJoinHash`.
The benchmark suggests these made the optimized version 4x faster when `left semijoin`
<pre><code>
Original:
left semi join : 9288 ms
Optimized:
left semi join : 1963 ms
</code></pre>
The micro benchmark load `data1/kv3.txt` into a normal Hive table.
Benchmark code:
<pre><code>
def benchmark(f: => Unit) = {
val begin = System.currentTimeMillis()
f
val end = System.currentTimeMillis()
end - begin
}
val sc = new SparkContext(
new SparkConf()
.setMaster("local")
.setAppName(getClass.getSimpleName.stripSuffix("$")))
val hiveContext = new HiveContext(sc)
import hiveContext._
sql("drop table if exists left_table")
sql("drop table if exists right_table")
sql( """create table left_table (key int, value string)
""".stripMargin)
sql( s"""load data local inpath "/data1/kv3.txt" into table left_table""")
sql( """create table right_table (key int, value string)
""".stripMargin)
sql(
"""
|from left_table
|insert overwrite table right_table
|select left_table.key, left_table.value
""".stripMargin)
val leftSimeJoin = sql(
"""select a.key from left_table a
|left semi join right_table b on a.key = b.key""".stripMargin)
val leftSemiJoinDuration = benchmark(leftSimeJoin.count())
println(s"left semi join : $leftSemiJoinDuration ms ")
</code></pre>
Author: wangxiaojing <u9jing@gmail.com>
Closes #3442 from wangxiaojing/SPARK-4570 and squashes the following commits:
a4a43c9 [wangxiaojing] rebase
f103983 [wangxiaojing] change style
fbe4887 [wangxiaojing] change style
ff2e618 [wangxiaojing] add testsuite
1a8da2a [wangxiaojing] add BroadcastLeftSemiJoinHash
Diffstat (limited to 'sql')
4 files changed, 160 insertions, 1 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 9151da69ed..ce878c137e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -33,6 +33,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object LeftSemiJoin extends Strategy with PredicateHelper { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) + if sqlContext.autoBroadcastJoinThreshold > 0 && + right.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold => + val semiJoin = joins.BroadcastLeftSemiJoinHash( + leftKeys, rightKeys, planLater(left), planLater(right)) + condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil // Find left semi joins where at least some predicates can be evaluated by matching join keys case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) => val semiJoin = joins.LeftSemiJoinHash( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala new file mode 100644 index 0000000000..2ab064fd01 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions.{Expression, Row} +import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} + +/** + * :: DeveloperApi :: + * Build the right table's join keys into a HashSet, and iteratively go through the left + * table, to find the if join keys are in the Hash set. + */ +@DeveloperApi +case class BroadcastLeftSemiJoinHash( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode with HashJoin { + + override val buildSide = BuildRight + + override def output = left.output + + override def execute() = { + val buildIter= buildPlan.execute().map(_.copy()).collect().toIterator + val hashSet = new java.util.HashSet[Row]() + var currentRow: Row = null + + // Create a Hash set of buildKeys + while (buildIter.hasNext) { + currentRow = buildIter.next() + val rowKey = buildSideKeyGenerator(currentRow) + if (!rowKey.anyNull) { + val keyExists = hashSet.contains(rowKey) + if (!keyExists) { + hashSet.add(rowKey) + } + } + } + + val broadcastedRelation = sparkContext.broadcast(hashSet) + + streamedPlan.execute().mapPartitions { streamIter => + val joinKeys = streamSideKeyGenerator() + streamIter.filter(current => { + !joinKeys(current).anyNull && broadcastedRelation.value.contains(joinKeys.currentValue) + }) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 0378fd7e36..1a4232dab8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -48,6 +48,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { case j: LeftSemiJoinBNL => j case j: CartesianProduct => j case j: BroadcastNestedLoopJoin => j + case j: BroadcastLeftSemiJoinHash => j } assert(operators.size === 1) @@ -382,4 +383,41 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { """.stripMargin), (null, 10) :: Nil) } + + test("broadcasted left semi join operator selection") { + clearCache() + sql("CACHE TABLE testData") + val tmp = autoBroadcastJoinThreshold + + sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=1000000000") + Seq( + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", + classOf[BroadcastLeftSemiJoinHash]) + ).foreach { + case (query, joinClass) => assertJoin(query, joinClass) + } + + sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1") + + Seq( + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]) + ).foreach { + case (query, joinClass) => assertJoin(query, joinClass) + } + + setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, tmp.toString) + sql("UNCACHE TABLE testData") + } + + test("left semi join") { + val rdd = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a") + checkAnswer(rdd, + (1, 1) :: + (1, 2) :: + (2, 1) :: + (2, 2) :: + (3, 1) :: + (3, 2) :: Nil) + + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index ff4071d8e2..4b6a9308b9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -22,7 +22,7 @@ import org.scalatest.BeforeAndAfterAll import scala.reflect.ClassTag import org.apache.spark.sql.{SQLConf, QueryTest} -import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin} +import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.execution._ @@ -193,4 +193,52 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { ) } + test("auto converts to broadcast left semi join, by size estimate of a relation") { + val leftSemiJoinQuery = + """SELECT * FROM src a + |left semi JOIN src b ON a.key=86 and a.key = b.key""".stripMargin + val answer = (86, "val_86") :: Nil + + var rdd = sql(leftSemiJoinQuery) + + // Assert src has a size smaller than the threshold. + val sizes = rdd.queryExecution.analyzed.collect { + case r if implicitly[ClassTag[MetastoreRelation]].runtimeClass + .isAssignableFrom(r.getClass) => + r.statistics.sizeInBytes + } + assert(sizes.size === 2 && sizes(1) <= autoBroadcastJoinThreshold + && sizes(0) <= autoBroadcastJoinThreshold, + s"query should contain two relations, each of which has size smaller than autoConvertSize") + + // Using `sparkPlan` because for relevant patterns in HashJoin to be + // matched, other strategies need to be applied. + var bhj = rdd.queryExecution.sparkPlan.collect { + case j: BroadcastLeftSemiJoinHash => j + } + assert(bhj.size === 1, + s"actual query plans do not contain broadcast join: ${rdd.queryExecution}") + + checkAnswer(rdd, answer) // check correctness of output + + TestHive.settings.synchronized { + val tmp = autoBroadcastJoinThreshold + + sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1") + rdd = sql(leftSemiJoinQuery) + bhj = rdd.queryExecution.sparkPlan.collect { + case j: BroadcastLeftSemiJoinHash => j + } + assert(bhj.isEmpty, "BroadcastHashJoin still planned even though it is switched off") + + val shj = rdd.queryExecution.sparkPlan.collect { + case j: LeftSemiJoinHash => j + } + assert(shj.size === 1, + "LeftSemiJoinHash should be planned when BroadcastHashJoin is turned off") + + sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=$tmp") + } + + } } |