From bbd887f53cc4fa03d97932e1b570bd7180783da5 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 15 Mar 2016 19:58:49 -0700 Subject: [SPARK-13918][SQL] Merge SortMergeJoin and SortMergerOuterJoin ## What changes were proposed in this pull request? This PR just move some code from SortMergeOuterJoin into SortMergeJoin. This is for support codegen for outer join. ## How was this patch tested? existing tests. Author: Davies Liu Closes #11743 from davies/gen_smjouter. --- .../spark/sql/execution/SparkStrategies.scala | 4 +- .../spark/sql/execution/WholeStageCodegen.scala | 3 +- .../spark/sql/execution/joins/SortMergeJoin.scala | 505 ++++++++++++++++++--- .../sql/execution/joins/SortMergeOuterJoin.scala | 464 ------------------- .../scala/org/apache/spark/sql/JoinSuite.scala | 7 +- .../apache/spark/sql/execution/PlannerSuite.scala | 3 + .../spark/sql/execution/joins/InnerJoinSuite.scala | 2 +- .../spark/sql/execution/joins/OuterJoinSuite.scala | 4 +- .../sql/execution/metric/SQLMetricsSuite.scala | 10 +- 9 files changed, 467 insertions(+), 535 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala (limited to 'sql') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 113cf9ae2f..7fc6a8267f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -120,7 +120,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) if RowOrdering.isOrderable(leftKeys) => joins.SortMergeJoin( - leftKeys, rightKeys, condition, planLater(left), planLater(right)) :: Nil + leftKeys, rightKeys, Inner, condition, planLater(left), planLater(right)) :: Nil // --- Outer joins -------------------------------------------------------------------------- @@ -136,7 +136,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) if RowOrdering.isOrderable(leftKeys) => - joins.SortMergeOuterJoin( + joins.SortMergeJoin( leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil // --- Cases where this strategy does not apply --------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 81676d3ebb..a54b772045 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -22,6 +22,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.toCommentSafeString @@ -450,7 +451,7 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] { * Inserts a InputAdapter on top of those that do not support codegen. */ private def insertInputAdapter(plan: SparkPlan): SparkPlan = plan match { - case j @ SortMergeJoin(_, _, _, left, right) => + case j @ SortMergeJoin(_, _, _, _, left, right) if j.supportCodegen => // The children of SortMergeJoin should do codegen separately. j.copy(left = InputAdapter(insertWholeStageCodegen(left)), right = InputAdapter(insertWholeStageCodegen(right))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index cffd6f6032..d0724ff6e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -23,9 +23,11 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, RowIterator, SparkPlan} -import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics} +import org.apache.spark.util.collection.BitSet /** * Performs an sort merge join of two child relations. @@ -33,6 +35,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics case class SortMergeJoin( leftKeys: Seq[Expression], rightKeys: Seq[Expression], + joinType: JoinType, condition: Option[Expression], left: SparkPlan, right: SparkPlan) extends BinaryNode with CodegenSupport { @@ -40,10 +43,32 @@ case class SortMergeJoin( override private[sql] lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - override def output: Seq[Attribute] = left.output ++ right.output + override def output: Seq[Attribute] = { + joinType match { + case Inner => + left.output ++ right.output + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case FullOuter => + (left.output ++ right.output).map(_.withNullability(true)) + case x => + throw new IllegalArgumentException( + s"${getClass.getSimpleName} should not take $x as the JoinType") + } + } - override def outputPartitioning: Partitioning = - PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) + override def outputPartitioning: Partitioning = joinType match { + case Inner => PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) + // For left and right outer joins, the output is partitioned by the streamed input's join keys. + case LeftOuter => left.outputPartitioning + case RightOuter => right.outputPartitioning + case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) + case x => + throw new IllegalArgumentException( + s"${getClass.getSimpleName} should not take $x as the JoinType") + } override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil @@ -58,6 +83,12 @@ case class SortMergeJoin( keys.map(SortOrder(_, Ascending)) } + private def createLeftKeyGenerator(): Projection = + UnsafeProjection.create(leftKeys, left.output) + + private def createRightKeyGenerator(): Projection = + UnsafeProjection.create(rightKeys, right.output) + protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") @@ -69,64 +100,122 @@ case class SortMergeJoin( (r: InternalRow) => true } } - new RowIterator { - // The projection used to extract keys from input rows of the left child. - private[this] val leftKeyGenerator = UnsafeProjection.create(leftKeys, left.output) - - // The projection used to extract keys from input rows of the right child. - private[this] val rightKeyGenerator = UnsafeProjection.create(rightKeys, right.output) - - // An ordering that can be used to compare keys from both sides. - private[this] val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType)) - private[this] var currentLeftRow: InternalRow = _ - private[this] var currentRightMatches: ArrayBuffer[InternalRow] = _ - private[this] var currentMatchIdx: Int = -1 - private[this] val smjScanner = new SortMergeJoinScanner( - leftKeyGenerator, - rightKeyGenerator, - keyOrdering, - RowIterator.fromScala(leftIter), - RowIterator.fromScala(rightIter) - ) - private[this] val joinRow = new JoinedRow - private[this] val resultProjection: (InternalRow) => InternalRow = - UnsafeProjection.create(schema) - - if (smjScanner.findNextInnerJoinRows()) { - currentRightMatches = smjScanner.getBufferedMatches - currentLeftRow = smjScanner.getStreamedRow - currentMatchIdx = 0 - } + // An ordering that can be used to compare keys from both sides. + val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType)) + val resultProj: InternalRow => InternalRow = UnsafeProjection.create(output, output) + + joinType match { + case Inner => + new RowIterator { + // The projection used to extract keys from input rows of the left child. + private[this] val leftKeyGenerator = UnsafeProjection.create(leftKeys, left.output) + + // The projection used to extract keys from input rows of the right child. + private[this] val rightKeyGenerator = UnsafeProjection.create(rightKeys, right.output) + + // An ordering that can be used to compare keys from both sides. + private[this] val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType)) + private[this] var currentLeftRow: InternalRow = _ + private[this] var currentRightMatches: ArrayBuffer[InternalRow] = _ + private[this] var currentMatchIdx: Int = -1 + private[this] val smjScanner = new SortMergeJoinScanner( + leftKeyGenerator, + rightKeyGenerator, + keyOrdering, + RowIterator.fromScala(leftIter), + RowIterator.fromScala(rightIter) + ) + private[this] val joinRow = new JoinedRow + private[this] val resultProjection: (InternalRow) => InternalRow = + UnsafeProjection.create(schema) + + if (smjScanner.findNextInnerJoinRows()) { + currentRightMatches = smjScanner.getBufferedMatches + currentLeftRow = smjScanner.getStreamedRow + currentMatchIdx = 0 + } - override def advanceNext(): Boolean = { - while (currentMatchIdx >= 0) { - if (currentMatchIdx == currentRightMatches.length) { - if (smjScanner.findNextInnerJoinRows()) { - currentRightMatches = smjScanner.getBufferedMatches - currentLeftRow = smjScanner.getStreamedRow - currentMatchIdx = 0 - } else { - currentRightMatches = null - currentLeftRow = null - currentMatchIdx = -1 - return false + override def advanceNext(): Boolean = { + while (currentMatchIdx >= 0) { + if (currentMatchIdx == currentRightMatches.length) { + if (smjScanner.findNextInnerJoinRows()) { + currentRightMatches = smjScanner.getBufferedMatches + currentLeftRow = smjScanner.getStreamedRow + currentMatchIdx = 0 + } else { + currentRightMatches = null + currentLeftRow = null + currentMatchIdx = -1 + return false + } + } + joinRow(currentLeftRow, currentRightMatches(currentMatchIdx)) + currentMatchIdx += 1 + if (boundCondition(joinRow)) { + numOutputRows += 1 + return true + } } + false } - joinRow(currentLeftRow, currentRightMatches(currentMatchIdx)) - currentMatchIdx += 1 - if (boundCondition(joinRow)) { - numOutputRows += 1 - return true - } - } - false - } - override def getRow: InternalRow = resultProjection(joinRow) - }.toScala + override def getRow: InternalRow = resultProjection(joinRow) + }.toScala + + case LeftOuter => + val smjScanner = new SortMergeJoinScanner( + streamedKeyGenerator = createLeftKeyGenerator(), + bufferedKeyGenerator = createRightKeyGenerator(), + keyOrdering, + streamedIter = RowIterator.fromScala(leftIter), + bufferedIter = RowIterator.fromScala(rightIter) + ) + val rightNullRow = new GenericInternalRow(right.output.length) + new LeftOuterIterator( + smjScanner, rightNullRow, boundCondition, resultProj, numOutputRows).toScala + + case RightOuter => + val smjScanner = new SortMergeJoinScanner( + streamedKeyGenerator = createRightKeyGenerator(), + bufferedKeyGenerator = createLeftKeyGenerator(), + keyOrdering, + streamedIter = RowIterator.fromScala(rightIter), + bufferedIter = RowIterator.fromScala(leftIter) + ) + val leftNullRow = new GenericInternalRow(left.output.length) + new RightOuterIterator( + smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows).toScala + + case FullOuter => + val leftNullRow = new GenericInternalRow(left.output.length) + val rightNullRow = new GenericInternalRow(right.output.length) + val smjScanner = new SortMergeFullOuterJoinScanner( + leftKeyGenerator = createLeftKeyGenerator(), + rightKeyGenerator = createRightKeyGenerator(), + keyOrdering, + leftIter = RowIterator.fromScala(leftIter), + rightIter = RowIterator.fromScala(rightIter), + boundCondition, + leftNullRow, + rightNullRow) + + new FullOuterIterator( + smjScanner, + resultProj, + numOutputRows).toScala + + case x => + throw new IllegalArgumentException( + s"SortMergeJoin should not take $x as the JoinType") + } + } } + override def supportCodegen: Boolean = { + joinType == Inner + } + override def upstreams(): Seq[RDD[InternalRow]] = { left.execute() :: right.execute() :: Nil } @@ -376,7 +465,7 @@ case class SortMergeJoin( } /** - * Helper class that is used to implement [[SortMergeJoin]] and [[SortMergeOuterJoin]]. + * Helper class that is used to implement [[SortMergeJoin]]. * * To perform an inner (outer) join, users of this class call [[findNextInnerJoinRows()]] * ([[findNextOuterJoinRows()]]), which returns `true` if a result has been produced and `false` @@ -570,3 +659,307 @@ private[joins] class SortMergeJoinScanner( } while (bufferedRow != null && keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0) } } + +/** + * An iterator for outputting rows in left outer join. + */ +private class LeftOuterIterator( + smjScanner: SortMergeJoinScanner, + rightNullRow: InternalRow, + boundCondition: InternalRow => Boolean, + resultProj: InternalRow => InternalRow, + numOutputRows: LongSQLMetric) + extends OneSideOuterIterator( + smjScanner, rightNullRow, boundCondition, resultProj, numOutputRows) { + + protected override def setStreamSideOutput(row: InternalRow): Unit = joinedRow.withLeft(row) + protected override def setBufferedSideOutput(row: InternalRow): Unit = joinedRow.withRight(row) +} + +/** + * An iterator for outputting rows in right outer join. + */ +private class RightOuterIterator( + smjScanner: SortMergeJoinScanner, + leftNullRow: InternalRow, + boundCondition: InternalRow => Boolean, + resultProj: InternalRow => InternalRow, + numOutputRows: LongSQLMetric) + extends OneSideOuterIterator( + smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows) { + + protected override def setStreamSideOutput(row: InternalRow): Unit = joinedRow.withRight(row) + protected override def setBufferedSideOutput(row: InternalRow): Unit = joinedRow.withLeft(row) +} + +/** + * An abstract iterator for sharing code between [[LeftOuterIterator]] and [[RightOuterIterator]]. + * + * Each [[OneSideOuterIterator]] has a streamed side and a buffered side. Each row on the + * streamed side will output 0 or many rows, one for each matching row on the buffered side. + * If there are no matches, then the buffered side of the joined output will be a null row. + * + * In left outer join, the left is the streamed side and the right is the buffered side. + * In right outer join, the right is the streamed side and the left is the buffered side. + * + * @param smjScanner a scanner that streams rows and buffers any matching rows + * @param bufferedSideNullRow the default row to return when a streamed row has no matches + * @param boundCondition an additional filter condition for buffered rows + * @param resultProj how the output should be projected + * @param numOutputRows an accumulator metric for the number of rows output + */ +private abstract class OneSideOuterIterator( + smjScanner: SortMergeJoinScanner, + bufferedSideNullRow: InternalRow, + boundCondition: InternalRow => Boolean, + resultProj: InternalRow => InternalRow, + numOutputRows: LongSQLMetric) extends RowIterator { + + // A row to store the joined result, reused many times + protected[this] val joinedRow: JoinedRow = new JoinedRow() + + // Index of the buffered rows, reset to 0 whenever we advance to a new streamed row + private[this] var bufferIndex: Int = 0 + + // This iterator is initialized lazily so there should be no matches initially + assert(smjScanner.getBufferedMatches.length == 0) + + // Set output methods to be overridden by subclasses + protected def setStreamSideOutput(row: InternalRow): Unit + protected def setBufferedSideOutput(row: InternalRow): Unit + + /** + * Advance to the next row on the stream side and populate the buffer with matches. + * @return whether there are more rows in the stream to consume. + */ + private def advanceStream(): Boolean = { + bufferIndex = 0 + if (smjScanner.findNextOuterJoinRows()) { + setStreamSideOutput(smjScanner.getStreamedRow) + if (smjScanner.getBufferedMatches.isEmpty) { + // There are no matching rows in the buffer, so return the null row + setBufferedSideOutput(bufferedSideNullRow) + } else { + // Find the next row in the buffer that satisfied the bound condition + if (!advanceBufferUntilBoundConditionSatisfied()) { + setBufferedSideOutput(bufferedSideNullRow) + } + } + true + } else { + // Stream has been exhausted + false + } + } + + /** + * Advance to the next row in the buffer that satisfies the bound condition. + * @return whether there is such a row in the current buffer. + */ + private def advanceBufferUntilBoundConditionSatisfied(): Boolean = { + var foundMatch: Boolean = false + while (!foundMatch && bufferIndex < smjScanner.getBufferedMatches.length) { + setBufferedSideOutput(smjScanner.getBufferedMatches(bufferIndex)) + foundMatch = boundCondition(joinedRow) + bufferIndex += 1 + } + foundMatch + } + + override def advanceNext(): Boolean = { + val r = advanceBufferUntilBoundConditionSatisfied() || advanceStream() + if (r) numOutputRows += 1 + r + } + + override def getRow: InternalRow = resultProj(joinedRow) +} + +private class SortMergeFullOuterJoinScanner( + leftKeyGenerator: Projection, + rightKeyGenerator: Projection, + keyOrdering: Ordering[InternalRow], + leftIter: RowIterator, + rightIter: RowIterator, + boundCondition: InternalRow => Boolean, + leftNullRow: InternalRow, + rightNullRow: InternalRow) { + private[this] val joinedRow: JoinedRow = new JoinedRow() + private[this] var leftRow: InternalRow = _ + private[this] var leftRowKey: InternalRow = _ + private[this] var rightRow: InternalRow = _ + private[this] var rightRowKey: InternalRow = _ + + private[this] var leftIndex: Int = 0 + private[this] var rightIndex: Int = 0 + private[this] val leftMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow] + private[this] val rightMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow] + private[this] var leftMatched: BitSet = new BitSet(1) + private[this] var rightMatched: BitSet = new BitSet(1) + + advancedLeft() + advancedRight() + + // --- Private methods -------------------------------------------------------------------------- + + /** + * Advance the left iterator and compute the new row's join key. + * @return true if the left iterator returned a row and false otherwise. + */ + private def advancedLeft(): Boolean = { + if (leftIter.advanceNext()) { + leftRow = leftIter.getRow + leftRowKey = leftKeyGenerator(leftRow) + true + } else { + leftRow = null + leftRowKey = null + false + } + } + + /** + * Advance the right iterator and compute the new row's join key. + * @return true if the right iterator returned a row and false otherwise. + */ + private def advancedRight(): Boolean = { + if (rightIter.advanceNext()) { + rightRow = rightIter.getRow + rightRowKey = rightKeyGenerator(rightRow) + true + } else { + rightRow = null + rightRowKey = null + false + } + } + + /** + * Populate the left and right buffers with rows matching the provided key. + * This consumes rows from both iterators until their keys are different from the matching key. + */ + private def findMatchingRows(matchingKey: InternalRow): Unit = { + leftMatches.clear() + rightMatches.clear() + leftIndex = 0 + rightIndex = 0 + + while (leftRowKey != null && keyOrdering.compare(leftRowKey, matchingKey) == 0) { + leftMatches += leftRow.copy() + advancedLeft() + } + while (rightRowKey != null && keyOrdering.compare(rightRowKey, matchingKey) == 0) { + rightMatches += rightRow.copy() + advancedRight() + } + + if (leftMatches.size <= leftMatched.capacity) { + leftMatched.clear() + } else { + leftMatched = new BitSet(leftMatches.size) + } + if (rightMatches.size <= rightMatched.capacity) { + rightMatched.clear() + } else { + rightMatched = new BitSet(rightMatches.size) + } + } + + /** + * Scan the left and right buffers for the next valid match. + * + * Note: this method mutates `joinedRow` to point to the latest matching rows in the buffers. + * If a left row has no valid matches on the right, or a right row has no valid matches on the + * left, then the row is joined with the null row and the result is considered a valid match. + * + * @return true if a valid match is found, false otherwise. + */ + private def scanNextInBuffered(): Boolean = { + while (leftIndex < leftMatches.size) { + while (rightIndex < rightMatches.size) { + joinedRow(leftMatches(leftIndex), rightMatches(rightIndex)) + if (boundCondition(joinedRow)) { + leftMatched.set(leftIndex) + rightMatched.set(rightIndex) + rightIndex += 1 + return true + } + rightIndex += 1 + } + rightIndex = 0 + if (!leftMatched.get(leftIndex)) { + // the left row has never matched any right row, join it with null row + joinedRow(leftMatches(leftIndex), rightNullRow) + leftIndex += 1 + return true + } + leftIndex += 1 + } + + while (rightIndex < rightMatches.size) { + if (!rightMatched.get(rightIndex)) { + // the right row has never matched any left row, join it with null row + joinedRow(leftNullRow, rightMatches(rightIndex)) + rightIndex += 1 + return true + } + rightIndex += 1 + } + + // There are no more valid matches in the left and right buffers + false + } + + // --- Public methods -------------------------------------------------------------------------- + + def getJoinedRow(): JoinedRow = joinedRow + + def advanceNext(): Boolean = { + // If we already buffered some matching rows, use them directly + if (leftIndex <= leftMatches.size || rightIndex <= rightMatches.size) { + if (scanNextInBuffered()) { + return true + } + } + + if (leftRow != null && (leftRowKey.anyNull || rightRow == null)) { + joinedRow(leftRow.copy(), rightNullRow) + advancedLeft() + true + } else if (rightRow != null && (rightRowKey.anyNull || leftRow == null)) { + joinedRow(leftNullRow, rightRow.copy()) + advancedRight() + true + } else if (leftRow != null && rightRow != null) { + // Both rows are present and neither have null values, + // so we populate the buffers with rows matching the next key + val comp = keyOrdering.compare(leftRowKey, rightRowKey) + if (comp <= 0) { + findMatchingRows(leftRowKey.copy()) + } else { + findMatchingRows(rightRowKey.copy()) + } + scanNextInBuffered() + true + } else { + // Both iterators have been consumed + false + } + } +} + +private class FullOuterIterator( + smjScanner: SortMergeFullOuterJoinScanner, + resultProj: InternalRow => InternalRow, + numRows: LongSQLMetric +) extends RowIterator { + private[this] val joinedRow: JoinedRow = smjScanner.getJoinedRow() + + override def advanceNext(): Boolean = { + val r = smjScanner.advanceNext() + if (r) numRows += 1 + r + } + + override def getRow: InternalRow = resultProj(joinedRow) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala deleted file mode 100644 index 40a6c93b59..0000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala +++ /dev/null @@ -1,464 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan} -import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics} -import org.apache.spark.util.collection.BitSet - -/** - * Performs an sort merge outer join of two child relations. - */ -case class SortMergeOuterJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - joinType: JoinType, - condition: Option[Expression], - left: SparkPlan, - right: SparkPlan) extends BinaryNode { - - override private[sql] lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - - override def output: Seq[Attribute] = { - joinType match { - case LeftOuter => - left.output ++ right.output.map(_.withNullability(true)) - case RightOuter => - left.output.map(_.withNullability(true)) ++ right.output - case FullOuter => - (left.output ++ right.output).map(_.withNullability(true)) - case x => - throw new IllegalArgumentException( - s"${getClass.getSimpleName} should not take $x as the JoinType") - } - } - - override def outputPartitioning: Partitioning = joinType match { - // For left and right outer joins, the output is partitioned by the streamed input's join keys. - case LeftOuter => left.outputPartitioning - case RightOuter => right.outputPartitioning - case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) - case x => - throw new IllegalArgumentException( - s"${getClass.getSimpleName} should not take $x as the JoinType") - } - - override def outputOrdering: Seq[SortOrder] = joinType match { - // For left and right outer joins, the output is ordered by the streamed input's join keys. - case LeftOuter => requiredOrders(leftKeys) - case RightOuter => requiredOrders(rightKeys) - // there are null rows in both streams, so there is no order - case FullOuter => Nil - case x => throw new IllegalArgumentException( - s"SortMergeOuterJoin should not take $x as the JoinType") - } - - override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - - override def requiredChildOrdering: Seq[Seq[SortOrder]] = - requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil - - private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = { - // This must be ascending in order to agree with the `keyOrdering` defined in `doExecute()`. - keys.map(SortOrder(_, Ascending)) - } - - private def createLeftKeyGenerator(): Projection = - UnsafeProjection.create(leftKeys, left.output) - - private def createRightKeyGenerator(): Projection = - UnsafeProjection.create(rightKeys, right.output) - - override def doExecute(): RDD[InternalRow] = { - val numOutputRows = longMetric("numOutputRows") - - left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => - // An ordering that can be used to compare keys from both sides. - val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType)) - val boundCondition: (InternalRow) => Boolean = { - condition.map { cond => - newPredicate(cond, left.output ++ right.output) - }.getOrElse { - (r: InternalRow) => true - } - } - val resultProj: InternalRow => InternalRow = UnsafeProjection.create(output, output) - - joinType match { - case LeftOuter => - val smjScanner = new SortMergeJoinScanner( - streamedKeyGenerator = createLeftKeyGenerator(), - bufferedKeyGenerator = createRightKeyGenerator(), - keyOrdering, - streamedIter = RowIterator.fromScala(leftIter), - bufferedIter = RowIterator.fromScala(rightIter) - ) - val rightNullRow = new GenericInternalRow(right.output.length) - new LeftOuterIterator( - smjScanner, rightNullRow, boundCondition, resultProj, numOutputRows).toScala - - case RightOuter => - val smjScanner = new SortMergeJoinScanner( - streamedKeyGenerator = createRightKeyGenerator(), - bufferedKeyGenerator = createLeftKeyGenerator(), - keyOrdering, - streamedIter = RowIterator.fromScala(rightIter), - bufferedIter = RowIterator.fromScala(leftIter) - ) - val leftNullRow = new GenericInternalRow(left.output.length) - new RightOuterIterator( - smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows).toScala - - case FullOuter => - val leftNullRow = new GenericInternalRow(left.output.length) - val rightNullRow = new GenericInternalRow(right.output.length) - val smjScanner = new SortMergeFullOuterJoinScanner( - leftKeyGenerator = createLeftKeyGenerator(), - rightKeyGenerator = createRightKeyGenerator(), - keyOrdering, - leftIter = RowIterator.fromScala(leftIter), - rightIter = RowIterator.fromScala(rightIter), - boundCondition, - leftNullRow, - rightNullRow) - - new FullOuterIterator( - smjScanner, - resultProj, - numOutputRows).toScala - - case x => - throw new IllegalArgumentException( - s"SortMergeOuterJoin should not take $x as the JoinType") - } - } - } -} - -/** - * An iterator for outputting rows in left outer join. - */ -private class LeftOuterIterator( - smjScanner: SortMergeJoinScanner, - rightNullRow: InternalRow, - boundCondition: InternalRow => Boolean, - resultProj: InternalRow => InternalRow, - numOutputRows: LongSQLMetric) - extends OneSideOuterIterator( - smjScanner, rightNullRow, boundCondition, resultProj, numOutputRows) { - - protected override def setStreamSideOutput(row: InternalRow): Unit = joinedRow.withLeft(row) - protected override def setBufferedSideOutput(row: InternalRow): Unit = joinedRow.withRight(row) -} - -/** - * An iterator for outputting rows in right outer join. - */ -private class RightOuterIterator( - smjScanner: SortMergeJoinScanner, - leftNullRow: InternalRow, - boundCondition: InternalRow => Boolean, - resultProj: InternalRow => InternalRow, - numOutputRows: LongSQLMetric) - extends OneSideOuterIterator( - smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows) { - - protected override def setStreamSideOutput(row: InternalRow): Unit = joinedRow.withRight(row) - protected override def setBufferedSideOutput(row: InternalRow): Unit = joinedRow.withLeft(row) -} - -/** - * An abstract iterator for sharing code between [[LeftOuterIterator]] and [[RightOuterIterator]]. - * - * Each [[OneSideOuterIterator]] has a streamed side and a buffered side. Each row on the - * streamed side will output 0 or many rows, one for each matching row on the buffered side. - * If there are no matches, then the buffered side of the joined output will be a null row. - * - * In left outer join, the left is the streamed side and the right is the buffered side. - * In right outer join, the right is the streamed side and the left is the buffered side. - * - * @param smjScanner a scanner that streams rows and buffers any matching rows - * @param bufferedSideNullRow the default row to return when a streamed row has no matches - * @param boundCondition an additional filter condition for buffered rows - * @param resultProj how the output should be projected - * @param numOutputRows an accumulator metric for the number of rows output - */ -private abstract class OneSideOuterIterator( - smjScanner: SortMergeJoinScanner, - bufferedSideNullRow: InternalRow, - boundCondition: InternalRow => Boolean, - resultProj: InternalRow => InternalRow, - numOutputRows: LongSQLMetric) extends RowIterator { - - // A row to store the joined result, reused many times - protected[this] val joinedRow: JoinedRow = new JoinedRow() - - // Index of the buffered rows, reset to 0 whenever we advance to a new streamed row - private[this] var bufferIndex: Int = 0 - - // This iterator is initialized lazily so there should be no matches initially - assert(smjScanner.getBufferedMatches.length == 0) - - // Set output methods to be overridden by subclasses - protected def setStreamSideOutput(row: InternalRow): Unit - protected def setBufferedSideOutput(row: InternalRow): Unit - - /** - * Advance to the next row on the stream side and populate the buffer with matches. - * @return whether there are more rows in the stream to consume. - */ - private def advanceStream(): Boolean = { - bufferIndex = 0 - if (smjScanner.findNextOuterJoinRows()) { - setStreamSideOutput(smjScanner.getStreamedRow) - if (smjScanner.getBufferedMatches.isEmpty) { - // There are no matching rows in the buffer, so return the null row - setBufferedSideOutput(bufferedSideNullRow) - } else { - // Find the next row in the buffer that satisfied the bound condition - if (!advanceBufferUntilBoundConditionSatisfied()) { - setBufferedSideOutput(bufferedSideNullRow) - } - } - true - } else { - // Stream has been exhausted - false - } - } - - /** - * Advance to the next row in the buffer that satisfies the bound condition. - * @return whether there is such a row in the current buffer. - */ - private def advanceBufferUntilBoundConditionSatisfied(): Boolean = { - var foundMatch: Boolean = false - while (!foundMatch && bufferIndex < smjScanner.getBufferedMatches.length) { - setBufferedSideOutput(smjScanner.getBufferedMatches(bufferIndex)) - foundMatch = boundCondition(joinedRow) - bufferIndex += 1 - } - foundMatch - } - - override def advanceNext(): Boolean = { - val r = advanceBufferUntilBoundConditionSatisfied() || advanceStream() - if (r) numOutputRows += 1 - r - } - - override def getRow: InternalRow = resultProj(joinedRow) -} - -private class SortMergeFullOuterJoinScanner( - leftKeyGenerator: Projection, - rightKeyGenerator: Projection, - keyOrdering: Ordering[InternalRow], - leftIter: RowIterator, - rightIter: RowIterator, - boundCondition: InternalRow => Boolean, - leftNullRow: InternalRow, - rightNullRow: InternalRow) { - private[this] val joinedRow: JoinedRow = new JoinedRow() - private[this] var leftRow: InternalRow = _ - private[this] var leftRowKey: InternalRow = _ - private[this] var rightRow: InternalRow = _ - private[this] var rightRowKey: InternalRow = _ - - private[this] var leftIndex: Int = 0 - private[this] var rightIndex: Int = 0 - private[this] val leftMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow] - private[this] val rightMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow] - private[this] var leftMatched: BitSet = new BitSet(1) - private[this] var rightMatched: BitSet = new BitSet(1) - - advancedLeft() - advancedRight() - - // --- Private methods -------------------------------------------------------------------------- - - /** - * Advance the left iterator and compute the new row's join key. - * @return true if the left iterator returned a row and false otherwise. - */ - private def advancedLeft(): Boolean = { - if (leftIter.advanceNext()) { - leftRow = leftIter.getRow - leftRowKey = leftKeyGenerator(leftRow) - true - } else { - leftRow = null - leftRowKey = null - false - } - } - - /** - * Advance the right iterator and compute the new row's join key. - * @return true if the right iterator returned a row and false otherwise. - */ - private def advancedRight(): Boolean = { - if (rightIter.advanceNext()) { - rightRow = rightIter.getRow - rightRowKey = rightKeyGenerator(rightRow) - true - } else { - rightRow = null - rightRowKey = null - false - } - } - - /** - * Populate the left and right buffers with rows matching the provided key. - * This consumes rows from both iterators until their keys are different from the matching key. - */ - private def findMatchingRows(matchingKey: InternalRow): Unit = { - leftMatches.clear() - rightMatches.clear() - leftIndex = 0 - rightIndex = 0 - - while (leftRowKey != null && keyOrdering.compare(leftRowKey, matchingKey) == 0) { - leftMatches += leftRow.copy() - advancedLeft() - } - while (rightRowKey != null && keyOrdering.compare(rightRowKey, matchingKey) == 0) { - rightMatches += rightRow.copy() - advancedRight() - } - - if (leftMatches.size <= leftMatched.capacity) { - leftMatched.clear() - } else { - leftMatched = new BitSet(leftMatches.size) - } - if (rightMatches.size <= rightMatched.capacity) { - rightMatched.clear() - } else { - rightMatched = new BitSet(rightMatches.size) - } - } - - /** - * Scan the left and right buffers for the next valid match. - * - * Note: this method mutates `joinedRow` to point to the latest matching rows in the buffers. - * If a left row has no valid matches on the right, or a right row has no valid matches on the - * left, then the row is joined with the null row and the result is considered a valid match. - * - * @return true if a valid match is found, false otherwise. - */ - private def scanNextInBuffered(): Boolean = { - while (leftIndex < leftMatches.size) { - while (rightIndex < rightMatches.size) { - joinedRow(leftMatches(leftIndex), rightMatches(rightIndex)) - if (boundCondition(joinedRow)) { - leftMatched.set(leftIndex) - rightMatched.set(rightIndex) - rightIndex += 1 - return true - } - rightIndex += 1 - } - rightIndex = 0 - if (!leftMatched.get(leftIndex)) { - // the left row has never matched any right row, join it with null row - joinedRow(leftMatches(leftIndex), rightNullRow) - leftIndex += 1 - return true - } - leftIndex += 1 - } - - while (rightIndex < rightMatches.size) { - if (!rightMatched.get(rightIndex)) { - // the right row has never matched any left row, join it with null row - joinedRow(leftNullRow, rightMatches(rightIndex)) - rightIndex += 1 - return true - } - rightIndex += 1 - } - - // There are no more valid matches in the left and right buffers - false - } - - // --- Public methods -------------------------------------------------------------------------- - - def getJoinedRow(): JoinedRow = joinedRow - - def advanceNext(): Boolean = { - // If we already buffered some matching rows, use them directly - if (leftIndex <= leftMatches.size || rightIndex <= rightMatches.size) { - if (scanNextInBuffered()) { - return true - } - } - - if (leftRow != null && (leftRowKey.anyNull || rightRow == null)) { - joinedRow(leftRow.copy(), rightNullRow) - advancedLeft() - true - } else if (rightRow != null && (rightRowKey.anyNull || leftRow == null)) { - joinedRow(leftNullRow, rightRow.copy()) - advancedRight() - true - } else if (leftRow != null && rightRow != null) { - // Both rows are present and neither have null values, - // so we populate the buffers with rows matching the next key - val comp = keyOrdering.compare(leftRowKey, rightRowKey) - if (comp <= 0) { - findMatchingRows(leftRowKey.copy()) - } else { - findMatchingRows(rightRowKey.copy()) - } - scanNextInBuffered() - true - } else { - // Both iterators have been consumed - false - } - } -} - -private class FullOuterIterator( - smjScanner: SortMergeFullOuterJoinScanner, - resultProj: InternalRow => InternalRow, - numRows: LongSQLMetric - ) extends RowIterator { - private[this] val joinedRow: JoinedRow = smjScanner.getJoinedRow() - - override def advanceNext(): Boolean = { - val r = smjScanner.advanceNext() - if (r) numRows += 1 - r - } - - override def getRow: InternalRow = resultProj(joinedRow) -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 50647c2840..580e8d815a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -51,7 +51,6 @@ class JoinSuite extends QueryTest with SharedSQLContext { case j: BroadcastNestedLoopJoin => j case j: BroadcastLeftSemiJoinHash => j case j: SortMergeJoin => j - case j: SortMergeOuterJoin => j } assert(operators.size === 1) @@ -83,13 +82,13 @@ class JoinSuite extends QueryTest with SharedSQLContext { ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]), - ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[SortMergeOuterJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[SortMergeJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]), ("SELECT * FROM testData right join testData2 ON key = a and key = 2", - classOf[SortMergeOuterJoin]), + classOf[SortMergeJoin]), ("SELECT * FROM testData full outer join testData2 ON key = a", - classOf[SortMergeOuterJoin]), + classOf[SortMergeJoin]), ("SELECT * FROM testData left JOIN testData2 ON (key * a != key + a)", classOf[BroadcastNestedLoopJoin]), ("SELECT * FROM testData right JOIN testData2 ON (key * a != key + a)", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 88fbcda296..9cd50abda6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{execution, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder} +import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.columnar.InMemoryRelation @@ -487,6 +488,7 @@ class PlannerSuite extends SharedSQLContext { val inputPlan = SortMergeJoin( Literal(1) :: Nil, Literal(1) :: Nil, + Inner, None, shuffle, shuffle) @@ -503,6 +505,7 @@ class PlannerSuite extends SharedSQLContext { val inputPlan2 = SortMergeJoin( Literal(1) :: Nil, Literal(1) :: Nil, + Inner, None, ShuffleExchange(finalPartitioning, inputPlan), ShuffleExchange(finalPartitioning, inputPlan)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index eeb44404e9..814e25d10e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -108,7 +108,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { leftPlan: SparkPlan, rightPlan: SparkPlan) = { val sortMergeJoin = - joins.SortMergeJoin(leftKeys, rightKeys, boundCondition, leftPlan, rightPlan) + joins.SortMergeJoin(leftKeys, rightKeys, Inner, boundCondition, leftPlan, rightPlan) EnsureRequirements(sqlContext.sessionState.conf).apply(sortMergeJoin) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 4525486430..547d06236b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -94,12 +94,12 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { } } - test(s"$testName using SortMergeOuterJoin") { + test(s"$testName using SortMergeJoin") { extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => EnsureRequirements(sqlContext.sessionState.conf).apply( - SortMergeOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)), + SortMergeJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index f754acb761..d7bd215af5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -179,18 +179,18 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { } } - test("SortMergeOuterJoin metrics") { - // Because SortMergeOuterJoin may skip different rows if the number of partitions is different, + test("SortMergeJoin(outer) metrics") { + // Because SortMergeJoin may skip different rows if the number of partitions is different, // this test should use the deterministic number of partitions. val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) testDataForJoin.registerTempTable("testDataForJoin") withTempTable("testDataForJoin") { // Assume the execution plan is - // ... -> SortMergeOuterJoin(nodeId = 1) -> TungstenProject(nodeId = 0) + // ... -> SortMergeJoin(nodeId = 1) -> TungstenProject(nodeId = 0) val df = sqlContext.sql( "SELECT * FROM testData2 left JOIN testDataForJoin ON testData2.a = testDataForJoin.a") testSparkPlanMetrics(df, 1, Map( - 0L -> ("SortMergeOuterJoin", Map( + 0L -> ("SortMergeJoin", Map( // It's 4 because we only read 3 rows in the first partition and 1 row in the second one "number of output rows" -> 8L))) ) @@ -198,7 +198,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { val df2 = sqlContext.sql( "SELECT * FROM testDataForJoin right JOIN testData2 ON testData2.a = testDataForJoin.a") testSparkPlanMetrics(df2, 1, Map( - 0L -> ("SortMergeOuterJoin", Map( + 0L -> ("SortMergeJoin", Map( // It's 4 because we only read 3 rows in the first partition and 1 row in the second one "number of output rows" -> 8L))) ) -- cgit v1.2.3