aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-03-15 22:17:04 -0700
committerDavies Liu <davies.liu@gmail.com>2016-03-15 22:17:04 -0700
commit421f6c20e85b32f6462d37dad6a62dec2d46ed88 (patch)
treede14ba806b61c1cb0582bb1058ca1d301f25f9b2 /sql/core
parent52b6a899be2711568d86ab99d1a2b4d1f9fef286 (diff)
downloadspark-421f6c20e85b32f6462d37dad6a62dec2d46ed88.tar.gz
spark-421f6c20e85b32f6462d37dad6a62dec2d46ed88.tar.bz2
spark-421f6c20e85b32f6462d37dad6a62dec2d46ed88.zip
[SPARK-13917] [SQL] generate broadcast semi join
## What changes were proposed in this pull request? This PR brings codegen support for broadcast left-semi join. ## How was this patch tested? Existing tests. Added benchmark, the result show 7X speedup. Author: Davies Liu <davies@databricks.com> Closes #11742 from davies/gen_semi.
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala81
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala57
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala23
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala61
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala14
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala5
10 files changed, 122 insertions, 137 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 7fc6a8267f..121b6d9e97 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -65,8 +65,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case ExtractEquiJoinKeys(
LeftSemi, leftKeys, rightKeys, condition, left, CanBroadcast(right)) =>
- joins.BroadcastLeftSemiJoinHash(
- leftKeys, rightKeys, planLater(left), planLater(right), condition) :: Nil
+ Seq(joins.BroadcastHashJoin(
+ leftKeys, rightKeys, LeftSemi, BuildRight, condition, planLater(left), planLater(right)))
// Find left semi joins where at least some predicates can be evaluated by matching join keys
case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) =>
joins.LeftSemiJoinHash(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index 4c8f8080a9..f84ed41f1d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -23,7 +23,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection}
-import org.apache.spark.sql.catalyst.plans.{Inner, JoinType, LeftOuter, RightOuter}
+import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, Partitioning, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
@@ -92,6 +92,9 @@ case class BroadcastHashJoin(
rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj, numOutputRows)
}
+ case LeftSemi =>
+ hashSemiJoin(streamedIter, hashTable, numOutputRows)
+
case x =>
throw new IllegalArgumentException(
s"BroadcastHashJoin should not take $x as the JoinType")
@@ -108,11 +111,13 @@ case class BroadcastHashJoin(
}
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = {
- if (joinType == Inner) {
- codegenInner(ctx, input)
- } else {
- // LeftOuter and RightOuter
- codegenOuter(ctx, input)
+ joinType match {
+ case Inner => codegenInner(ctx, input)
+ case LeftOuter | RightOuter => codegenOuter(ctx, input)
+ case LeftSemi => codegenSemi(ctx, input)
+ case x =>
+ throw new IllegalArgumentException(
+ s"BroadcastHashJoin should not take $x as the JoinType")
}
}
@@ -322,4 +327,68 @@ case class BroadcastHashJoin(
""".stripMargin
}
}
+
+ /**
+ * Generates the code for left semi join.
+ */
+ private def codegenSemi(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
+ val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+ val matched = ctx.freshName("matched")
+ val buildVars = genBuildSideVars(ctx, matched)
+ val numOutput = metricTerm(ctx, "numOutputRows")
+
+ val checkCondition = if (condition.isDefined) {
+ val expr = condition.get
+ // evaluate the variables from build side that used by condition
+ val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references)
+ // filter the output via condition
+ ctx.currentVars = input ++ buildVars
+ val ev = BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).gen(ctx)
+ s"""
+ |$eval
+ |${ev.code}
+ |if (${ev.isNull} || !${ev.value}) continue;
+ """.stripMargin
+ } else {
+ ""
+ }
+
+ if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) {
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashedRelation
+ |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value});
+ |if ($matched == null) continue;
+ |$checkCondition
+ |$numOutput.add(1);
+ |${consume(ctx, input)}
+ """.stripMargin
+ } else {
+ val matches = ctx.freshName("matches")
+ val bufferType = classOf[CompactBuffer[UnsafeRow]].getName
+ val i = ctx.freshName("i")
+ val size = ctx.freshName("size")
+ val found = ctx.freshName("found")
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashRelation
+ |$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value});
+ |if ($matches == null) continue;
+ |int $size = $matches.size();
+ |boolean $found = false;
+ |for (int $i = 0; $i < $size; $i++) {
+ | UnsafeRow $matched = (UnsafeRow) $matches.apply($i);
+ | $checkCondition
+ | $found = true;
+ | break;
+ |}
+ |if (!$found) continue;
+ |$numOutput.add(1);
+ |${consume(ctx, input)}
+ """.stripMargin
+ }
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
deleted file mode 100644
index d3bcfad7c3..0000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
+++ /dev/null
@@ -1,57 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.joins
-
-import org.apache.spark.TaskContext
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution}
-import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
-import org.apache.spark.sql.execution.metric.SQLMetrics
-
-/**
- * Build the right table's join keys into a HashedRelation, and iteratively go through the left
- * table, to find if the join keys are in the HashedRelation.
- */
-case class BroadcastLeftSemiJoinHash(
- leftKeys: Seq[Expression],
- rightKeys: Seq[Expression],
- left: SparkPlan,
- right: SparkPlan,
- condition: Option[Expression]) extends BinaryNode with HashSemiJoin {
-
- override private[sql] lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
-
- override def requiredChildDistribution: Seq[Distribution] = {
- val mode = HashedRelationBroadcastMode(canJoinKeyFitWithinLong = false, rightKeys, right.output)
- UnspecifiedDistribution :: BroadcastDistribution(mode) :: Nil
- }
-
- protected override def doExecute(): RDD[InternalRow] = {
- val numOutputRows = longMetric("numOutputRows")
-
- val broadcastedRelation = right.executeBroadcast[HashedRelation]()
- left.execute().mapPartitionsInternal { streamIter =>
- val hashedRelation = broadcastedRelation.value
- TaskContext.get().taskMetrics().incPeakExecutionMemory(hashedRelation.getMemorySize)
- hashSemiJoin(streamIter, hashedRelation, numOutputRows)
- }
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index 2fe9c06cc9..5f42d07273 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -46,8 +46,8 @@ trait HashJoin {
left.output ++ right.output.map(_.withNullability(true))
case RightOuter =>
left.output.map(_.withNullability(true)) ++ right.output
- case FullOuter =>
- left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
+ case LeftSemi =>
+ left.output
case x =>
throw new IllegalArgumentException(s"HashJoin should not take $x as the JoinType")
}
@@ -104,7 +104,7 @@ trait HashJoin {
keyExpr :: Nil
}
- protected val canJoinKeyFitWithinLong: Boolean = {
+ protected lazy val canJoinKeyFitWithinLong: Boolean = {
val sameTypes = buildKeys.map(_.dataType) == streamedKeys.map(_.dataType)
val key = rewriteKeyExpr(buildKeys)
sameTypes && key.length == 1 && key.head.dataType.isInstanceOf[LongType]
@@ -258,4 +258,21 @@ trait HashJoin {
}
ret.iterator
}
+
+ protected def hashSemiJoin(
+ streamIter: Iterator[InternalRow],
+ hashedRelation: HashedRelation,
+ numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
+ val joinKeys = streamSideKeyGenerator
+ val joinedRow = new JoinedRow
+ streamIter.filter { current =>
+ val key = joinKeys(current)
+ lazy val rowBuffer = hashedRelation.get(key)
+ val r = !key.anyNull && rowBuffer != null && (condition.isEmpty || rowBuffer.exists {
+ (row: InternalRow) => boundCondition(joinedRow(current, row))
+ })
+ if (r) numOutputRows += 1
+ r
+ }
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
deleted file mode 100644
index 813ec02425..0000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
+++ /dev/null
@@ -1,61 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.joins
-
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.execution.metric.LongSQLMetric
-
-
-trait HashSemiJoin {
- self: SparkPlan =>
- val leftKeys: Seq[Expression]
- val rightKeys: Seq[Expression]
- val left: SparkPlan
- val right: SparkPlan
- val condition: Option[Expression]
-
- override def output: Seq[Attribute] = left.output
-
- protected def leftKeyGenerator: Projection =
- UnsafeProjection.create(leftKeys, left.output)
-
- protected def rightKeyGenerator: Projection =
- UnsafeProjection.create(rightKeys, right.output)
-
- @transient private lazy val boundCondition =
- newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
-
- protected def hashSemiJoin(
- streamIter: Iterator[InternalRow],
- hashedRelation: HashedRelation,
- numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
- val joinKeys = leftKeyGenerator
- val joinedRow = new JoinedRow
- streamIter.filter { current =>
- val key = joinKeys(current)
- lazy val rowBuffer = hashedRelation.get(key)
- val r = !key.anyNull && rowBuffer != null && (condition.isEmpty || rowBuffer.exists {
- (row: InternalRow) => boundCondition(joinedRow(current, row))
- })
- if (r) numOutputRows += 1
- r
- }
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
index 14389e45ba..fa549b4d51 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.LeftSemi
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
@@ -33,7 +34,10 @@ case class LeftSemiJoinHash(
rightKeys: Seq[Expression],
left: SparkPlan,
right: SparkPlan,
- condition: Option[Expression]) extends BinaryNode with HashSemiJoin {
+ condition: Option[Expression]) extends BinaryNode with HashJoin {
+
+ override val joinType = LeftSemi
+ override val buildSide = BuildRight
override private[sql] lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
@@ -47,7 +51,7 @@ case class LeftSemiJoinHash(
val numOutputRows = longMetric("numOutputRows")
right.execute().zipPartitions(left.execute()) { (buildIter, streamIter) =>
- val hashRelation = HashedRelation(buildIter.map(_.copy()), rightKeyGenerator)
+ val hashRelation = HashedRelation(buildIter.map(_.copy()), buildSideKeyGenerator)
hashSemiJoin(streamIter, hashRelation, numOutputRows)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 580e8d815a..41919910a9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -49,7 +49,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
case j: BroadcastHashJoin => j
case j: CartesianProduct => j
case j: BroadcastNestedLoopJoin => j
- case j: BroadcastLeftSemiJoinHash => j
+ case j: BroadcastHashJoin => j
case j: SortMergeJoin => j
}
@@ -427,7 +427,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000000000") {
Seq(
("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a",
- classOf[BroadcastLeftSemiJoinHash])
+ classOf[BroadcastHashJoin])
).foreach {
case (query, joinClass) => assertJoin(query, joinClass)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index 9f33e4ab62..cb672643f1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -38,7 +38,7 @@ import org.apache.spark.util.Benchmark
class BenchmarkWholeStageCodegen extends SparkFunSuite {
lazy val conf = new SparkConf().setMaster("local[1]").setAppName("benchmark")
.set("spark.sql.shuffle.partitions", "1")
- .set("spark.sql.autoBroadcastJoinThreshold", "0")
+ .set("spark.sql.autoBroadcastJoinThreshold", "1")
lazy val sc = SparkContext.getOrCreate(conf)
lazy val sqlContext = SQLContext.getOrCreate(sc)
@@ -200,6 +200,18 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
outer join w long codegen=false 15280 / 16497 6.9 145.7 1.0X
outer join w long codegen=true 769 / 796 136.3 7.3 19.9X
*/
+
+ runBenchmark("semi join w long", N) {
+ sqlContext.range(N).join(dim, (col("id") bitwiseAND M) === col("k"), "leftsemi").count()
+ }
+
+ /**
+ Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+ semi join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ -------------------------------------------------------------------------------------------
+ semi join w long codegen=false 5804 / 5969 18.1 55.3 1.0X
+ semi join w long codegen=true 814 / 934 128.8 7.8 7.1X
+ */
}
ignore("sort merge join") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
index 6d5b777733..babe7ef70f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
@@ -79,7 +79,7 @@ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll {
}
test("unsafe broadcast left semi join updates peak execution memory") {
- testBroadcastJoin[BroadcastLeftSemiJoinHash]("unsafe broadcast left semi join", "leftsemi")
+ testBroadcastJoin[BroadcastHashJoin]("unsafe broadcast left semi join", "leftsemi")
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
index d8c9564f1e..5eb6a74523 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
@@ -84,11 +84,12 @@ class SemiJoinSuite extends SparkPlanTest with SharedSQLContext {
}
}
- test(s"$testName using BroadcastLeftSemiJoinHash") {
+ test(s"$testName using BroadcastHashJoin") {
extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
- BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition),
+ BroadcastHashJoin(
+ leftKeys, rightKeys, LeftSemi, BuildRight, boundCondition, left, right),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}