aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-02-26 09:58:05 -0800
committerDavies Liu <davies.liu@gmail.com>2016-02-26 09:58:05 -0800
commit6df1e55a6594ae4bc7882f44af8d230aad9489b4 (patch)
tree296c8e72e5b383f0647d9bd7b31853bdb5adc82d /sql
parent727e78014fd4957e477d62adc977fa4da3e3455d (diff)
downloadspark-6df1e55a6594ae4bc7882f44af8d230aad9489b4.tar.gz
spark-6df1e55a6594ae4bc7882f44af8d230aad9489b4.tar.bz2
spark-6df1e55a6594ae4bc7882f44af8d230aad9489b4.zip
[SPARK-12313] [SQL] improve performance of BroadcastNestedLoopJoin
## What changes were proposed in this pull request? Currently, BroadcastNestedLoopJoin is implemented for worst case, it's too slow, very easy to hang forever. This PR will create fast path for some joinType and buildSide, also improve the worst case (will use much less memory than before). Before this PR, one task requires O(N*K) + O(K) in worst cases, N is number of rows from one partition of streamed table, it could hang the job (because of GC). In order to workaround this for InnerJoin, we have to disable auto-broadcast, switch to CartesianProduct: This could be workaround for InnerJoin, see https://forums.databricks.com/questions/6747/how-do-i-get-a-cartesian-product-of-a-huge-dataset.html In this PR, we will have fast path for these joins : InnerJoin with BuildLeft or BuildRight LeftOuterJoin with BuildRight RightOuterJoin with BuildLeft LeftSemi with BuildRight These fast paths are all stream based (take one pass on streamed table), required O(1) memory. All other join types and build types will take two pass on streamed table, one pass to find the matched rows that includes streamed part, which require O(1) memory, another pass to find the rows from build table that does not have a matched row from streamed table, which required O(K) memory, K is the number rows from build side, one bit per row, should be much smaller than the memory for broadcast. The following join types work in this way: LeftOuterJoin with BuildLeft RightOuterJoin with BuildRight FullOuterJoin with BuildLeft or BuildRight LeftSemi with BuildLeft This PR also added tests for all the join types for BroadcastNestedLoopJoin. After this PR, for InnerJoin with one small table, BroadcastNestedLoopJoin should be faster than CartesianProduct, we don't need that workaround anymore. ## How was the this patch tested? Added unit tests. Author: Davies Liu <davies@databricks.com> Closes #11328 from davies/nested_loop.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala14
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala295
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala11
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala27
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala18
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala20
7 files changed, 295 insertions, 91 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala
index c646dcfa11..e01f69f813 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala
@@ -31,5 +31,6 @@ trait BroadcastMode {
* IdentityBroadcastMode requires that rows are broadcasted in their original form.
*/
case object IdentityBroadcastMode extends BroadcastMode {
+ // TODO: pack the UnsafeRows into single bytes array.
override def transform(rows: Array[InternalRow]): Array[InternalRow] = rows
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 5fdf38c733..dd8c96d5fa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -253,22 +253,19 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
object BroadcastNestedLoop extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case logical.Join(
- CanBroadcast(left), right, joinType, condition) if joinType != LeftSemi =>
+ case j @ logical.Join(CanBroadcast(left), right, Inner | RightOuter, condition) =>
execution.joins.BroadcastNestedLoopJoin(
- planLater(left), planLater(right), joins.BuildLeft, joinType, condition) :: Nil
- case logical.Join(
- left, CanBroadcast(right), joinType, condition) if joinType != LeftSemi =>
+ planLater(left), planLater(right), joins.BuildLeft, j.joinType, condition) :: Nil
+ case j @ logical.Join(left, CanBroadcast(right), Inner | LeftOuter | LeftSemi, condition) =>
execution.joins.BroadcastNestedLoopJoin(
- planLater(left), planLater(right), joins.BuildRight, joinType, condition) :: Nil
+ planLater(left), planLater(right), joins.BuildRight, j.joinType, condition) :: Nil
case _ => Nil
}
}
object CartesianProduct extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- // TODO CartesianProduct doesn't support the Left Semi Join
- case logical.Join(left, right, joinType, None) if joinType != LeftSemi =>
+ case logical.Join(left, right, Inner, None) =>
execution.joins.CartesianProduct(planLater(left), planLater(right)) :: Nil
case logical.Join(left, right, Inner, Some(condition)) =>
execution.Filter(condition,
@@ -286,6 +283,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
} else {
joins.BuildLeft
}
+ // This join could be very slow or even hang forever
joins.BroadcastNestedLoopJoin(
planLater(left), planLater(right), buildSide, joinType, condition) :: Nil
case _ => Nil
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
index e8bd7f69db..d83486df02 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.execution.joins
+import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
@@ -26,7 +27,6 @@ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.util.collection.{BitSet, CompactBuffer}
-
case class BroadcastNestedLoopJoin(
left: SparkPlan,
right: SparkPlan,
@@ -51,125 +51,266 @@ case class BroadcastNestedLoopJoin(
}
private[this] def genResultProjection: InternalRow => InternalRow = {
- UnsafeProjection.create(schema)
+ if (joinType == LeftSemi) {
+ UnsafeProjection.create(output, output)
+ } else {
+ // Always put the stream side on left to simplify implementation
+ UnsafeProjection.create(output, streamed.output ++ broadcast.output)
+ }
}
override def outputPartitioning: Partitioning = streamed.outputPartitioning
override def output: Seq[Attribute] = {
joinType match {
+ case Inner =>
+ left.output ++ right.output
case LeftOuter =>
left.output ++ right.output.map(_.withNullability(true))
case RightOuter =>
left.output.map(_.withNullability(true)) ++ right.output
case FullOuter =>
left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
- case Inner =>
- // TODO we can avoid breaking the lineage, since we union an empty RDD for Inner Join case
- left.output ++ right.output
- case x => // TODO support the Left Semi Join
+ case LeftSemi =>
+ left.output
+ case x =>
throw new IllegalArgumentException(
s"BroadcastNestedLoopJoin should not take $x as the JoinType")
}
}
- @transient private lazy val boundCondition =
- newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
+ @transient private lazy val boundCondition = {
+ if (condition.isDefined) {
+ newPredicate(condition.get, streamed.output ++ broadcast.output)
+ } else {
+ (r: InternalRow) => true
+ }
+ }
- protected override def doExecute(): RDD[InternalRow] = {
- val numOutputRows = longMetric("numOutputRows")
+ /**
+ * The implementation for InnerJoin.
+ */
+ private def innerJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = {
+ streamed.execute().mapPartitionsInternal { streamedIter =>
+ val buildRows = relation.value
+ val joinedRow = new JoinedRow
- val broadcastedRelation = broadcast.executeBroadcast[Array[InternalRow]]()
+ streamedIter.flatMap { streamedRow =>
+ val joinedRows = buildRows.iterator.map(r => joinedRow(streamedRow, r))
+ if (condition.isDefined) {
+ joinedRows.filter(boundCondition)
+ } else {
+ joinedRows
+ }
+ }
+ }
+ }
- /** All rows that either match both-way, or rows from streamed joined with nulls. */
- val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter =>
- val relation = broadcastedRelation.value
+ /**
+ * The implementation for these joins:
+ *
+ * LeftOuter with BuildRight
+ * RightOuter with BuildLeft
+ */
+ private def outerJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = {
+ streamed.execute().mapPartitionsInternal { streamedIter =>
+ val buildRows = relation.value
+ val joinedRow = new JoinedRow
+ val nulls = new GenericMutableRow(broadcast.output.size)
+
+ // Returns an iterator to avoid copy the rows.
+ new Iterator[InternalRow] {
+ // current row from stream side
+ private var streamRow: InternalRow = null
+ // have found a match for current row or not
+ private var foundMatch: Boolean = false
+ // the matched result row
+ private var resultRow: InternalRow = null
+ // the next index of buildRows to try
+ private var nextIndex: Int = 0
- val matchedRows = new CompactBuffer[InternalRow]
- val includedBroadcastTuples = new BitSet(relation.length)
+ private def findNextMatch(): Boolean = {
+ if (streamRow == null) {
+ if (!streamedIter.hasNext) {
+ return false
+ }
+ streamRow = streamedIter.next()
+ nextIndex = 0
+ foundMatch = false
+ }
+ while (nextIndex < buildRows.length) {
+ resultRow = joinedRow(streamRow, buildRows(nextIndex))
+ nextIndex += 1
+ if (boundCondition(resultRow)) {
+ foundMatch = true
+ return true
+ }
+ }
+ if (!foundMatch) {
+ resultRow = joinedRow(streamRow, nulls)
+ streamRow = null
+ true
+ } else {
+ resultRow = null
+ streamRow = null
+ findNextMatch()
+ }
+ }
+
+ override def hasNext(): Boolean = {
+ resultRow != null || findNextMatch()
+ }
+ override def next(): InternalRow = {
+ val r = resultRow
+ resultRow = null
+ r
+ }
+ }
+ }
+ }
+
+ /**
+ * The implementation for these joins:
+ *
+ * LeftSemi with BuildRight
+ */
+ private def leftSemiJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = {
+ assert(buildSide == BuildRight)
+ streamed.execute().mapPartitionsInternal { streamedIter =>
+ val buildRows = relation.value
val joinedRow = new JoinedRow
- val leftNulls = new GenericMutableRow(left.output.size)
- val rightNulls = new GenericMutableRow(right.output.size)
- val resultProj = genResultProjection
+ if (condition.isDefined) {
+ streamedIter.filter(l =>
+ buildRows.exists(r => boundCondition(joinedRow(l, r)))
+ )
+ } else {
+ streamedIter.filter(r => !buildRows.isEmpty)
+ }
+ }
+ }
+
+ /**
+ * The implementation for these joins:
+ *
+ * LeftOuter with BuildLeft
+ * RightOuter with BuildRight
+ * FullOuter
+ * LeftSemi with BuildLeft
+ */
+ private def defaultJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = {
+ /** All rows that either match both-way, or rows from streamed joined with nulls. */
+ val streamRdd = streamed.execute()
+
+ val matchedBuildRows = streamRdd.mapPartitionsInternal { streamedIter =>
+ val buildRows = relation.value
+ val matched = new BitSet(buildRows.length)
+ val joinedRow = new JoinedRow
streamedIter.foreach { streamedRow =>
var i = 0
- var streamRowMatched = false
-
- while (i < relation.length) {
- val broadcastedRow = relation(i)
- buildSide match {
- case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) =>
- matchedRows += resultProj(joinedRow(streamedRow, broadcastedRow)).copy()
- streamRowMatched = true
- includedBroadcastTuples.set(i)
- case BuildLeft if boundCondition(joinedRow(broadcastedRow, streamedRow)) =>
- matchedRows += resultProj(joinedRow(broadcastedRow, streamedRow)).copy()
- streamRowMatched = true
- includedBroadcastTuples.set(i)
- case _ =>
+ while (i < buildRows.length) {
+ if (boundCondition(joinedRow(streamedRow, buildRows(i)))) {
+ matched.set(i)
}
i += 1
}
+ }
+ Seq(matched).toIterator
+ }
- (streamRowMatched, joinType, buildSide) match {
- case (false, LeftOuter | FullOuter, BuildRight) =>
- matchedRows += resultProj(joinedRow(streamedRow, rightNulls)).copy()
- case (false, RightOuter | FullOuter, BuildLeft) =>
- matchedRows += resultProj(joinedRow(leftNulls, streamedRow)).copy()
- case _ =>
+ val matchedBroadcastRows = matchedBuildRows.fold(
+ new BitSet(relation.value.length)
+ )(_ | _)
+
+ if (joinType == LeftSemi) {
+ assert(buildSide == BuildLeft)
+ val buf: CompactBuffer[InternalRow] = new CompactBuffer()
+ var i = 0
+ val rel = relation.value
+ while (i < rel.length) {
+ if (matchedBroadcastRows.get(i)) {
+ buf += rel(i).copy()
}
+ i += 1
}
- Iterator((matchedRows, includedBroadcastTuples))
+ return sparkContext.makeRDD(buf.toSeq)
}
- val includedBroadcastTuples = matchesOrStreamedRowsWithNulls.map(_._2)
- val allIncludedBroadcastTuples = includedBroadcastTuples.fold(
- new BitSet(broadcastedRelation.value.size)
- )(_ | _)
+ val matchedStreamRows = streamRdd.mapPartitionsInternal { streamedIter =>
+ val buildRows = relation.value
+ val joinedRow = new JoinedRow
+ val nulls = new GenericMutableRow(broadcast.output.size)
- val leftNulls = new GenericMutableRow(left.output.size)
- val rightNulls = new GenericMutableRow(right.output.size)
- val resultProj = genResultProjection
+ streamedIter.flatMap { streamedRow =>
+ var i = 0
+ var foundMatch = false
+ val matchedRows = new CompactBuffer[InternalRow]
+
+ while (i < buildRows.length) {
+ if (boundCondition(joinedRow(streamedRow, buildRows(i)))) {
+ matchedRows += joinedRow.copy()
+ foundMatch = true
+ }
+ i += 1
+ }
+
+ if (!foundMatch && joinType == FullOuter) {
+ matchedRows += joinedRow(streamedRow, nulls).copy()
+ }
+ matchedRows.iterator
+ }
+ }
- /** Rows from broadcasted joined with nulls. */
- val broadcastRowsWithNulls: Seq[InternalRow] = {
+ val notMatchedBroadcastRows: Seq[InternalRow] = {
+ val nulls = new GenericMutableRow(streamed.output.size)
val buf: CompactBuffer[InternalRow] = new CompactBuffer()
var i = 0
- val rel = broadcastedRelation.value
- (joinType, buildSide) match {
- case (RightOuter | FullOuter, BuildRight) =>
- val joinedRow = new JoinedRow
- joinedRow.withLeft(leftNulls)
- while (i < rel.length) {
- if (!allIncludedBroadcastTuples.get(i)) {
- buf += resultProj(joinedRow.withRight(rel(i))).copy()
- }
- i += 1
- }
- case (LeftOuter | FullOuter, BuildLeft) =>
- val joinedRow = new JoinedRow
- joinedRow.withRight(rightNulls)
- while (i < rel.length) {
- if (!allIncludedBroadcastTuples.get(i)) {
- buf += resultProj(joinedRow.withLeft(rel(i))).copy()
- }
- i += 1
- }
- case _ =>
+ val buildRows = relation.value
+ val joinedRow = new JoinedRow
+ joinedRow.withLeft(nulls)
+ while (i < buildRows.length) {
+ if (!matchedBroadcastRows.get(i)) {
+ buf += joinedRow.withRight(buildRows(i)).copy()
+ }
+ i += 1
}
buf.toSeq
}
- // TODO: Breaks lineage.
sparkContext.union(
- matchesOrStreamedRowsWithNulls.flatMap(_._1),
- sparkContext.makeRDD(broadcastRowsWithNulls)
- ).map { row =>
- // `broadcastRowsWithNulls` doesn't run in a job so that we have to track numOutputRows here.
- numOutputRows += 1
- row
+ matchedStreamRows,
+ sparkContext.makeRDD(notMatchedBroadcastRows)
+ )
+ }
+
+ protected override def doExecute(): RDD[InternalRow] = {
+ val broadcastedRelation = broadcast.executeBroadcast[Array[InternalRow]]()
+
+ val resultRdd = (joinType, buildSide) match {
+ case (Inner, _) =>
+ innerJoin(broadcastedRelation)
+ case (LeftOuter, BuildRight) | (RightOuter, BuildLeft) =>
+ outerJoin(broadcastedRelation)
+ case (LeftSemi, BuildRight) =>
+ leftSemiJoin(broadcastedRelation)
+ case _ =>
+ /**
+ * LeftOuter with BuildLeft
+ * RightOuter with BuildRight
+ * FullOuter
+ * LeftSemi with BuildLeft
+ */
+ defaultJoin(broadcastedRelation)
+ }
+
+ val numOutputRows = longMetric("numOutputRows")
+ resultRdd.mapPartitionsInternal { iter =>
+ val resultProj = genResultProjection
+ iter.map { r =>
+ numOutputRows += 1
+ resultProj(r)
+ }
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 41e27ec466..3dab848e7b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -70,13 +70,14 @@ class JoinSuite extends QueryTest with SharedSQLContext {
("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]),
("SELECT * FROM testData JOIN testData2", classOf[CartesianProduct]),
("SELECT * FROM testData JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
- ("SELECT * FROM testData LEFT JOIN testData2", classOf[CartesianProduct]),
- ("SELECT * FROM testData RIGHT JOIN testData2", classOf[CartesianProduct]),
- ("SELECT * FROM testData FULL OUTER JOIN testData2", classOf[CartesianProduct]),
- ("SELECT * FROM testData LEFT JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
+ ("SELECT * FROM testData LEFT JOIN testData2", classOf[BroadcastNestedLoopJoin]),
+ ("SELECT * FROM testData RIGHT JOIN testData2", classOf[BroadcastNestedLoopJoin]),
+ ("SELECT * FROM testData FULL OUTER JOIN testData2", classOf[BroadcastNestedLoopJoin]),
+ ("SELECT * FROM testData LEFT JOIN testData2 WHERE key = 2",
+ classOf[BroadcastNestedLoopJoin]),
("SELECT * FROM testData RIGHT JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key = 2",
- classOf[CartesianProduct]),
+ classOf[BroadcastNestedLoopJoin]),
("SELECT * FROM testData JOIN testData2 WHERE key > a", classOf[CartesianProduct]),
("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key > a",
classOf[CartesianProduct]),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
index b748229e40..7eb15249eb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
@@ -146,6 +146,33 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
}
}
}
+
+ test(s"$testName using CartesianProduct") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+ checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
+ Filter(condition(), CartesianProduct(left, right)),
+ expectedAnswer.map(Row.fromTuple),
+ sortAnswers = true)
+ }
+ }
+
+ test(s"$testName using BroadcastNestedLoopJoin build left") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+ checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
+ BroadcastNestedLoopJoin(left, right, BuildLeft, Inner, Some(condition())),
+ expectedAnswer.map(Row.fromTuple),
+ sortAnswers = true)
+ }
+ }
+
+ test(s"$testName using BroadcastNestedLoopJoin build right") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+ checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
+ BroadcastNestedLoopJoin(left, right, BuildRight, Inner, Some(condition())),
+ expectedAnswer.map(Row.fromTuple),
+ sortAnswers = true)
+ }
+ }
}
testInnerJoin(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
index 22fe8caff2..0d1c29fe57 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
@@ -105,6 +105,24 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext {
}
}
}
+
+ test(s"$testName using BroadcastNestedLoopJoin build left") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+ checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
+ BroadcastNestedLoopJoin(left, right, BuildLeft, joinType, Some(condition)),
+ expectedAnswer.map(Row.fromTuple),
+ sortAnswers = true)
+ }
+ }
+
+ test(s"$testName using BroadcastNestedLoopJoin build right") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+ checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
+ BroadcastNestedLoopJoin(left, right, BuildRight, joinType, Some(condition)),
+ expectedAnswer.map(Row.fromTuple),
+ sortAnswers = true)
+ }
+ }
}
// --- Basic outer joins ------------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
index 5c982885d6..355f916a97 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan}
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
-import org.apache.spark.sql.catalyst.plans.Inner
+import org.apache.spark.sql.catalyst.plans.{Inner, LeftSemi}
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest}
import org.apache.spark.sql.execution.exchange.EnsureRequirements
@@ -103,6 +103,24 @@ class SemiJoinSuite extends SparkPlanTest with SharedSQLContext {
sortAnswers = true)
}
}
+
+ test(s"$testName using BroadcastNestedLoopJoin build left") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+ checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
+ BroadcastNestedLoopJoin(left, right, BuildLeft, LeftSemi, Some(condition)),
+ expectedAnswer.map(Row.fromTuple),
+ sortAnswers = true)
+ }
+ }
+
+ test(s"$testName using BroadcastNestedLoopJoin build right") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+ checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
+ BroadcastNestedLoopJoin(left, right, BuildRight, LeftSemi, Some(condition)),
+ expectedAnswer.map(Row.fromTuple),
+ sortAnswers = true)
+ }
+ }
}
testLeftSemiJoin(