aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala16
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala253
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala121
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala111
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala153
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala5
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala131
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala9
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala8
12 files changed, 448 insertions, 371 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 042c99db4d..382654afac 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -108,12 +108,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
// --- Inner joins --------------------------------------------------------------------------
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, CanBroadcast(right)) =>
- joins.BroadcastHashJoin(
- leftKeys, rightKeys, BuildRight, condition, planLater(left), planLater(right)) :: Nil
+ Seq(joins.BroadcastHashJoin(
+ leftKeys, rightKeys, Inner, BuildRight, condition, planLater(left), planLater(right)))
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, CanBroadcast(left), right) =>
- joins.BroadcastHashJoin(
- leftKeys, rightKeys, BuildLeft, condition, planLater(left), planLater(right)) :: Nil
+ Seq(joins.BroadcastHashJoin(
+ leftKeys, rightKeys, Inner, BuildLeft, condition, planLater(left), planLater(right)))
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
if RowOrdering.isOrderable(leftKeys) =>
@@ -124,13 +124,13 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case ExtractEquiJoinKeys(
LeftOuter, leftKeys, rightKeys, condition, left, CanBroadcast(right)) =>
- joins.BroadcastHashOuterJoin(
- leftKeys, rightKeys, LeftOuter, condition, planLater(left), planLater(right)) :: Nil
+ Seq(joins.BroadcastHashJoin(
+ leftKeys, rightKeys, LeftOuter, BuildRight, condition, planLater(left), planLater(right)))
case ExtractEquiJoinKeys(
RightOuter, leftKeys, rightKeys, condition, CanBroadcast(left), right) =>
- joins.BroadcastHashOuterJoin(
- leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil
+ Seq(joins.BroadcastHashJoin(
+ leftKeys, rightKeys, RightOuter, BuildLeft, condition, planLater(left), planLater(right)))
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
if RowOrdering.isOrderable(leftKeys) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index f35efb5b24..8626f54eb4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.aggregate.TungstenAggregate
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, BuildLeft, BuildRight}
-import org.apache.spark.sql.execution.metric.{LongSQLMetric, LongSQLMetricValue, SQLMetric}
+import org.apache.spark.sql.execution.metric.LongSQLMetricValue
/**
* An interface for those physical operators that support codegen.
@@ -38,7 +38,7 @@ trait CodegenSupport extends SparkPlan {
/** Prefix used in the current operator's variable names. */
private def variablePrefix: String = this match {
case _: TungstenAggregate => "agg"
- case _: BroadcastHashJoin => "bhj"
+ case _: BroadcastHashJoin => "join"
case _ => nodeName.toLowerCase
}
@@ -391,9 +391,9 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru
var inputs = ArrayBuffer[SparkPlan]()
val combined = plan.transform {
// The build side can't be compiled together
- case b @ BroadcastHashJoin(_, _, BuildLeft, _, left, right) =>
+ case b @ BroadcastHashJoin(_, _, _, BuildLeft, _, left, right) =>
b.copy(left = apply(left))
- case b @ BroadcastHashJoin(_, _, BuildRight, _, left, right) =>
+ case b @ BroadcastHashJoin(_, _, _, BuildRight, _, left, right) =>
b.copy(right = apply(right))
case p if !supportCodegen(p) =>
val input = apply(p) // collapse them recursively
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index 985e74011d..a64da22580 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -24,8 +24,9 @@ import org.apache.spark.TaskContext
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{BindReferences, BoundReference, Expression, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection}
+import org.apache.spark.sql.catalyst.plans.{Inner, JoinType, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.metric.SQLMetrics
@@ -41,6 +42,7 @@ import org.apache.spark.util.collection.CompactBuffer
case class BroadcastHashJoin(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
+ joinType: JoinType,
buildSide: BuildSide,
condition: Option[Expression],
left: SparkPlan,
@@ -105,75 +107,144 @@ case class BroadcastHashJoin(
val broadcastRelation = Await.result(broadcastFuture, timeout)
streamedPlan.execute().mapPartitions { streamedIter =>
- val hashedRelation = broadcastRelation.value
- TaskContext.get().taskMetrics().incPeakExecutionMemory(hashedRelation.getMemorySize)
- hashJoin(streamedIter, hashedRelation, numOutputRows)
+ val joinedRow = new JoinedRow()
+ val hashTable = broadcastRelation.value
+ TaskContext.get().taskMetrics().incPeakExecutionMemory(hashTable.getMemorySize)
+ val keyGenerator = streamSideKeyGenerator
+ val resultProj = createResultProjection
+
+ joinType match {
+ case Inner =>
+ hashJoin(streamedIter, hashTable, numOutputRows)
+
+ case LeftOuter =>
+ streamedIter.flatMap { currentRow =>
+ val rowKey = keyGenerator(currentRow)
+ joinedRow.withLeft(currentRow)
+ leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey), resultProj, numOutputRows)
+ }
+
+ case RightOuter =>
+ streamedIter.flatMap { currentRow =>
+ val rowKey = keyGenerator(currentRow)
+ joinedRow.withRight(currentRow)
+ rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj, numOutputRows)
+ }
+
+ case x =>
+ throw new IllegalArgumentException(
+ s"BroadcastHashJoin should not take $x as the JoinType")
+ }
}
}
- private var broadcastRelation: Broadcast[HashedRelation] = _
- // the term for hash relation
- private var relationTerm: String = _
-
override def upstream(): RDD[InternalRow] = {
streamedPlan.asInstanceOf[CodegenSupport].upstream()
}
override def doProduce(ctx: CodegenContext): String = {
+ streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this)
+ }
+
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ if (joinType == Inner) {
+ codegenInner(ctx, input)
+ } else {
+ // LeftOuter and RightOuter
+ codegenOuter(ctx, input)
+ }
+ }
+
+ /**
+ * Returns a tuple of Broadcast of HashedRelation and the variable name for it.
+ */
+ private def prepareBroadcast(ctx: CodegenContext): (Broadcast[HashedRelation], String) = {
// create a name for HashedRelation
- broadcastRelation = Await.result(broadcastFuture, timeout)
+ val broadcastRelation = Await.result(broadcastFuture, timeout)
val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation)
- relationTerm = ctx.freshName("relation")
+ val relationTerm = ctx.freshName("relation")
val clsName = broadcastRelation.value.getClass.getName
ctx.addMutableState(clsName, relationTerm,
s"""
| $relationTerm = ($clsName) $broadcast.value();
| incPeakExecutionMemory($relationTerm.getMemorySize());
""".stripMargin)
-
- s"""
- | ${streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this)}
- """.stripMargin
+ (broadcastRelation, relationTerm)
}
- override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
- // generate the key as UnsafeRow or Long
+ /**
+ * Returns the code for generating join key for stream side, and expression of whether the key
+ * has any null in it or not.
+ */
+ private def genStreamSideJoinKey(
+ ctx: CodegenContext,
+ input: Seq[ExprCode]): (ExprCode, String) = {
ctx.currentVars = input
- val (keyVal, anyNull) = if (canJoinKeyFitWithinLong) {
+ if (canJoinKeyFitWithinLong) {
+ // generate the join key as Long
val expr = rewriteKeyExpr(streamedKeys).head
val ev = BindReferences.bindReference(expr, streamedPlan.output).gen(ctx)
(ev, ev.isNull)
} else {
+ // generate the join key as UnsafeRow
val keyExpr = streamedKeys.map(BindReferences.bindReference(_, streamedPlan.output))
val ev = GenerateUnsafeProjection.createCode(ctx, keyExpr)
(ev, s"${ev.value}.anyNull()")
}
+ }
- // find the matches from HashedRelation
- val matched = ctx.freshName("matched")
-
- // create variables for output
+ /**
+ * Generates the code for variable of build side.
+ */
+ private def genBuildSideVars(ctx: CodegenContext, matched: String): Seq[ExprCode] = {
ctx.currentVars = null
ctx.INPUT_ROW = matched
- val buildColumns = buildPlan.output.zipWithIndex.map { case (a, i) =>
- BoundReference(i, a.dataType, a.nullable).gen(ctx)
+ buildPlan.output.zipWithIndex.map { case (a, i) =>
+ val ev = BoundReference(i, a.dataType, a.nullable).gen(ctx)
+ if (joinType == Inner) {
+ ev
+ } else {
+ // the variables are needed even there is no matched rows
+ val isNull = ctx.freshName("isNull")
+ val value = ctx.freshName("value")
+ val code = s"""
+ |boolean $isNull = true;
+ |${ctx.javaType(a.dataType)} $value = ${ctx.defaultValue(a.dataType)};
+ |if ($matched != null) {
+ | ${ev.code}
+ | $isNull = ${ev.isNull};
+ | $value = ${ev.value};
+ |}
+ """.stripMargin
+ ExprCode(code, isNull, value)
+ }
}
+ }
+
+ /**
+ * Generates the code for Inner join.
+ */
+ private def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
+ val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+ val matched = ctx.freshName("matched")
+ val buildVars = genBuildSideVars(ctx, matched)
val resultVars = buildSide match {
- case BuildLeft => buildColumns ++ input
- case BuildRight => input ++ buildColumns
+ case BuildLeft => buildVars ++ input
+ case BuildRight => input ++ buildVars
}
-
val numOutput = metricTerm(ctx, "numOutputRows")
+
val outputCode = if (condition.isDefined) {
// filter the output via condition
ctx.currentVars = resultVars
val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx)
s"""
- | ${ev.code}
- | if (!${ev.isNull} && ${ev.value}) {
- | $numOutput.add(1);
- | ${consume(ctx, resultVars)}
- | }
+ |${ev.code}
+ |if (!${ev.isNull} && ${ev.value}) {
+ | $numOutput.add(1);
+ | ${consume(ctx, resultVars)}
+ |}
""".stripMargin
} else {
s"""
@@ -184,36 +255,110 @@ case class BroadcastHashJoin(
if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) {
s"""
- | // generate join key
- | ${keyVal.code}
- | // find matches from HashedRelation
- | UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyVal.value});
- | if ($matched != null) {
- | ${buildColumns.map(_.code).mkString("\n")}
- | $outputCode
- | }
- """.stripMargin
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashedRelation
+ |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value});
+ |if ($matched != null) {
+ | ${buildVars.map(_.code).mkString("\n")}
+ | $outputCode
+ |}
+ """.stripMargin
+
+ } else {
+ val matches = ctx.freshName("matches")
+ val bufferType = classOf[CompactBuffer[UnsafeRow]].getName
+ val i = ctx.freshName("i")
+ val size = ctx.freshName("size")
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashRelation
+ |$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value});
+ |if ($matches != null) {
+ | int $size = $matches.size();
+ | for (int $i = 0; $i < $size; $i++) {
+ | UnsafeRow $matched = (UnsafeRow) $matches.apply($i);
+ | ${buildVars.map(_.code).mkString("\n")}
+ | $outputCode
+ | }
+ |}
+ """.stripMargin
+ }
+ }
+
+
+ /**
+ * Generates the code for left or right outer join.
+ */
+ private def codegenOuter(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
+ val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+ val matched = ctx.freshName("matched")
+ val buildVars = genBuildSideVars(ctx, matched)
+ val resultVars = buildSide match {
+ case BuildLeft => buildVars ++ input
+ case BuildRight => input ++ buildVars
+ }
+ val numOutput = metricTerm(ctx, "numOutputRows")
+
+ // filter the output via condition
+ val conditionPassed = ctx.freshName("conditionPassed")
+ val checkCondition = if (condition.isDefined) {
+ ctx.currentVars = resultVars
+ val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx)
+ s"""
+ |boolean $conditionPassed = true;
+ |if ($matched != null) {
+ | ${ev.code}
+ | $conditionPassed = !${ev.isNull} && ${ev.value};
+ |}
+ """.stripMargin
+ } else {
+ s"final boolean $conditionPassed = true;"
+ }
+
+ if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) {
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashedRelation
+ |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value});
+ |${buildVars.map(_.code).mkString("\n")}
+ |${checkCondition.trim}
+ |if (!$conditionPassed) {
+ | // reset to null
+ | ${buildVars.map(v => s"${v.isNull} = true;").mkString("\n")}
+ |}
+ |$numOutput.add(1);
+ |${consume(ctx, resultVars)}
+ """.stripMargin
} else {
val matches = ctx.freshName("matches")
val bufferType = classOf[CompactBuffer[UnsafeRow]].getName
val i = ctx.freshName("i")
val size = ctx.freshName("size")
+ val found = ctx.freshName("found")
s"""
- | // generate join key
- | ${keyVal.code}
- | // find matches from HashRelation
- | $bufferType $matches = ${anyNull} ? null :
- | ($bufferType) $relationTerm.get(${keyVal.value});
- | if ($matches != null) {
- | int $size = $matches.size();
- | for (int $i = 0; $i < $size; $i++) {
- | UnsafeRow $matched = (UnsafeRow) $matches.apply($i);
- | ${buildColumns.map(_.code).mkString("\n")}
- | $outputCode
- | }
- | }
- """.stripMargin
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashRelation
+ |$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value});
+ |int $size = $matches != null ? $matches.size() : 0;
+ |boolean $found = false;
+ |// the last iteration of this loop is to emit an empty row if there is no matched rows.
+ |for (int $i = 0; $i <= $size; $i++) {
+ | UnsafeRow $matched = $i < $size ? (UnsafeRow) $matches.apply($i) : null;
+ | ${buildVars.map(_.code).mkString("\n")}
+ | ${checkCondition.trim}
+ | if ($conditionPassed && ($i < $size || !$found)) {
+ | $found = true;
+ | $numOutput.add(1);
+ | ${consume(ctx, resultVars)}
+ | }
+ |}
+ """.stripMargin
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
deleted file mode 100644
index 5e8c8ca043..0000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
+++ /dev/null
@@ -1,121 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.joins
-
-import scala.concurrent._
-import scala.concurrent.duration._
-
-import org.apache.spark.{InternalAccumulator, TaskContext}
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter}
-import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution}
-import org.apache.spark.sql.execution.{BinaryNode, SparkPlan, SQLExecution}
-import org.apache.spark.sql.execution.metric.SQLMetrics
-
-/**
- * Performs a outer hash join for two child relations. When the output RDD of this operator is
- * being constructed, a Spark job is asynchronously started to calculate the values for the
- * broadcasted relation. This data is then placed in a Spark broadcast variable. The streamed
- * relation is not shuffled.
- */
-case class BroadcastHashOuterJoin(
- leftKeys: Seq[Expression],
- rightKeys: Seq[Expression],
- joinType: JoinType,
- condition: Option[Expression],
- left: SparkPlan,
- right: SparkPlan) extends BinaryNode with HashOuterJoin {
-
- override private[sql] lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
-
- val timeout = {
- val timeoutValue = sqlContext.conf.broadcastTimeout
- if (timeoutValue < 0) {
- Duration.Inf
- } else {
- timeoutValue.seconds
- }
- }
-
- override def requiredChildDistribution: Seq[Distribution] =
- UnspecifiedDistribution :: UnspecifiedDistribution :: Nil
-
- override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning
-
- // Use lazy so that we won't do broadcast when calling explain but still cache the broadcast value
- // for the same query.
- @transient
- private lazy val broadcastFuture = {
- // broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here.
- val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
- Future {
- // This will run in another thread. Set the execution id so that we can connect these jobs
- // with the correct execution.
- SQLExecution.withExecutionId(sparkContext, executionId) {
- // Note that we use .execute().collect() because we don't want to convert data to Scala
- // types
- val input: Array[InternalRow] = buildPlan.execute().map { row =>
- row.copy()
- }.collect()
- val hashed = HashedRelation(input.iterator, buildKeyGenerator, input.size)
- sparkContext.broadcast(hashed)
- }
- }(BroadcastHashJoin.broadcastHashJoinExecutionContext)
- }
-
- protected override def doPrepare(): Unit = {
- broadcastFuture
- }
-
- override def doExecute(): RDD[InternalRow] = {
- val numOutputRows = longMetric("numOutputRows")
-
- val broadcastRelation = Await.result(broadcastFuture, timeout)
-
- streamedPlan.execute().mapPartitions { streamedIter =>
- val joinedRow = new JoinedRow()
- val hashTable = broadcastRelation.value
- val keyGenerator = streamedKeyGenerator
- TaskContext.get().taskMetrics().incPeakExecutionMemory(hashTable.getMemorySize)
-
- val resultProj = resultProjection
- joinType match {
- case LeftOuter =>
- streamedIter.flatMap(currentRow => {
- val rowKey = keyGenerator(currentRow)
- joinedRow.withLeft(currentRow)
- leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey), resultProj, numOutputRows)
- })
-
- case RightOuter =>
- streamedIter.flatMap(currentRow => {
- val rowKey = keyGenerator(currentRow)
- joinedRow.withRight(currentRow)
- rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj, numOutputRows)
- })
-
- case x =>
- throw new IllegalArgumentException(
- s"BroadcastHashOuterJoin should not take $x as the JoinType")
- }
- }
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index 332a748d3b..2fe9c06cc9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -21,20 +21,38 @@ import java.util.NoSuchElementException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.metric.LongSQLMetric
-import org.apache.spark.sql.types.{IntegralType, LongType}
+import org.apache.spark.sql.types.{IntegerType, IntegralType, LongType}
+import org.apache.spark.util.collection.CompactBuffer
trait HashJoin {
self: SparkPlan =>
val leftKeys: Seq[Expression]
val rightKeys: Seq[Expression]
+ val joinType: JoinType
val buildSide: BuildSide
val condition: Option[Expression]
val left: SparkPlan
val right: SparkPlan
+ override def output: Seq[Attribute] = {
+ joinType match {
+ case Inner =>
+ left.output ++ right.output
+ case LeftOuter =>
+ left.output ++ right.output.map(_.withNullability(true))
+ case RightOuter =>
+ left.output.map(_.withNullability(true)) ++ right.output
+ case FullOuter =>
+ left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
+ case x =>
+ throw new IllegalArgumentException(s"HashJoin should not take $x as the JoinType")
+ }
+ }
+
protected lazy val (buildPlan, streamedPlan) = buildSide match {
case BuildLeft => (left, right)
case BuildRight => (right, left)
@@ -45,8 +63,6 @@ trait HashJoin {
case BuildRight => (rightKeys, leftKeys)
}
- override def output: Seq[Attribute] = left.output ++ right.output
-
/**
* Try to rewrite the key as LongType so we can use getLong(), if they key can fit with a long.
*
@@ -67,8 +83,17 @@ trait HashJoin {
width = dt.defaultSize
} else {
val bits = dt.defaultSize * 8
+ // hashCode of Long is (l >> 32) ^ l.toInt, it means the hash code of an long with same
+ // value in high 32 bit and low 32 bit will be 0. To avoid the worst case that keys
+ // with two same ints have hash code 0, we rotate the bits of second one.
+ val rotated = if (e.dataType == IntegerType) {
+ // (e >>> 15) | (e << 17)
+ BitwiseOr(ShiftRightUnsigned(e, Literal(15)), ShiftLeft(e, Literal(17)))
+ } else {
+ e
+ }
keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)),
- BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1)))
+ BitwiseAnd(Cast(rotated, LongType), Literal((1L << bits) - 1)))
width -= bits
}
// TODO: support BooleanType, DateType and TimestampType
@@ -97,11 +122,13 @@ trait HashJoin {
(r: InternalRow) => true
}
+ protected def createResultProjection: (InternalRow) => InternalRow =
+ UnsafeProjection.create(self.schema)
+
protected def hashJoin(
streamIter: Iterator[InternalRow],
hashedRelation: HashedRelation,
- numOutputRows: LongSQLMetric): Iterator[InternalRow] =
- {
+ numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
new Iterator[InternalRow] {
private[this] var currentStreamedRow: InternalRow = _
private[this] var currentHashMatches: Seq[InternalRow] = _
@@ -109,8 +136,7 @@ trait HashJoin {
// Mutable per row objects.
private[this] val joinRow = new JoinedRow
- private[this] val resultProjection: (InternalRow) => InternalRow =
- UnsafeProjection.create(self.schema)
+ private[this] val resultProjection = createResultProjection
private[this] val joinKeys = streamSideKeyGenerator
@@ -163,4 +189,73 @@ trait HashJoin {
}
}
}
+
+ @transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]()
+
+ @transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length)
+ @transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length)
+
+ protected[this] def leftOuterIterator(
+ key: InternalRow,
+ joinedRow: JoinedRow,
+ rightIter: Iterable[InternalRow],
+ resultProjection: InternalRow => InternalRow,
+ numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
+ val ret: Iterable[InternalRow] = {
+ if (!key.anyNull) {
+ val temp = if (rightIter != null) {
+ rightIter.collect {
+ case r if boundCondition(joinedRow.withRight(r)) => {
+ numOutputRows += 1
+ resultProjection(joinedRow).copy()
+ }
+ }
+ } else {
+ List.empty
+ }
+ if (temp.isEmpty) {
+ numOutputRows += 1
+ resultProjection(joinedRow.withRight(rightNullRow)) :: Nil
+ } else {
+ temp
+ }
+ } else {
+ numOutputRows += 1
+ resultProjection(joinedRow.withRight(rightNullRow)) :: Nil
+ }
+ }
+ ret.iterator
+ }
+
+ protected[this] def rightOuterIterator(
+ key: InternalRow,
+ leftIter: Iterable[InternalRow],
+ joinedRow: JoinedRow,
+ resultProjection: InternalRow => InternalRow,
+ numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
+ val ret: Iterable[InternalRow] = {
+ if (!key.anyNull) {
+ val temp = if (leftIter != null) {
+ leftIter.collect {
+ case l if boundCondition(joinedRow.withLeft(l)) => {
+ numOutputRows += 1
+ resultProjection(joinedRow).copy()
+ }
+ }
+ } else {
+ List.empty
+ }
+ if (temp.isEmpty) {
+ numOutputRows += 1
+ resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil
+ } else {
+ temp
+ }
+ } else {
+ numOutputRows += 1
+ resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil
+ }
+ }
+ ret.iterator
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
deleted file mode 100644
index 9e614309de..0000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
+++ /dev/null
@@ -1,153 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.joins
-
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.execution.metric.LongSQLMetric
-import org.apache.spark.util.collection.CompactBuffer
-
-
-trait HashOuterJoin {
- self: SparkPlan =>
-
- val leftKeys: Seq[Expression]
- val rightKeys: Seq[Expression]
- val joinType: JoinType
- val condition: Option[Expression]
- val left: SparkPlan
- val right: SparkPlan
-
- override def output: Seq[Attribute] = {
- joinType match {
- case LeftOuter =>
- left.output ++ right.output.map(_.withNullability(true))
- case RightOuter =>
- left.output.map(_.withNullability(true)) ++ right.output
- case FullOuter =>
- left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
- case x =>
- throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType")
- }
- }
-
- protected[this] lazy val (buildPlan, streamedPlan) = joinType match {
- case RightOuter => (left, right)
- case LeftOuter => (right, left)
- case x =>
- throw new IllegalArgumentException(
- s"HashOuterJoin should not take $x as the JoinType")
- }
-
- protected[this] lazy val (buildKeys, streamedKeys) = joinType match {
- case RightOuter => (leftKeys, rightKeys)
- case LeftOuter => (rightKeys, leftKeys)
- case x =>
- throw new IllegalArgumentException(
- s"HashOuterJoin should not take $x as the JoinType")
- }
-
- protected def buildKeyGenerator: Projection =
- UnsafeProjection.create(buildKeys, buildPlan.output)
-
- protected[this] def streamedKeyGenerator: Projection =
- UnsafeProjection.create(streamedKeys, streamedPlan.output)
-
- protected[this] def resultProjection: InternalRow => InternalRow =
- UnsafeProjection.create(output, output)
-
- @transient private[this] lazy val DUMMY_LIST = CompactBuffer[InternalRow](null)
- @transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]()
-
- @transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length)
- @transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length)
- @transient private[this] lazy val boundCondition = if (condition.isDefined) {
- newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
- } else {
- (row: InternalRow) => true
- }
-
- // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala
- // iterator for performance purpose.
-
- protected[this] def leftOuterIterator(
- key: InternalRow,
- joinedRow: JoinedRow,
- rightIter: Iterable[InternalRow],
- resultProjection: InternalRow => InternalRow,
- numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
- val ret: Iterable[InternalRow] = {
- if (!key.anyNull) {
- val temp = if (rightIter != null) {
- rightIter.collect {
- case r if boundCondition(joinedRow.withRight(r)) => {
- numOutputRows += 1
- resultProjection(joinedRow).copy()
- }
- }
- } else {
- List.empty
- }
- if (temp.isEmpty) {
- numOutputRows += 1
- resultProjection(joinedRow.withRight(rightNullRow)) :: Nil
- } else {
- temp
- }
- } else {
- numOutputRows += 1
- resultProjection(joinedRow.withRight(rightNullRow)) :: Nil
- }
- }
- ret.iterator
- }
-
- protected[this] def rightOuterIterator(
- key: InternalRow,
- leftIter: Iterable[InternalRow],
- joinedRow: JoinedRow,
- resultProjection: InternalRow => InternalRow,
- numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
- val ret: Iterable[InternalRow] = {
- if (!key.anyNull) {
- val temp = if (leftIter != null) {
- leftIter.collect {
- case l if boundCondition(joinedRow.withLeft(l)) => {
- numOutputRows += 1
- resultProjection(joinedRow).copy()
- }
- }
- } else {
- List.empty
- }
- if (temp.isEmpty) {
- numOutputRows += 1
- resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil
- } else {
- temp
- }
- } else {
- numOutputRows += 1
- resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil
- }
- }
- ret.iterator
- }
-}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 9a3c262e94..92ff7e73fa 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -46,7 +46,6 @@ class JoinSuite extends QueryTest with SharedSQLContext {
val operators = physical.collect {
case j: LeftSemiJoinHash => j
case j: BroadcastHashJoin => j
- case j: BroadcastHashOuterJoin => j
case j: LeftSemiJoinBNL => j
case j: CartesianProduct => j
case j: BroadcastNestedLoopJoin => j
@@ -123,9 +122,9 @@ class JoinSuite extends QueryTest with SharedSQLContext {
("SELECT * FROM testData LEFT JOIN testData2 ON key = a",
classOf[SortMergeOuterJoin]),
("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2",
- classOf[BroadcastHashOuterJoin]),
+ classOf[BroadcastHashJoin]),
("SELECT * FROM testData right join testData2 ON key = a and key = 2",
- classOf[BroadcastHashOuterJoin])
+ classOf[BroadcastHashJoin])
).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
sql("UNCACHE TABLE testData")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index 4a151179bf..bcac660a35 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution
+import java.util.HashMap
+
import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager}
import org.apache.spark.sql.SQLContext
@@ -124,37 +126,65 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
ignore("broadcast hash join") {
val N = 100 << 20
- val dim = broadcast(sqlContext.range(1 << 16).selectExpr("id as k", "cast(id as string) as v"))
+ val M = 1 << 16
+ val dim = broadcast(sqlContext.range(M).selectExpr("id as k", "cast(id as string) as v"))
runBenchmark("Join w long", N) {
- sqlContext.range(N).join(dim, (col("id") % 60000) === col("k")).count()
+ sqlContext.range(N).join(dim, (col("id") bitwiseAND M) === col("k")).count()
}
/*
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
- BroadcastHashJoin: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ Join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
- Join w long codegen=false 10174 / 10317 10.0 100.0 1.0X
- Join w long codegen=true 1069 / 1107 98.0 10.2 9.5X
+ Join w long codegen=false 5744 / 5814 18.3 54.8 1.0X
+ Join w long codegen=true 735 / 853 142.7 7.0 7.8X
*/
- val dim2 = broadcast(sqlContext.range(1 << 16)
+ val dim2 = broadcast(sqlContext.range(M)
.selectExpr("cast(id as int) as k1", "cast(id as int) as k2", "cast(id as string) as v"))
runBenchmark("Join w 2 ints", N) {
sqlContext.range(N).join(dim2,
- (col("id") bitwiseAND 60000).cast(IntegerType) === col("k1")
- && (col("id") bitwiseAND 50000).cast(IntegerType) === col("k2")).count()
+ (col("id") bitwiseAND M).cast(IntegerType) === col("k1")
+ && (col("id") bitwiseAND M).cast(IntegerType) === col("k2")).count()
}
/**
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
- BroadcastHashJoin: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ Join w 2 ints: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
- Join w 2 ints codegen=false 11435 / 11530 9.0 111.1 1.0X
- Join w 2 ints codegen=true 1265 / 1424 82.0 12.2 9.0X
+ Join w 2 ints codegen=false 7159 / 7224 14.6 68.3 1.0X
+ Join w 2 ints codegen=true 1135 / 1197 92.4 10.8 6.3X
*/
+ val dim3 = broadcast(sqlContext.range(M)
+ .selectExpr("id as k1", "id as k2", "cast(id as string) as v"))
+
+ runBenchmark("Join w 2 longs", N) {
+ sqlContext.range(N).join(dim3,
+ (col("id") bitwiseAND M) === col("k1") && (col("id") bitwiseAND M) === col("k2"))
+ .count()
+ }
+
+ /**
+ Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+ Join w 2 longs: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ -------------------------------------------------------------------------------------------
+ Join w 2 longs codegen=false 7877 / 8358 13.3 75.1 1.0X
+ Join w 2 longs codegen=true 3877 / 3937 27.0 37.0 2.0X
+ */
+ runBenchmark("outer join w long", N) {
+ sqlContext.range(N).join(dim, (col("id") bitwiseAND M) === col("k"), "left").count()
+ }
+
+ /**
+ Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+ outer join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ -------------------------------------------------------------------------------------------
+ outer join w long codegen=false 15280 / 16497 6.9 145.7 1.0X
+ outer join w long codegen=true 769 / 796 136.3 7.3 19.9X
+ */
}
ignore("rube") {
@@ -175,7 +205,7 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
}
ignore("hash and BytesToBytesMap") {
- val N = 50 << 20
+ val N = 10 << 20
val benchmark = new Benchmark("BytesToBytesMap", N)
@@ -227,6 +257,80 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
}
}
+ benchmark.addCase("Java HashMap (Long)") { iter =>
+ var i = 0
+ val keyBytes = new Array[Byte](16)
+ val valueBytes = new Array[Byte](16)
+ val value = new UnsafeRow(1)
+ value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16)
+ value.setInt(0, 555)
+ val map = new HashMap[Long, UnsafeRow]()
+ while (i < 65536) {
+ value.setInt(0, i)
+ map.put(i.toLong, value)
+ i += 1
+ }
+ var s = 0
+ i = 0
+ while (i < N) {
+ if (map.get(i % 100000) != null) {
+ s += 1
+ }
+ i += 1
+ }
+ }
+
+ benchmark.addCase("Java HashMap (two ints) ") { iter =>
+ var i = 0
+ val valueBytes = new Array[Byte](16)
+ val value = new UnsafeRow(1)
+ value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16)
+ value.setInt(0, 555)
+ val map = new HashMap[Long, UnsafeRow]()
+ while (i < 65536) {
+ value.setInt(0, i)
+ val key = (i.toLong << 32) + Integer.rotateRight(i, 15)
+ map.put(key, value)
+ i += 1
+ }
+ var s = 0
+ i = 0
+ while (i < N) {
+ val key = ((i & 100000).toLong << 32) + Integer.rotateRight(i & 100000, 15)
+ if (map.get(key) != null) {
+ s += 1
+ }
+ i += 1
+ }
+ }
+
+ benchmark.addCase("Java HashMap (UnsafeRow)") { iter =>
+ var i = 0
+ val keyBytes = new Array[Byte](16)
+ val valueBytes = new Array[Byte](16)
+ val key = new UnsafeRow(1)
+ key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16)
+ val value = new UnsafeRow(1)
+ value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16)
+ value.setInt(0, 555)
+ val map = new HashMap[UnsafeRow, UnsafeRow]()
+ while (i < 65536) {
+ key.setInt(0, i)
+ value.setInt(0, i)
+ map.put(key, value.copy())
+ i += 1
+ }
+ var s = 0
+ i = 0
+ while (i < N) {
+ key.setInt(0, i % 100000)
+ if (map.get(key) != null) {
+ s += 1
+ }
+ i += 1
+ }
+ }
+
Seq("off", "on").foreach { heap =>
benchmark.addCase(s"BytesToBytesMap ($heap Heap)") { iter =>
val taskMemoryManager = new TaskMemoryManager(
@@ -268,6 +372,9 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
hash 651 / 678 80.0 12.5 1.0X
fast hash 336 / 343 155.9 6.4 1.9X
arrayEqual 417 / 428 125.0 8.0 1.6X
+ Java HashMap (Long) 145 / 168 72.2 13.8 0.8X
+ Java HashMap (two ints) 157 / 164 66.8 15.0 0.8X
+ Java HashMap (UnsafeRow) 538 / 573 19.5 51.3 0.2X
BytesToBytesMap (off Heap) 2594 / 2664 20.2 49.5 0.2X
BytesToBytesMap (on Heap) 2693 / 2989 19.5 51.4 0.2X
*/
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
index aee8e84db5..e25b5e0610 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
@@ -73,7 +73,7 @@ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll {
}
test("unsafe broadcast hash outer join updates peak execution memory") {
- testBroadcastJoin[BroadcastHashOuterJoin]("unsafe broadcast hash outer join", "left_outer")
+ testBroadcastJoin[BroadcastHashJoin]("unsafe broadcast hash outer join", "left_outer")
}
test("unsafe broadcast left semi join updates peak execution memory") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
index 149f34dbd7..e22a810a6b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
@@ -88,7 +88,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
leftPlan: SparkPlan,
rightPlan: SparkPlan,
side: BuildSide) = {
- joins.BroadcastHashJoin(leftKeys, rightKeys, side, boundCondition, leftPlan, rightPlan)
+ joins.BroadcastHashJoin(leftKeys, rightKeys, Inner, side, boundCondition, leftPlan, rightPlan)
}
def makeSortMergeJoin(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
index 3d3e9a7b90..f4b01fbad0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
@@ -75,11 +75,16 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext {
}
if (joinType != FullOuter) {
- test(s"$testName using BroadcastHashOuterJoin") {
+ test(s"$testName using BroadcastHashJoin") {
+ val buildSide = joinType match {
+ case LeftOuter => BuildRight
+ case RightOuter => BuildLeft
+ }
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
- BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right),
+ BroadcastHashJoin(
+ leftKeys, rightKeys, joinType, buildSide, boundCondition, left, right),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index f4bc9e501c..46bb699b78 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -209,20 +209,20 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
)
}
- test("BroadcastHashOuterJoin metrics") {
+ test("BroadcastHashJoin(outer) metrics") {
val df1 = Seq((1, "a"), (1, "b"), (4, "c")).toDF("key", "value")
val df2 = Seq((1, "a"), (1, "b"), (2, "c"), (3, "d")).toDF("key2", "value")
// Assume the execution plan is
- // ... -> BroadcastHashOuterJoin(nodeId = 0)
+ // ... -> BroadcastHashJoin(nodeId = 0)
val df = df1.join(broadcast(df2), $"key" === $"key2", "left_outer")
testSparkPlanMetrics(df, 2, Map(
- 0L -> ("BroadcastHashOuterJoin", Map(
+ 0L -> ("BroadcastHashJoin", Map(
"number of output rows" -> 5L)))
)
val df3 = df1.join(broadcast(df2), $"key" === $"key2", "right_outer")
testSparkPlanMetrics(df3, 2, Map(
- 0L -> ("BroadcastHashOuterJoin", Map(
+ 0L -> ("BroadcastHashJoin", Map(
"number of output rows" -> 6L)))
)
}