aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-02-18 15:15:06 -0800
committerDavies Liu <davies.liu@gmail.com>2016-02-18 15:15:06 -0800
commit95e1ab223e87fc216f3256d404fe3be50d111a9d (patch)
treec4a92105835a85c38565d3121e83d38c3b000da4 /sql/core/src
parent26f38bb83c423e512955ca25775914dae7e5bbf0 (diff)
downloadspark-95e1ab223e87fc216f3256d404fe3be50d111a9d.tar.gz
spark-95e1ab223e87fc216f3256d404fe3be50d111a9d.tar.bz2
spark-95e1ab223e87fc216f3256d404fe3be50d111a9d.zip
[SPARK-13237] [SQL] generated broadcast outer join
This PR support codegen for broadcast outer join. In order to reduce the duplicated codes, this PR merge HashJoin and HashOuterJoin together (also BroadcastHashJoin and BroadcastHashOuterJoin). Author: Davies Liu <davies@databricks.com> Closes #11130 from davies/gen_out.
Diffstat (limited to 'sql/core/src')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala16
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala253
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala121
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala111
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala153
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala5
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala131
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala9
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala8
12 files changed, 448 insertions, 371 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 042c99db4d..382654afac 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -108,12 +108,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
// --- Inner joins --------------------------------------------------------------------------
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, CanBroadcast(right)) =>
- joins.BroadcastHashJoin(
- leftKeys, rightKeys, BuildRight, condition, planLater(left), planLater(right)) :: Nil
+ Seq(joins.BroadcastHashJoin(
+ leftKeys, rightKeys, Inner, BuildRight, condition, planLater(left), planLater(right)))
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, CanBroadcast(left), right) =>
- joins.BroadcastHashJoin(
- leftKeys, rightKeys, BuildLeft, condition, planLater(left), planLater(right)) :: Nil
+ Seq(joins.BroadcastHashJoin(
+ leftKeys, rightKeys, Inner, BuildLeft, condition, planLater(left), planLater(right)))
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
if RowOrdering.isOrderable(leftKeys) =>
@@ -124,13 +124,13 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case ExtractEquiJoinKeys(
LeftOuter, leftKeys, rightKeys, condition, left, CanBroadcast(right)) =>
- joins.BroadcastHashOuterJoin(
- leftKeys, rightKeys, LeftOuter, condition, planLater(left), planLater(right)) :: Nil
+ Seq(joins.BroadcastHashJoin(
+ leftKeys, rightKeys, LeftOuter, BuildRight, condition, planLater(left), planLater(right)))
case ExtractEquiJoinKeys(
RightOuter, leftKeys, rightKeys, condition, CanBroadcast(left), right) =>
- joins.BroadcastHashOuterJoin(
- leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil
+ Seq(joins.BroadcastHashJoin(
+ leftKeys, rightKeys, RightOuter, BuildLeft, condition, planLater(left), planLater(right)))
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
if RowOrdering.isOrderable(leftKeys) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index f35efb5b24..8626f54eb4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.aggregate.TungstenAggregate
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, BuildLeft, BuildRight}
-import org.apache.spark.sql.execution.metric.{LongSQLMetric, LongSQLMetricValue, SQLMetric}
+import org.apache.spark.sql.execution.metric.LongSQLMetricValue
/**
* An interface for those physical operators that support codegen.
@@ -38,7 +38,7 @@ trait CodegenSupport extends SparkPlan {
/** Prefix used in the current operator's variable names. */
private def variablePrefix: String = this match {
case _: TungstenAggregate => "agg"
- case _: BroadcastHashJoin => "bhj"
+ case _: BroadcastHashJoin => "join"
case _ => nodeName.toLowerCase
}
@@ -391,9 +391,9 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru
var inputs = ArrayBuffer[SparkPlan]()
val combined = plan.transform {
// The build side can't be compiled together
- case b @ BroadcastHashJoin(_, _, BuildLeft, _, left, right) =>
+ case b @ BroadcastHashJoin(_, _, _, BuildLeft, _, left, right) =>
b.copy(left = apply(left))
- case b @ BroadcastHashJoin(_, _, BuildRight, _, left, right) =>
+ case b @ BroadcastHashJoin(_, _, _, BuildRight, _, left, right) =>
b.copy(right = apply(right))
case p if !supportCodegen(p) =>
val input = apply(p) // collapse them recursively
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index 985e74011d..a64da22580 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -24,8 +24,9 @@ import org.apache.spark.TaskContext
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{BindReferences, BoundReference, Expression, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection}
+import org.apache.spark.sql.catalyst.plans.{Inner, JoinType, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.metric.SQLMetrics
@@ -41,6 +42,7 @@ import org.apache.spark.util.collection.CompactBuffer
case class BroadcastHashJoin(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
+ joinType: JoinType,
buildSide: BuildSide,
condition: Option[Expression],
left: SparkPlan,
@@ -105,75 +107,144 @@ case class BroadcastHashJoin(
val broadcastRelation = Await.result(broadcastFuture, timeout)
streamedPlan.execute().mapPartitions { streamedIter =>
- val hashedRelation = broadcastRelation.value
- TaskContext.get().taskMetrics().incPeakExecutionMemory(hashedRelation.getMemorySize)
- hashJoin(streamedIter, hashedRelation, numOutputRows)
+ val joinedRow = new JoinedRow()
+ val hashTable = broadcastRelation.value
+ TaskContext.get().taskMetrics().incPeakExecutionMemory(hashTable.getMemorySize)
+ val keyGenerator = streamSideKeyGenerator
+ val resultProj = createResultProjection
+
+ joinType match {
+ case Inner =>
+ hashJoin(streamedIter, hashTable, numOutputRows)
+
+ case LeftOuter =>
+ streamedIter.flatMap { currentRow =>
+ val rowKey = keyGenerator(currentRow)
+ joinedRow.withLeft(currentRow)
+ leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey), resultProj, numOutputRows)
+ }
+
+ case RightOuter =>
+ streamedIter.flatMap { currentRow =>
+ val rowKey = keyGenerator(currentRow)
+ joinedRow.withRight(currentRow)
+ rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj, numOutputRows)
+ }
+
+ case x =>
+ throw new IllegalArgumentException(
+ s"BroadcastHashJoin should not take $x as the JoinType")
+ }
}
}
- private var broadcastRelation: Broadcast[HashedRelation] = _
- // the term for hash relation
- private var relationTerm: String = _
-
override def upstream(): RDD[InternalRow] = {
streamedPlan.asInstanceOf[CodegenSupport].upstream()
}
override def doProduce(ctx: CodegenContext): String = {
+ streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this)
+ }
+
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ if (joinType == Inner) {
+ codegenInner(ctx, input)
+ } else {
+ // LeftOuter and RightOuter
+ codegenOuter(ctx, input)
+ }
+ }
+
+ /**
+ * Returns a tuple of Broadcast of HashedRelation and the variable name for it.
+ */
+ private def prepareBroadcast(ctx: CodegenContext): (Broadcast[HashedRelation], String) = {
// create a name for HashedRelation
- broadcastRelation = Await.result(broadcastFuture, timeout)
+ val broadcastRelation = Await.result(broadcastFuture, timeout)
val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation)
- relationTerm = ctx.freshName("relation")
+ val relationTerm = ctx.freshName("relation")
val clsName = broadcastRelation.value.getClass.getName
ctx.addMutableState(clsName, relationTerm,
s"""
| $relationTerm = ($clsName) $broadcast.value();
| incPeakExecutionMemory($relationTerm.getMemorySize());
""".stripMargin)
-
- s"""
- | ${streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this)}
- """.stripMargin
+ (broadcastRelation, relationTerm)
}
- override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
- // generate the key as UnsafeRow or Long
+ /**
+ * Returns the code for generating join key for stream side, and expression of whether the key
+ * has any null in it or not.
+ */
+ private def genStreamSideJoinKey(
+ ctx: CodegenContext,
+ input: Seq[ExprCode]): (ExprCode, String) = {
ctx.currentVars = input
- val (keyVal, anyNull) = if (canJoinKeyFitWithinLong) {
+ if (canJoinKeyFitWithinLong) {
+ // generate the join key as Long
val expr = rewriteKeyExpr(streamedKeys).head
val ev = BindReferences.bindReference(expr, streamedPlan.output).gen(ctx)
(ev, ev.isNull)
} else {
+ // generate the join key as UnsafeRow
val keyExpr = streamedKeys.map(BindReferences.bindReference(_, streamedPlan.output))
val ev = GenerateUnsafeProjection.createCode(ctx, keyExpr)
(ev, s"${ev.value}.anyNull()")
}
+ }
- // find the matches from HashedRelation
- val matched = ctx.freshName("matched")
-
- // create variables for output
+ /**
+ * Generates the code for variable of build side.
+ */
+ private def genBuildSideVars(ctx: CodegenContext, matched: String): Seq[ExprCode] = {
ctx.currentVars = null
ctx.INPUT_ROW = matched
- val buildColumns = buildPlan.output.zipWithIndex.map { case (a, i) =>
- BoundReference(i, a.dataType, a.nullable).gen(ctx)
+ buildPlan.output.zipWithIndex.map { case (a, i) =>
+ val ev = BoundReference(i, a.dataType, a.nullable).gen(ctx)
+ if (joinType == Inner) {
+ ev
+ } else {
+ // the variables are needed even there is no matched rows
+ val isNull = ctx.freshName("isNull")
+ val value = ctx.freshName("value")
+ val code = s"""
+ |boolean $isNull = true;
+ |${ctx.javaType(a.dataType)} $value = ${ctx.defaultValue(a.dataType)};
+ |if ($matched != null) {
+ | ${ev.code}
+ | $isNull = ${ev.isNull};
+ | $value = ${ev.value};
+ |}
+ """.stripMargin
+ ExprCode(code, isNull, value)
+ }
}
+ }
+
+ /**
+ * Generates the code for Inner join.
+ */
+ private def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
+ val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+ val matched = ctx.freshName("matched")
+ val buildVars = genBuildSideVars(ctx, matched)
val resultVars = buildSide match {
- case BuildLeft => buildColumns ++ input
- case BuildRight => input ++ buildColumns
+ case BuildLeft => buildVars ++ input
+ case BuildRight => input ++ buildVars
}
-
val numOutput = metricTerm(ctx, "numOutputRows")
+
val outputCode = if (condition.isDefined) {
// filter the output via condition
ctx.currentVars = resultVars
val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx)
s"""
- | ${ev.code}
- | if (!${ev.isNull} && ${ev.value}) {
- | $numOutput.add(1);
- | ${consume(ctx, resultVars)}
- | }
+ |${ev.code}
+ |if (!${ev.isNull} && ${ev.value}) {
+ | $numOutput.add(1);
+ | ${consume(ctx, resultVars)}
+ |}
""".stripMargin
} else {
s"""
@@ -184,36 +255,110 @@ case class BroadcastHashJoin(
if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) {
s"""
- | // generate join key
- | ${keyVal.code}
- | // find matches from HashedRelation
- | UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyVal.value});
- | if ($matched != null) {
- | ${buildColumns.map(_.code).mkString("\n")}
- | $outputCode
- | }
- """.stripMargin
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashedRelation
+ |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value});
+ |if ($matched != null) {
+ | ${buildVars.map(_.code).mkString("\n")}
+ | $outputCode
+ |}
+ """.stripMargin
+
+ } else {
+ val matches = ctx.freshName("matches")
+ val bufferType = classOf[CompactBuffer[UnsafeRow]].getName
+ val i = ctx.freshName("i")
+ val size = ctx.freshName("size")
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashRelation
+ |$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value});
+ |if ($matches != null) {
+ | int $size = $matches.size();
+ | for (int $i = 0; $i < $size; $i++) {
+ | UnsafeRow $matched = (UnsafeRow) $matches.apply($i);
+ | ${buildVars.map(_.code).mkString("\n")}
+ | $outputCode
+ | }
+ |}
+ """.stripMargin
+ }
+ }
+
+
+ /**
+ * Generates the code for left or right outer join.
+ */
+ private def codegenOuter(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
+ val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+ val matched = ctx.freshName("matched")
+ val buildVars = genBuildSideVars(ctx, matched)
+ val resultVars = buildSide match {
+ case BuildLeft => buildVars ++ input
+ case BuildRight => input ++ buildVars
+ }
+ val numOutput = metricTerm(ctx, "numOutputRows")
+
+ // filter the output via condition
+ val conditionPassed = ctx.freshName("conditionPassed")
+ val checkCondition = if (condition.isDefined) {
+ ctx.currentVars = resultVars
+ val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx)
+ s"""
+ |boolean $conditionPassed = true;
+ |if ($matched != null) {
+ | ${ev.code}
+ | $conditionPassed = !${ev.isNull} && ${ev.value};
+ |}
+ """.stripMargin
+ } else {
+ s"final boolean $conditionPassed = true;"
+ }
+
+ if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) {
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashedRelation
+ |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value});
+ |${buildVars.map(_.code).mkString("\n")}
+ |${checkCondition.trim}
+ |if (!$conditionPassed) {
+ | // reset to null
+ | ${buildVars.map(v => s"${v.isNull} = true;").mkString("\n")}
+ |}
+ |$numOutput.add(1);
+ |${consume(ctx, resultVars)}
+ """.stripMargin
} else {
val matches = ctx.freshName("matches")
val bufferType = classOf[CompactBuffer[UnsafeRow]].getName
val i = ctx.freshName("i")
val size = ctx.freshName("size")
+ val found = ctx.freshName("found")
s"""
- | // generate join key
- | ${keyVal.code}
- | // find matches from HashRelation
- | $bufferType $matches = ${anyNull} ? null :
- | ($bufferType) $relationTerm.get(${keyVal.value});
- | if ($matches != null) {
- | int $size = $matches.size();
- | for (int $i = 0; $i < $size; $i++) {
- | UnsafeRow $matched = (UnsafeRow) $matches.apply($i);
- | ${buildColumns.map(_.code).mkString("\n")}
- | $outputCode
- | }
- | }
- """.stripMargin
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashRelation
+ |$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value});
+ |int $size = $matches != null ? $matches.size() : 0;
+ |boolean $found = false;
+ |// the last iteration of this loop is to emit an empty row if there is no matched rows.
+ |for (int $i = 0; $i <= $size; $i++) {
+ | UnsafeRow $matched = $i < $size ? (UnsafeRow) $matches.apply($i) : null;
+ | ${buildVars.map(_.code).mkString("\n")}
+ | ${checkCondition.trim}
+ | if ($conditionPassed && ($i < $size || !$found)) {
+ | $found = true;
+ | $numOutput.add(1);
+ | ${consume(ctx, resultVars)}
+ | }
+ |}
+ """.stripMargin
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
deleted file mode 100644
index 5e8c8ca043..0000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
+++ /dev/null
@@ -1,121 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.joins
-
-import scala.concurrent._
-import scala.concurrent.duration._
-
-import org.apache.spark.{InternalAccumulator, TaskContext}
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter}
-import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution}
-import org.apache.spark.sql.execution.{BinaryNode, SparkPlan, SQLExecution}
-import org.apache.spark.sql.execution.metric.SQLMetrics
-
-/**
- * Performs a outer hash join for two child relations. When the output RDD of this operator is
- * being constructed, a Spark job is asynchronously started to calculate the values for the
- * broadcasted relation. This data is then placed in a Spark broadcast variable. The streamed
- * relation is not shuffled.
- */
-case class BroadcastHashOuterJoin(
- leftKeys: Seq[Expression],
- rightKeys: Seq[Expression],
- joinType: JoinType,
- condition: Option[Expression],
- left: SparkPlan,
- right: SparkPlan) extends BinaryNode with HashOuterJoin {
-
- override private[sql] lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
-
- val timeout = {
- val timeoutValue = sqlContext.conf.broadcastTimeout
- if (timeoutValue < 0) {
- Duration.Inf
- } else {
- timeoutValue.seconds
- }
- }
-
- override def requiredChildDistribution: Seq[Distribution] =
- UnspecifiedDistribution :: UnspecifiedDistribution :: Nil
-
- override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning
-
- // Use lazy so that we won't do broadcast when calling explain but still cache the broadcast value
- // for the same query.
- @transient
- private lazy val broadcastFuture = {
- // broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here.
- val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
- Future {
- // This will run in another thread. Set the execution id so that we can connect these jobs
- // with the correct execution.
- SQLExecution.withExecutionId(sparkContext, executionId) {
- // Note that we use .execute().collect() because we don't want to convert data to Scala
- // types
- val input: Array[InternalRow] = buildPlan.execute().map { row =>
- row.copy()
- }.collect()
- val hashed = HashedRelation(input.iterator, buildKeyGenerator, input.size)
- sparkContext.broadcast(hashed)
- }
- }(BroadcastHashJoin.broadcastHashJoinExecutionContext)
- }
-
- protected override def doPrepare(): Unit = {
- broadcastFuture
- }
-
- override def doExecute(): RDD[InternalRow] = {
- val numOutputRows = longMetric("numOutputRows")
-
- val broadcastRelation = Await.result(broadcastFuture, timeout)
-
- streamedPlan.execute().mapPartitions { streamedIter =>
- val joinedRow = new JoinedRow()
- val hashTable = broadcastRelation.value
- val keyGenerator = streamedKeyGenerator
- TaskContext.get().taskMetrics().incPeakExecutionMemory(hashTable.getMemorySize)
-
- val resultProj = resultProjection
- joinType match {
- case LeftOuter =>
- streamedIter.flatMap(currentRow => {
- val rowKey = keyGenerator(currentRow)
- joinedRow.withLeft(currentRow)
- leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey), resultProj, numOutputRows)
- })
-
- case RightOuter =>
- streamedIter.flatMap(currentRow => {
- val rowKey = keyGenerator(currentRow)
- joinedRow.withRight(currentRow)
- rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj, numOutputRows)
- })
-
- case x =>
- throw new IllegalArgumentException(
- s"BroadcastHashOuterJoin should not take $x as the JoinType")
- }
- }
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index 332a748d3b..2fe9c06cc9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -21,20 +21,38 @@ import java.util.NoSuchElementException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.metric.LongSQLMetric
-import org.apache.spark.sql.types.{IntegralType, LongType}
+import org.apache.spark.sql.types.{IntegerType, IntegralType, LongType}
+import org.apache.spark.util.collection.CompactBuffer
trait HashJoin {
self: SparkPlan =>
val leftKeys: Seq[Expression]
val rightKeys: Seq[Expression]
+ val joinType: JoinType
val buildSide: BuildSide
val condition: Option[Expression]
val left: SparkPlan
val right: SparkPlan
+ override def output: Seq[Attribute] = {
+ joinType match {
+ case Inner =>
+ left.output ++ right.output
+ case LeftOuter =>
+ left.output ++ right.output.map(_.withNullability(true))
+ case RightOuter =>
+ left.output.map(_.withNullability(true)) ++ right.output
+ case FullOuter =>
+ left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
+ case x =>
+ throw new IllegalArgumentException(s"HashJoin should not take $x as the JoinType")
+ }
+ }
+
protected lazy val (buildPlan, streamedPlan) = buildSide match {
case BuildLeft => (left, right)
case BuildRight => (right, left)
@@ -45,8 +63,6 @@ trait HashJoin {
case BuildRight => (rightKeys, leftKeys)
}
- override def output: Seq[Attribute] = left.output ++ right.output
-
/**
* Try to rewrite the key as LongType so we can use getLong(), if they key can fit with a long.
*
@@ -67,8 +83,17 @@ trait HashJoin {
width = dt.defaultSize
} else {
val bits = dt.defaultSize * 8
+ // hashCode of Long is (l >> 32) ^ l.toInt, it means the hash code of an long with same
+ // value in high 32 bit and low 32 bit will be 0. To avoid the worst case that keys
+ // with two same ints have hash code 0, we rotate the bits of second one.
+ val rotated = if (e.dataType == IntegerType) {
+ // (e >>> 15) | (e << 17)
+ BitwiseOr(ShiftRightUnsigned(e, Literal(15)), ShiftLeft(e, Literal(17)))
+ } else {
+ e
+ }
keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)),
- BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1)))
+ BitwiseAnd(Cast(rotated, LongType), Literal((1L << bits) - 1)))
width -= bits
}
// TODO: support BooleanType, DateType and TimestampType
@@ -97,11 +122,13 @@ trait HashJoin {
(r: InternalRow) => true
}
+ protected def createResultProjection: (InternalRow) => InternalRow =
+ UnsafeProjection.create(self.schema)
+
protected def hashJoin(
streamIter: Iterator[InternalRow],
hashedRelation: HashedRelation,
- numOutputRows: LongSQLMetric): Iterator[InternalRow] =
- {
+ numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
new Iterator[InternalRow] {
private[this] var currentStreamedRow: InternalRow = _
private[this] var currentHashMatches: Seq[InternalRow] = _
@@ -109,8 +136,7 @@ trait HashJoin {
// Mutable per row objects.
private[this] val joinRow = new JoinedRow
- private[this] val resultProjection: (InternalRow) => InternalRow =
- UnsafeProjection.create(self.schema)
+ private[this] val resultProjection = createResultProjection
private[this] val joinKeys = streamSideKeyGenerator
@@ -163,4 +189,73 @@ trait HashJoin {
}
}
}
+
+ @transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]()
+
+ @transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length)
+ @transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length)
+
+ protected[this] def leftOuterIterator(
+ key: InternalRow,
+ joinedRow: JoinedRow,
+ rightIter: Iterable[InternalRow],
+ resultProjection: InternalRow => InternalRow,
+ numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
+ val ret: Iterable[InternalRow] = {
+ if (!key.anyNull) {
+ val temp = if (rightIter != null) {
+ rightIter.collect {
+ case r if boundCondition(joinedRow.withRight(r)) => {
+ numOutputRows += 1
+ resultProjection(joinedRow).copy()
+ }
+ }
+ } else {
+ List.empty
+ }
+ if (temp.isEmpty) {
+ numOutputRows += 1
+ resultProjection(joinedRow.withRight(rightNullRow)) :: Nil
+ } else {
+ temp
+ }
+ } else {
+ numOutputRows += 1
+ resultProjection(joinedRow.withRight(rightNullRow)) :: Nil
+ }
+ }
+ ret.iterator
+ }
+
+ protected[this] def rightOuterIterator(
+ key: InternalRow,
+ leftIter: Iterable[InternalRow],
+ joinedRow: JoinedRow,
+ resultProjection: InternalRow => InternalRow,
+ numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
+ val ret: Iterable[InternalRow] = {
+ if (!key.anyNull) {
+ val temp = if (leftIter != null) {
+ leftIter.collect {
+ case l if boundCondition(joinedRow.withLeft(l)) => {
+ numOutputRows += 1
+ resultProjection(joinedRow).copy()
+ }
+ }
+ } else {
+ List.empty
+ }
+ if (temp.isEmpty) {
+ numOutputRows += 1
+ resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil
+ } else {
+ temp
+ }
+ } else {
+ numOutputRows += 1
+ resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil
+ }
+ }
+ ret.iterator
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
deleted file mode 100644
index 9e614309de..0000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
+++ /dev/null
@@ -1,153 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.joins
-
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.execution.metric.LongSQLMetric
-import org.apache.spark.util.collection.CompactBuffer
-
-
-trait HashOuterJoin {
- self: SparkPlan =>
-
- val leftKeys: Seq[Expression]
- val rightKeys: Seq[Expression]
- val joinType: JoinType
- val condition: Option[Expression]
- val left: SparkPlan
- val right: SparkPlan
-
- override def output: Seq[Attribute] = {
- joinType match {
- case LeftOuter =>
- left.output ++ right.output.map(_.withNullability(true))
- case RightOuter =>
- left.output.map(_.withNullability(true)) ++ right.output
- case FullOuter =>
- left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
- case x =>
- throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType")
- }
- }
-
- protected[this] lazy val (buildPlan, streamedPlan) = joinType match {
- case RightOuter => (left, right)
- case LeftOuter => (right, left)
- case x =>
- throw new IllegalArgumentException(
- s"HashOuterJoin should not take $x as the JoinType")
- }
-
- protected[this] lazy val (buildKeys, streamedKeys) = joinType match {
- case RightOuter => (leftKeys, rightKeys)
- case LeftOuter => (rightKeys, leftKeys)
- case x =>
- throw new IllegalArgumentException(
- s"HashOuterJoin should not take $x as the JoinType")
- }
-
- protected def buildKeyGenerator: Projection =
- UnsafeProjection.create(buildKeys, buildPlan.output)
-
- protected[this] def streamedKeyGenerator: Projection =
- UnsafeProjection.create(streamedKeys, streamedPlan.output)
-
- protected[this] def resultProjection: InternalRow => InternalRow =
- UnsafeProjection.create(output, output)
-
- @transient private[this] lazy val DUMMY_LIST = CompactBuffer[InternalRow](null)
- @transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]()
-
- @transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length)
- @transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length)
- @transient private[this] lazy val boundCondition = if (condition.isDefined) {
- newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
- } else {
- (row: InternalRow) => true
- }
-
- // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala
- // iterator for performance purpose.
-
- protected[this] def leftOuterIterator(
- key: InternalRow,
- joinedRow: JoinedRow,
- rightIter: Iterable[InternalRow],
- resultProjection: InternalRow => InternalRow,
- numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
- val ret: Iterable[InternalRow] = {
- if (!key.anyNull) {
- val temp = if (rightIter != null) {
- rightIter.collect {
- case r if boundCondition(joinedRow.withRight(r)) => {
- numOutputRows += 1
- resultProjection(joinedRow).copy()
- }
- }
- } else {
- List.empty
- }
- if (temp.isEmpty) {
- numOutputRows += 1
- resultProjection(joinedRow.withRight(rightNullRow)) :: Nil
- } else {
- temp
- }
- } else {
- numOutputRows += 1
- resultProjection(joinedRow.withRight(rightNullRow)) :: Nil
- }
- }
- ret.iterator
- }
-
- protected[this] def rightOuterIterator(
- key: InternalRow,
- leftIter: Iterable[InternalRow],
- joinedRow: JoinedRow,
- resultProjection: InternalRow => InternalRow,
- numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
- val ret: Iterable[InternalRow] = {
- if (!key.anyNull) {
- val temp = if (leftIter != null) {
- leftIter.collect {
- case l if boundCondition(joinedRow.withLeft(l)) => {
- numOutputRows += 1
- resultProjection(joinedRow).copy()
- }
- }
- } else {
- List.empty
- }
- if (temp.isEmpty) {
- numOutputRows += 1
- resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil
- } else {
- temp
- }
- } else {
- numOutputRows += 1
- resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil
- }
- }
- ret.iterator
- }
-}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 9a3c262e94..92ff7e73fa 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -46,7 +46,6 @@ class JoinSuite extends QueryTest with SharedSQLContext {
val operators = physical.collect {
case j: LeftSemiJoinHash => j
case j: BroadcastHashJoin => j
- case j: BroadcastHashOuterJoin => j
case j: LeftSemiJoinBNL => j
case j: CartesianProduct => j
case j: BroadcastNestedLoopJoin => j
@@ -123,9 +122,9 @@ class JoinSuite extends QueryTest with SharedSQLContext {
("SELECT * FROM testData LEFT JOIN testData2 ON key = a",
classOf[SortMergeOuterJoin]),
("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2",
- classOf[BroadcastHashOuterJoin]),
+ classOf[BroadcastHashJoin]),
("SELECT * FROM testData right join testData2 ON key = a and key = 2",
- classOf[BroadcastHashOuterJoin])
+ classOf[BroadcastHashJoin])
).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
sql("UNCACHE TABLE testData")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index 4a151179bf..bcac660a35 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution
+import java.util.HashMap
+
import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager}
import org.apache.spark.sql.SQLContext
@@ -124,37 +126,65 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
ignore("broadcast hash join") {
val N = 100 << 20
- val dim = broadcast(sqlContext.range(1 << 16).selectExpr("id as k", "cast(id as string) as v"))
+ val M = 1 << 16
+ val dim = broadcast(sqlContext.range(M).selectExpr("id as k", "cast(id as string) as v"))
runBenchmark("Join w long", N) {
- sqlContext.range(N).join(dim, (col("id") % 60000) === col("k")).count()
+ sqlContext.range(N).join(dim, (col("id") bitwiseAND M) === col("k")).count()
}
/*
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
- BroadcastHashJoin: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ Join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
- Join w long codegen=false 10174 / 10317 10.0 100.0 1.0X
- Join w long codegen=true 1069 / 1107 98.0 10.2 9.5X
+ Join w long codegen=false 5744 / 5814 18.3 54.8 1.0X
+ Join w long codegen=true 735 / 853 142.7 7.0 7.8X
*/
- val dim2 = broadcast(sqlContext.range(1 << 16)
+ val dim2 = broadcast(sqlContext.range(M)
.selectExpr("cast(id as int) as k1", "cast(id as int) as k2", "cast(id as string) as v"))
runBenchmark("Join w 2 ints", N) {
sqlContext.range(N).join(dim2,
- (col("id") bitwiseAND 60000).cast(IntegerType) === col("k1")
- && (col("id") bitwiseAND 50000).cast(IntegerType) === col("k2")).count()
+ (col("id") bitwiseAND M).cast(IntegerType) === col("k1")
+ && (col("id") bitwiseAND M).cast(IntegerType) === col("k2")).count()
}
/**
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
- BroadcastHashJoin: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ Join w 2 ints: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
- Join w 2 ints codegen=false 11435 / 11530 9.0 111.1 1.0X
- Join w 2 ints codegen=true 1265 / 1424 82.0 12.2 9.0X
+ Join w 2 ints codegen=false 7159 / 7224 14.6 68.3 1.0X
+ Join w 2 ints codegen=true 1135 / 1197 92.4 10.8 6.3X
*/
+ val dim3 = broadcast(sqlContext.range(M)
+ .selectExpr("id as k1", "id as k2", "cast(id as string) as v"))
+
+ runBenchmark("Join w 2 longs", N) {
+ sqlContext.range(N).join(dim3,
+ (col("id") bitwiseAND M) === col("k1") && (col("id") bitwiseAND M) === col("k2"))
+ .count()
+ }
+
+ /**
+ Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+ Join w 2 longs: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ -------------------------------------------------------------------------------------------
+ Join w 2 longs codegen=false 7877 / 8358 13.3 75.1 1.0X
+ Join w 2 longs codegen=true 3877 / 3937 27.0 37.0 2.0X
+ */
+ runBenchmark("outer join w long", N) {
+ sqlContext.range(N).join(dim, (col("id") bitwiseAND M) === col("k"), "left").count()
+ }
+
+ /**
+ Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+ outer join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ -------------------------------------------------------------------------------------------
+ outer join w long codegen=false 15280 / 16497 6.9 145.7 1.0X
+ outer join w long codegen=true 769 / 796 136.3 7.3 19.9X
+ */
}
ignore("rube") {
@@ -175,7 +205,7 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
}
ignore("hash and BytesToBytesMap") {
- val N = 50 << 20
+ val N = 10 << 20
val benchmark = new Benchmark("BytesToBytesMap", N)
@@ -227,6 +257,80 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
}
}
+ benchmark.addCase("Java HashMap (Long)") { iter =>
+ var i = 0
+ val keyBytes = new Array[Byte](16)
+ val valueBytes = new Array[Byte](16)
+ val value = new UnsafeRow(1)
+ value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16)
+ value.setInt(0, 555)
+ val map = new HashMap[Long, UnsafeRow]()
+ while (i < 65536) {
+ value.setInt(0, i)
+ map.put(i.toLong, value)
+ i += 1
+ }
+ var s = 0
+ i = 0
+ while (i < N) {
+ if (map.get(i % 100000) != null) {
+ s += 1
+ }
+ i += 1
+ }
+ }
+
+ benchmark.addCase("Java HashMap (two ints) ") { iter =>
+ var i = 0
+ val valueBytes = new Array[Byte](16)
+ val value = new UnsafeRow(1)
+ value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16)
+ value.setInt(0, 555)
+ val map = new HashMap[Long, UnsafeRow]()
+ while (i < 65536) {
+ value.setInt(0, i)
+ val key = (i.toLong << 32) + Integer.rotateRight(i, 15)
+ map.put(key, value)
+ i += 1
+ }
+ var s = 0
+ i = 0
+ while (i < N) {
+ val key = ((i & 100000).toLong << 32) + Integer.rotateRight(i & 100000, 15)
+ if (map.get(key) != null) {
+ s += 1
+ }
+ i += 1
+ }
+ }
+
+ benchmark.addCase("Java HashMap (UnsafeRow)") { iter =>
+ var i = 0
+ val keyBytes = new Array[Byte](16)
+ val valueBytes = new Array[Byte](16)
+ val key = new UnsafeRow(1)
+ key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16)
+ val value = new UnsafeRow(1)
+ value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16)
+ value.setInt(0, 555)
+ val map = new HashMap[UnsafeRow, UnsafeRow]()
+ while (i < 65536) {
+ key.setInt(0, i)
+ value.setInt(0, i)
+ map.put(key, value.copy())
+ i += 1
+ }
+ var s = 0
+ i = 0
+ while (i < N) {
+ key.setInt(0, i % 100000)
+ if (map.get(key) != null) {
+ s += 1
+ }
+ i += 1
+ }
+ }
+
Seq("off", "on").foreach { heap =>
benchmark.addCase(s"BytesToBytesMap ($heap Heap)") { iter =>
val taskMemoryManager = new TaskMemoryManager(
@@ -268,6 +372,9 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
hash 651 / 678 80.0 12.5 1.0X
fast hash 336 / 343 155.9 6.4 1.9X
arrayEqual 417 / 428 125.0 8.0 1.6X
+ Java HashMap (Long) 145 / 168 72.2 13.8 0.8X
+ Java HashMap (two ints) 157 / 164 66.8 15.0 0.8X
+ Java HashMap (UnsafeRow) 538 / 573 19.5 51.3 0.2X
BytesToBytesMap (off Heap) 2594 / 2664 20.2 49.5 0.2X
BytesToBytesMap (on Heap) 2693 / 2989 19.5 51.4 0.2X
*/
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
index aee8e84db5..e25b5e0610 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
@@ -73,7 +73,7 @@ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll {
}
test("unsafe broadcast hash outer join updates peak execution memory") {
- testBroadcastJoin[BroadcastHashOuterJoin]("unsafe broadcast hash outer join", "left_outer")
+ testBroadcastJoin[BroadcastHashJoin]("unsafe broadcast hash outer join", "left_outer")
}
test("unsafe broadcast left semi join updates peak execution memory") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
index 149f34dbd7..e22a810a6b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
@@ -88,7 +88,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
leftPlan: SparkPlan,
rightPlan: SparkPlan,
side: BuildSide) = {
- joins.BroadcastHashJoin(leftKeys, rightKeys, side, boundCondition, leftPlan, rightPlan)
+ joins.BroadcastHashJoin(leftKeys, rightKeys, Inner, side, boundCondition, leftPlan, rightPlan)
}
def makeSortMergeJoin(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
index 3d3e9a7b90..f4b01fbad0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
@@ -75,11 +75,16 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext {
}
if (joinType != FullOuter) {
- test(s"$testName using BroadcastHashOuterJoin") {
+ test(s"$testName using BroadcastHashJoin") {
+ val buildSide = joinType match {
+ case LeftOuter => BuildRight
+ case RightOuter => BuildLeft
+ }
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
- BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right),
+ BroadcastHashJoin(
+ leftKeys, rightKeys, joinType, buildSide, boundCondition, left, right),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index f4bc9e501c..46bb699b78 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -209,20 +209,20 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
)
}
- test("BroadcastHashOuterJoin metrics") {
+ test("BroadcastHashJoin(outer) metrics") {
val df1 = Seq((1, "a"), (1, "b"), (4, "c")).toDF("key", "value")
val df2 = Seq((1, "a"), (1, "b"), (2, "c"), (3, "d")).toDF("key2", "value")
// Assume the execution plan is
- // ... -> BroadcastHashOuterJoin(nodeId = 0)
+ // ... -> BroadcastHashJoin(nodeId = 0)
val df = df1.join(broadcast(df2), $"key" === $"key2", "left_outer")
testSparkPlanMetrics(df, 2, Map(
- 0L -> ("BroadcastHashOuterJoin", Map(
+ 0L -> ("BroadcastHashJoin", Map(
"number of output rows" -> 5L)))
)
val df3 = df1.join(broadcast(df2), $"key" === $"key2", "right_outer")
testSparkPlanMetrics(df3, 2, Map(
- 0L -> ("BroadcastHashOuterJoin", Map(
+ 0L -> ("BroadcastHashJoin", Map(
"number of output rows" -> 6L)))
)
}