aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorwangxiaojing <u9jing@gmail.com>2014-12-30 13:54:12 -0800
committerMichael Armbrust <michael@databricks.com>2014-12-30 13:54:12 -0800
commit07fa1910d9c4092d670381c447403105f01c584e (patch)
tree6552fd64d119b5333145b49200a598676145c534 /sql
parent8f29b7cafc2b6e802e4eb21f681d6369da2f30fa (diff)
downloadspark-07fa1910d9c4092d670381c447403105f01c584e.tar.gz
spark-07fa1910d9c4092d670381c447403105f01c584e.tar.bz2
spark-07fa1910d9c4092d670381c447403105f01c584e.zip
[SPARK-4570][SQL]add BroadcastLeftSemiJoinHash
JIRA issue: [SPARK-4570](https://issues.apache.org/jira/browse/SPARK-4570) We are planning to create a `BroadcastLeftSemiJoinHash` to implement the broadcast join for `left semijoin` In left semijoin : If the size of data from right side is smaller than the user-settable threshold `AUTO_BROADCASTJOIN_THRESHOLD`, the planner would mark it as the `broadcast` relation and mark the other relation as the stream side. The broadcast table will be broadcasted to all of the executors involved in the join, as a `org.apache.spark.broadcast.Broadcast` object. It will use `joins.BroadcastLeftSemiJoinHash`.,else it will use `joins.LeftSemiJoinHash`. The benchmark suggests these made the optimized version 4x faster when `left semijoin` <pre><code> Original: left semi join : 9288 ms Optimized: left semi join : 1963 ms </code></pre> The micro benchmark load `data1/kv3.txt` into a normal Hive table. Benchmark code: <pre><code> def benchmark(f: => Unit) = { val begin = System.currentTimeMillis() f val end = System.currentTimeMillis() end - begin } val sc = new SparkContext( new SparkConf() .setMaster("local") .setAppName(getClass.getSimpleName.stripSuffix("$"))) val hiveContext = new HiveContext(sc) import hiveContext._ sql("drop table if exists left_table") sql("drop table if exists right_table") sql( """create table left_table (key int, value string) """.stripMargin) sql( s"""load data local inpath "/data1/kv3.txt" into table left_table""") sql( """create table right_table (key int, value string) """.stripMargin) sql( """ |from left_table |insert overwrite table right_table |select left_table.key, left_table.value """.stripMargin) val leftSimeJoin = sql( """select a.key from left_table a |left semi join right_table b on a.key = b.key""".stripMargin) val leftSemiJoinDuration = benchmark(leftSimeJoin.count()) println(s"left semi join : $leftSemiJoinDuration ms ") </code></pre> Author: wangxiaojing <u9jing@gmail.com> Closes #3442 from wangxiaojing/SPARK-4570 and squashes the following commits: a4a43c9 [wangxiaojing] rebase f103983 [wangxiaojing] change style fbe4887 [wangxiaojing] change style ff2e618 [wangxiaojing] add testsuite 1a8da2a [wangxiaojing] add BroadcastLeftSemiJoinHash
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala67
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala38
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala50
4 files changed, 160 insertions, 1 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 9151da69ed..ce878c137e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -33,6 +33,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
object LeftSemiJoin extends Strategy with PredicateHelper {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right)
+ if sqlContext.autoBroadcastJoinThreshold > 0 &&
+ right.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold =>
+ val semiJoin = joins.BroadcastLeftSemiJoinHash(
+ leftKeys, rightKeys, planLater(left), planLater(right))
+ condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil
// Find left semi joins where at least some predicates can be evaluated by matching join keys
case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) =>
val semiJoin = joins.LeftSemiJoinHash(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
new file mode 100644
index 0000000000..2ab064fd01
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.catalyst.expressions.{Expression, Row}
+import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution
+import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
+
+/**
+ * :: DeveloperApi ::
+ * Build the right table's join keys into a HashSet, and iteratively go through the left
+ * table, to find the if join keys are in the Hash set.
+ */
+@DeveloperApi
+case class BroadcastLeftSemiJoinHash(
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ left: SparkPlan,
+ right: SparkPlan) extends BinaryNode with HashJoin {
+
+ override val buildSide = BuildRight
+
+ override def output = left.output
+
+ override def execute() = {
+ val buildIter= buildPlan.execute().map(_.copy()).collect().toIterator
+ val hashSet = new java.util.HashSet[Row]()
+ var currentRow: Row = null
+
+ // Create a Hash set of buildKeys
+ while (buildIter.hasNext) {
+ currentRow = buildIter.next()
+ val rowKey = buildSideKeyGenerator(currentRow)
+ if (!rowKey.anyNull) {
+ val keyExists = hashSet.contains(rowKey)
+ if (!keyExists) {
+ hashSet.add(rowKey)
+ }
+ }
+ }
+
+ val broadcastedRelation = sparkContext.broadcast(hashSet)
+
+ streamedPlan.execute().mapPartitions { streamIter =>
+ val joinKeys = streamSideKeyGenerator()
+ streamIter.filter(current => {
+ !joinKeys(current).anyNull && broadcastedRelation.value.contains(joinKeys.currentValue)
+ })
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 0378fd7e36..1a4232dab8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -48,6 +48,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
case j: LeftSemiJoinBNL => j
case j: CartesianProduct => j
case j: BroadcastNestedLoopJoin => j
+ case j: BroadcastLeftSemiJoinHash => j
}
assert(operators.size === 1)
@@ -382,4 +383,41 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
""".stripMargin),
(null, 10) :: Nil)
}
+
+ test("broadcasted left semi join operator selection") {
+ clearCache()
+ sql("CACHE TABLE testData")
+ val tmp = autoBroadcastJoinThreshold
+
+ sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=1000000000")
+ Seq(
+ ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a",
+ classOf[BroadcastLeftSemiJoinHash])
+ ).foreach {
+ case (query, joinClass) => assertJoin(query, joinClass)
+ }
+
+ sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1")
+
+ Seq(
+ ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash])
+ ).foreach {
+ case (query, joinClass) => assertJoin(query, joinClass)
+ }
+
+ setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, tmp.toString)
+ sql("UNCACHE TABLE testData")
+ }
+
+ test("left semi join") {
+ val rdd = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a")
+ checkAnswer(rdd,
+ (1, 1) ::
+ (1, 2) ::
+ (2, 1) ::
+ (2, 2) ::
+ (3, 1) ::
+ (3, 2) :: Nil)
+
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
index ff4071d8e2..4b6a9308b9 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
@@ -22,7 +22,7 @@ import org.scalatest.BeforeAndAfterAll
import scala.reflect.ClassTag
import org.apache.spark.sql.{SQLConf, QueryTest}
-import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin}
+import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
import org.apache.spark.sql.hive.execution._
@@ -193,4 +193,52 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll {
)
}
+ test("auto converts to broadcast left semi join, by size estimate of a relation") {
+ val leftSemiJoinQuery =
+ """SELECT * FROM src a
+ |left semi JOIN src b ON a.key=86 and a.key = b.key""".stripMargin
+ val answer = (86, "val_86") :: Nil
+
+ var rdd = sql(leftSemiJoinQuery)
+
+ // Assert src has a size smaller than the threshold.
+ val sizes = rdd.queryExecution.analyzed.collect {
+ case r if implicitly[ClassTag[MetastoreRelation]].runtimeClass
+ .isAssignableFrom(r.getClass) =>
+ r.statistics.sizeInBytes
+ }
+ assert(sizes.size === 2 && sizes(1) <= autoBroadcastJoinThreshold
+ && sizes(0) <= autoBroadcastJoinThreshold,
+ s"query should contain two relations, each of which has size smaller than autoConvertSize")
+
+ // Using `sparkPlan` because for relevant patterns in HashJoin to be
+ // matched, other strategies need to be applied.
+ var bhj = rdd.queryExecution.sparkPlan.collect {
+ case j: BroadcastLeftSemiJoinHash => j
+ }
+ assert(bhj.size === 1,
+ s"actual query plans do not contain broadcast join: ${rdd.queryExecution}")
+
+ checkAnswer(rdd, answer) // check correctness of output
+
+ TestHive.settings.synchronized {
+ val tmp = autoBroadcastJoinThreshold
+
+ sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1")
+ rdd = sql(leftSemiJoinQuery)
+ bhj = rdd.queryExecution.sparkPlan.collect {
+ case j: BroadcastLeftSemiJoinHash => j
+ }
+ assert(bhj.isEmpty, "BroadcastHashJoin still planned even though it is switched off")
+
+ val shj = rdd.queryExecution.sparkPlan.collect {
+ case j: LeftSemiJoinHash => j
+ }
+ assert(shj.size === 1,
+ "LeftSemiJoinHash should be planned when BroadcastHashJoin is turned off")
+
+ sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=$tmp")
+ }
+
+ }
}