aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala57
1 files changed, 35 insertions, 22 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
index 4143e944e5..4ba710c10a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
@@ -73,7 +73,7 @@ case class BroadcastNestedLoopJoin(
left.output.map(_.withNullability(true)) ++ right.output
case FullOuter =>
left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
- case LeftSemi =>
+ case LeftExistence(_) =>
left.output
case x =>
throw new IllegalArgumentException(
@@ -175,8 +175,11 @@ case class BroadcastNestedLoopJoin(
* The implementation for these joins:
*
* LeftSemi with BuildRight
+ * Anti with BuildRight
*/
- private def leftSemiJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = {
+ private def leftExistenceJoin(
+ relation: Broadcast[Array[InternalRow]],
+ exists: Boolean): RDD[InternalRow] = {
assert(buildSide == BuildRight)
streamed.execute().mapPartitionsInternal { streamedIter =>
val buildRows = relation.value
@@ -184,10 +187,12 @@ case class BroadcastNestedLoopJoin(
if (condition.isDefined) {
streamedIter.filter(l =>
- buildRows.exists(r => boundCondition(joinedRow(l, r)))
+ buildRows.exists(r => boundCondition(joinedRow(l, r))) == exists
)
+ } else if (buildRows.nonEmpty == exists) {
+ streamedIter
} else {
- streamedIter.filter(r => !buildRows.isEmpty)
+ Iterator.empty
}
}
}
@@ -199,6 +204,7 @@ case class BroadcastNestedLoopJoin(
* RightOuter with BuildRight
* FullOuter
* LeftSemi with BuildLeft
+ * Anti with BuildLeft
*/
private def defaultJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = {
/** All rows that either match both-way, or rows from streamed joined with nulls. */
@@ -236,7 +242,27 @@ case class BroadcastNestedLoopJoin(
}
i += 1
}
- return sparkContext.makeRDD(buf.toSeq)
+ return sparkContext.makeRDD(buf)
+ }
+
+ val notMatchedBroadcastRows: Seq[InternalRow] = {
+ val nulls = new GenericMutableRow(streamed.output.size)
+ val buf: CompactBuffer[InternalRow] = new CompactBuffer()
+ var i = 0
+ val buildRows = relation.value
+ val joinedRow = new JoinedRow
+ joinedRow.withLeft(nulls)
+ while (i < buildRows.length) {
+ if (!matchedBroadcastRows.get(i)) {
+ buf += joinedRow.withRight(buildRows(i)).copy()
+ }
+ i += 1
+ }
+ buf
+ }
+
+ if (joinType == LeftAnti) {
+ return sparkContext.makeRDD(notMatchedBroadcastRows)
}
val matchedStreamRows = streamRdd.mapPartitionsInternal { streamedIter =>
@@ -264,22 +290,6 @@ case class BroadcastNestedLoopJoin(
}
}
- val notMatchedBroadcastRows: Seq[InternalRow] = {
- val nulls = new GenericMutableRow(streamed.output.size)
- val buf: CompactBuffer[InternalRow] = new CompactBuffer()
- var i = 0
- val buildRows = relation.value
- val joinedRow = new JoinedRow
- joinedRow.withLeft(nulls)
- while (i < buildRows.length) {
- if (!matchedBroadcastRows.get(i)) {
- buf += joinedRow.withRight(buildRows(i)).copy()
- }
- i += 1
- }
- buf.toSeq
- }
-
sparkContext.union(
matchedStreamRows,
sparkContext.makeRDD(notMatchedBroadcastRows)
@@ -295,13 +305,16 @@ case class BroadcastNestedLoopJoin(
case (LeftOuter, BuildRight) | (RightOuter, BuildLeft) =>
outerJoin(broadcastedRelation)
case (LeftSemi, BuildRight) =>
- leftSemiJoin(broadcastedRelation)
+ leftExistenceJoin(broadcastedRelation, exists = true)
+ case (LeftAnti, BuildRight) =>
+ leftExistenceJoin(broadcastedRelation, exists = false)
case _ =>
/**
* LeftOuter with BuildLeft
* RightOuter with BuildRight
* FullOuter
* LeftSemi with BuildLeft
+ * Anti with BuildLeft
*/
defaultJoin(broadcastedRelation)
}