aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorAndrew Or <andrew@databricks.com>2015-09-10 13:22:35 -0700
committerDavies Liu <davies.liu@gmail.com>2015-09-10 13:22:35 -0700
commit3db72554be3f13478ccd5915e746491982163298 (patch)
treeae2a84cfa1d0b4e38478cdf2c95307ad2193d948 /sql
parent45e3be5c138d983f40f619735d60bf7eb78c9bf0 (diff)
downloadspark-3db72554be3f13478ccd5915e746491982163298.tar.gz
spark-3db72554be3f13478ccd5915e746491982163298.tar.bz2
spark-3db72554be3f13478ccd5915e746491982163298.zip
[SPARK-10443] [SQL] Refactor SortMergeOuterJoin to reduce duplication
`LeftOutputIterator` and `RightOutputIterator` are symmetrically identical and can share a lot of code. If someone makes a change in one but forgets to do the same thing in the other we'll end up with inconsistent behavior. This patch also adds inline comments to clarify the intention of the code. Author: Andrew Or <andrew@databricks.com> Closes #8596 from andrewor14/smoj-cleanup.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala138
1 files changed, 77 insertions, 61 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala
index ab20ee573a..c117dff9c8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala
@@ -199,99 +199,115 @@ case class SortMergeOuterJoin(
}
}
-
+/**
+ * An iterator for outputting rows in left outer join.
+ */
private class LeftOuterIterator(
smjScanner: SortMergeJoinScanner,
rightNullRow: InternalRow,
boundCondition: InternalRow => Boolean,
resultProj: InternalRow => InternalRow,
- numRows: LongSQLMetric
- ) extends RowIterator {
- private[this] val joinedRow: JoinedRow = new JoinedRow()
- private[this] var rightIdx: Int = 0
- assert(smjScanner.getBufferedMatches.length == 0)
-
- private def advanceLeft(): Boolean = {
- rightIdx = 0
- if (smjScanner.findNextOuterJoinRows()) {
- joinedRow.withLeft(smjScanner.getStreamedRow)
- if (smjScanner.getBufferedMatches.isEmpty) {
- // There are no matching right rows, so return nulls for the right row
- joinedRow.withRight(rightNullRow)
- } else {
- // Find the next row from the right input that satisfied the bound condition
- if (!advanceRightUntilBoundConditionSatisfied()) {
- joinedRow.withRight(rightNullRow)
- }
- }
- true
- } else {
- // Left input has been exhausted
- false
- }
- }
-
- private def advanceRightUntilBoundConditionSatisfied(): Boolean = {
- var foundMatch: Boolean = false
- while (!foundMatch && rightIdx < smjScanner.getBufferedMatches.length) {
- foundMatch = boundCondition(joinedRow.withRight(smjScanner.getBufferedMatches(rightIdx)))
- rightIdx += 1
- }
- foundMatch
- }
-
- override def advanceNext(): Boolean = {
- val r = advanceRightUntilBoundConditionSatisfied() || advanceLeft()
- if (r) numRows += 1
- r
- }
+ numOutputRows: LongSQLMetric)
+ extends OneSideOuterIterator(
+ smjScanner, rightNullRow, boundCondition, resultProj, numOutputRows) {
- override def getRow: InternalRow = resultProj(joinedRow)
+ protected override def setStreamSideOutput(row: InternalRow): Unit = joinedRow.withLeft(row)
+ protected override def setBufferedSideOutput(row: InternalRow): Unit = joinedRow.withRight(row)
}
+/**
+ * An iterator for outputting rows in right outer join.
+ */
private class RightOuterIterator(
smjScanner: SortMergeJoinScanner,
leftNullRow: InternalRow,
boundCondition: InternalRow => Boolean,
resultProj: InternalRow => InternalRow,
- numRows: LongSQLMetric
- ) extends RowIterator {
- private[this] val joinedRow: JoinedRow = new JoinedRow()
- private[this] var leftIdx: Int = 0
+ numOutputRows: LongSQLMetric)
+ extends OneSideOuterIterator(
+ smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows) {
+
+ protected override def setStreamSideOutput(row: InternalRow): Unit = joinedRow.withRight(row)
+ protected override def setBufferedSideOutput(row: InternalRow): Unit = joinedRow.withLeft(row)
+}
+
+/**
+ * An abstract iterator for sharing code between [[LeftOuterIterator]] and [[RightOuterIterator]].
+ *
+ * Each [[OneSideOuterIterator]] has a streamed side and a buffered side. Each row on the
+ * streamed side will output 0 or many rows, one for each matching row on the buffered side.
+ * If there are no matches, then the buffered side of the joined output will be a null row.
+ *
+ * In left outer join, the left is the streamed side and the right is the buffered side.
+ * In right outer join, the right is the streamed side and the left is the buffered side.
+ *
+ * @param smjScanner a scanner that streams rows and buffers any matching rows
+ * @param bufferedSideNullRow the default row to return when a streamed row has no matches
+ * @param boundCondition an additional filter condition for buffered rows
+ * @param resultProj how the output should be projected
+ * @param numOutputRows an accumulator metric for the number of rows output
+ */
+private abstract class OneSideOuterIterator(
+ smjScanner: SortMergeJoinScanner,
+ bufferedSideNullRow: InternalRow,
+ boundCondition: InternalRow => Boolean,
+ resultProj: InternalRow => InternalRow,
+ numOutputRows: LongSQLMetric) extends RowIterator {
+
+ // A row to store the joined result, reused many times
+ protected[this] val joinedRow: JoinedRow = new JoinedRow()
+
+ // Index of the buffered rows, reset to 0 whenever we advance to a new streamed row
+ private[this] var bufferIndex: Int = 0
+
+ // This iterator is initialized lazily so there should be no matches initially
assert(smjScanner.getBufferedMatches.length == 0)
- private def advanceRight(): Boolean = {
- leftIdx = 0
+ // Set output methods to be overridden by subclasses
+ protected def setStreamSideOutput(row: InternalRow): Unit
+ protected def setBufferedSideOutput(row: InternalRow): Unit
+
+ /**
+ * Advance to the next row on the stream side and populate the buffer with matches.
+ * @return whether there are more rows in the stream to consume.
+ */
+ private def advanceStream(): Boolean = {
+ bufferIndex = 0
if (smjScanner.findNextOuterJoinRows()) {
- joinedRow.withRight(smjScanner.getStreamedRow)
+ setStreamSideOutput(smjScanner.getStreamedRow)
if (smjScanner.getBufferedMatches.isEmpty) {
- // There are no matching left rows, so return nulls for the left row
- joinedRow.withLeft(leftNullRow)
+ // There are no matching rows in the buffer, so return the null row
+ setBufferedSideOutput(bufferedSideNullRow)
} else {
- // Find the next row from the left input that satisfied the bound condition
- if (!advanceLeftUntilBoundConditionSatisfied()) {
- joinedRow.withLeft(leftNullRow)
+ // Find the next row in the buffer that satisfied the bound condition
+ if (!advanceBufferUntilBoundConditionSatisfied()) {
+ setBufferedSideOutput(bufferedSideNullRow)
}
}
true
} else {
- // Right input has been exhausted
+ // Stream has been exhausted
false
}
}
- private def advanceLeftUntilBoundConditionSatisfied(): Boolean = {
+ /**
+ * Advance to the next row in the buffer that satisfies the bound condition.
+ * @return whether there is such a row in the current buffer.
+ */
+ private def advanceBufferUntilBoundConditionSatisfied(): Boolean = {
var foundMatch: Boolean = false
- while (!foundMatch && leftIdx < smjScanner.getBufferedMatches.length) {
- foundMatch = boundCondition(joinedRow.withLeft(smjScanner.getBufferedMatches(leftIdx)))
- leftIdx += 1
+ while (!foundMatch && bufferIndex < smjScanner.getBufferedMatches.length) {
+ setBufferedSideOutput(smjScanner.getBufferedMatches(bufferIndex))
+ foundMatch = boundCondition(joinedRow)
+ bufferIndex += 1
}
foundMatch
}
override def advanceNext(): Boolean = {
- val r = advanceLeftUntilBoundConditionSatisfied() || advanceRight()
- if (r) numRows += 1
+ val r = advanceBufferUntilBoundConditionSatisfied() || advanceStream()
+ if (r) numOutputRows += 1
r
}