aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala138
1 files changed, 77 insertions, 61 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala
index ab20ee573a..c117dff9c8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala
@@ -199,99 +199,115 @@ case class SortMergeOuterJoin(
}
}
-
+/**
+ * An iterator for outputting rows in left outer join.
+ */
private class LeftOuterIterator(
smjScanner: SortMergeJoinScanner,
rightNullRow: InternalRow,
boundCondition: InternalRow => Boolean,
resultProj: InternalRow => InternalRow,
- numRows: LongSQLMetric
- ) extends RowIterator {
- private[this] val joinedRow: JoinedRow = new JoinedRow()
- private[this] var rightIdx: Int = 0
- assert(smjScanner.getBufferedMatches.length == 0)
-
- private def advanceLeft(): Boolean = {
- rightIdx = 0
- if (smjScanner.findNextOuterJoinRows()) {
- joinedRow.withLeft(smjScanner.getStreamedRow)
- if (smjScanner.getBufferedMatches.isEmpty) {
- // There are no matching right rows, so return nulls for the right row
- joinedRow.withRight(rightNullRow)
- } else {
- // Find the next row from the right input that satisfied the bound condition
- if (!advanceRightUntilBoundConditionSatisfied()) {
- joinedRow.withRight(rightNullRow)
- }
- }
- true
- } else {
- // Left input has been exhausted
- false
- }
- }
-
- private def advanceRightUntilBoundConditionSatisfied(): Boolean = {
- var foundMatch: Boolean = false
- while (!foundMatch && rightIdx < smjScanner.getBufferedMatches.length) {
- foundMatch = boundCondition(joinedRow.withRight(smjScanner.getBufferedMatches(rightIdx)))
- rightIdx += 1
- }
- foundMatch
- }
-
- override def advanceNext(): Boolean = {
- val r = advanceRightUntilBoundConditionSatisfied() || advanceLeft()
- if (r) numRows += 1
- r
- }
+ numOutputRows: LongSQLMetric)
+ extends OneSideOuterIterator(
+ smjScanner, rightNullRow, boundCondition, resultProj, numOutputRows) {
- override def getRow: InternalRow = resultProj(joinedRow)
+ protected override def setStreamSideOutput(row: InternalRow): Unit = joinedRow.withLeft(row)
+ protected override def setBufferedSideOutput(row: InternalRow): Unit = joinedRow.withRight(row)
}
+/**
+ * An iterator for outputting rows in right outer join.
+ */
private class RightOuterIterator(
smjScanner: SortMergeJoinScanner,
leftNullRow: InternalRow,
boundCondition: InternalRow => Boolean,
resultProj: InternalRow => InternalRow,
- numRows: LongSQLMetric
- ) extends RowIterator {
- private[this] val joinedRow: JoinedRow = new JoinedRow()
- private[this] var leftIdx: Int = 0
+ numOutputRows: LongSQLMetric)
+ extends OneSideOuterIterator(
+ smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows) {
+
+ protected override def setStreamSideOutput(row: InternalRow): Unit = joinedRow.withRight(row)
+ protected override def setBufferedSideOutput(row: InternalRow): Unit = joinedRow.withLeft(row)
+}
+
+/**
+ * An abstract iterator for sharing code between [[LeftOuterIterator]] and [[RightOuterIterator]].
+ *
+ * Each [[OneSideOuterIterator]] has a streamed side and a buffered side. Each row on the
+ * streamed side will output 0 or many rows, one for each matching row on the buffered side.
+ * If there are no matches, then the buffered side of the joined output will be a null row.
+ *
+ * In left outer join, the left is the streamed side and the right is the buffered side.
+ * In right outer join, the right is the streamed side and the left is the buffered side.
+ *
+ * @param smjScanner a scanner that streams rows and buffers any matching rows
+ * @param bufferedSideNullRow the default row to return when a streamed row has no matches
+ * @param boundCondition an additional filter condition for buffered rows
+ * @param resultProj how the output should be projected
+ * @param numOutputRows an accumulator metric for the number of rows output
+ */
+private abstract class OneSideOuterIterator(
+ smjScanner: SortMergeJoinScanner,
+ bufferedSideNullRow: InternalRow,
+ boundCondition: InternalRow => Boolean,
+ resultProj: InternalRow => InternalRow,
+ numOutputRows: LongSQLMetric) extends RowIterator {
+
+ // A row to store the joined result, reused many times
+ protected[this] val joinedRow: JoinedRow = new JoinedRow()
+
+ // Index of the buffered rows, reset to 0 whenever we advance to a new streamed row
+ private[this] var bufferIndex: Int = 0
+
+ // This iterator is initialized lazily so there should be no matches initially
assert(smjScanner.getBufferedMatches.length == 0)
- private def advanceRight(): Boolean = {
- leftIdx = 0
+ // Set output methods to be overridden by subclasses
+ protected def setStreamSideOutput(row: InternalRow): Unit
+ protected def setBufferedSideOutput(row: InternalRow): Unit
+
+ /**
+ * Advance to the next row on the stream side and populate the buffer with matches.
+ * @return whether there are more rows in the stream to consume.
+ */
+ private def advanceStream(): Boolean = {
+ bufferIndex = 0
if (smjScanner.findNextOuterJoinRows()) {
- joinedRow.withRight(smjScanner.getStreamedRow)
+ setStreamSideOutput(smjScanner.getStreamedRow)
if (smjScanner.getBufferedMatches.isEmpty) {
- // There are no matching left rows, so return nulls for the left row
- joinedRow.withLeft(leftNullRow)
+ // There are no matching rows in the buffer, so return the null row
+ setBufferedSideOutput(bufferedSideNullRow)
} else {
- // Find the next row from the left input that satisfied the bound condition
- if (!advanceLeftUntilBoundConditionSatisfied()) {
- joinedRow.withLeft(leftNullRow)
+ // Find the next row in the buffer that satisfied the bound condition
+ if (!advanceBufferUntilBoundConditionSatisfied()) {
+ setBufferedSideOutput(bufferedSideNullRow)
}
}
true
} else {
- // Right input has been exhausted
+ // Stream has been exhausted
false
}
}
- private def advanceLeftUntilBoundConditionSatisfied(): Boolean = {
+ /**
+ * Advance to the next row in the buffer that satisfies the bound condition.
+ * @return whether there is such a row in the current buffer.
+ */
+ private def advanceBufferUntilBoundConditionSatisfied(): Boolean = {
var foundMatch: Boolean = false
- while (!foundMatch && leftIdx < smjScanner.getBufferedMatches.length) {
- foundMatch = boundCondition(joinedRow.withLeft(smjScanner.getBufferedMatches(leftIdx)))
- leftIdx += 1
+ while (!foundMatch && bufferIndex < smjScanner.getBufferedMatches.length) {
+ setBufferedSideOutput(smjScanner.getBufferedMatches(bufferIndex))
+ foundMatch = boundCondition(joinedRow)
+ bufferIndex += 1
}
foundMatch
}
override def advanceNext(): Boolean = {
- val r = advanceLeftUntilBoundConditionSatisfied() || advanceRight()
- if (r) numRows += 1
+ val r = advanceBufferUntilBoundConditionSatisfied() || advanceStream()
+ if (r) numOutputRows += 1
r
}