diff options
Diffstat (limited to 'sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala')
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala | 57 |
1 files changed, 35 insertions, 22 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index 4143e944e5..4ba710c10a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -73,7 +73,7 @@ case class BroadcastNestedLoopJoin( left.output.map(_.withNullability(true)) ++ right.output case FullOuter => left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) - case LeftSemi => + case LeftExistence(_) => left.output case x => throw new IllegalArgumentException( @@ -175,8 +175,11 @@ case class BroadcastNestedLoopJoin( * The implementation for these joins: * * LeftSemi with BuildRight + * Anti with BuildRight */ - private def leftSemiJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = { + private def leftExistenceJoin( + relation: Broadcast[Array[InternalRow]], + exists: Boolean): RDD[InternalRow] = { assert(buildSide == BuildRight) streamed.execute().mapPartitionsInternal { streamedIter => val buildRows = relation.value @@ -184,10 +187,12 @@ case class BroadcastNestedLoopJoin( if (condition.isDefined) { streamedIter.filter(l => - buildRows.exists(r => boundCondition(joinedRow(l, r))) + buildRows.exists(r => boundCondition(joinedRow(l, r))) == exists ) + } else if (buildRows.nonEmpty == exists) { + streamedIter } else { - streamedIter.filter(r => !buildRows.isEmpty) + Iterator.empty } } } @@ -199,6 +204,7 @@ case class BroadcastNestedLoopJoin( * RightOuter with BuildRight * FullOuter * LeftSemi with BuildLeft + * Anti with BuildLeft */ private def defaultJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = { /** All rows that either match both-way, or rows from streamed joined with nulls. */ @@ -236,7 +242,27 @@ case class BroadcastNestedLoopJoin( } i += 1 } - return sparkContext.makeRDD(buf.toSeq) + return sparkContext.makeRDD(buf) + } + + val notMatchedBroadcastRows: Seq[InternalRow] = { + val nulls = new GenericMutableRow(streamed.output.size) + val buf: CompactBuffer[InternalRow] = new CompactBuffer() + var i = 0 + val buildRows = relation.value + val joinedRow = new JoinedRow + joinedRow.withLeft(nulls) + while (i < buildRows.length) { + if (!matchedBroadcastRows.get(i)) { + buf += joinedRow.withRight(buildRows(i)).copy() + } + i += 1 + } + buf + } + + if (joinType == LeftAnti) { + return sparkContext.makeRDD(notMatchedBroadcastRows) } val matchedStreamRows = streamRdd.mapPartitionsInternal { streamedIter => @@ -264,22 +290,6 @@ case class BroadcastNestedLoopJoin( } } - val notMatchedBroadcastRows: Seq[InternalRow] = { - val nulls = new GenericMutableRow(streamed.output.size) - val buf: CompactBuffer[InternalRow] = new CompactBuffer() - var i = 0 - val buildRows = relation.value - val joinedRow = new JoinedRow - joinedRow.withLeft(nulls) - while (i < buildRows.length) { - if (!matchedBroadcastRows.get(i)) { - buf += joinedRow.withRight(buildRows(i)).copy() - } - i += 1 - } - buf.toSeq - } - sparkContext.union( matchedStreamRows, sparkContext.makeRDD(notMatchedBroadcastRows) @@ -295,13 +305,16 @@ case class BroadcastNestedLoopJoin( case (LeftOuter, BuildRight) | (RightOuter, BuildLeft) => outerJoin(broadcastedRelation) case (LeftSemi, BuildRight) => - leftSemiJoin(broadcastedRelation) + leftExistenceJoin(broadcastedRelation, exists = true) + case (LeftAnti, BuildRight) => + leftExistenceJoin(broadcastedRelation, exists = false) case _ => /** * LeftOuter with BuildLeft * RightOuter with BuildRight * FullOuter * LeftSemi with BuildLeft + * Anti with BuildLeft */ defaultJoin(broadcastedRelation) } |