From 45de518742446ddfbd4816c9d0f8501139f9bc2d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 9 Sep 2015 16:02:27 -0700 Subject: [SPARK-9730] [SQL] Add Full Outer Join support for SortMergeJoin This PR is based on #8383 , thanks to viirya JIRA: https://issues.apache.org/jira/browse/SPARK-9730 This patch adds the Full Outer Join support for SortMergeJoin. A new class SortMergeFullJoinScanner is added to scan rows from left and right iterators. FullOuterIterator is simply a wrapper of type RowIterator to consume joined rows from SortMergeFullJoinScanner. Closes #8383 Author: Liang-Chi Hsieh Author: Davies Liu Closes #8579 from davies/smj_fullouter. --- .../org/apache/spark/util/collection/BitSet.scala | 11 + .../spark/sql/execution/SparkStrategies.scala | 9 +- .../sql/execution/joins/SortMergeOuterJoin.scala | 227 ++++++++++++++++++++- .../scala/org/apache/spark/sql/JoinSuite.scala | 2 +- .../spark/sql/execution/joins/OuterJoinSuite.scala | 44 ++-- 5 files changed, 259 insertions(+), 34 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala index 9c15b1188d..7ab67fc3a2 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala @@ -32,6 +32,17 @@ class BitSet(numBits: Int) extends Serializable { */ def capacity: Int = numWords * 64 + /** + * Clear all set bits. + */ + def clear(): Unit = { + var i = 0 + while (i < numWords) { + words(i) = 0L + i += 1 + } + } + /** * Set all the bits up to a given index */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 2170bc73a0..4572d5efc9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -132,15 +132,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { joins.BroadcastHashOuterJoin( leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil - case ExtractEquiJoinKeys(LeftOuter, leftKeys, rightKeys, condition, left, right) + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) => joins.SortMergeOuterJoin( - leftKeys, rightKeys, LeftOuter, condition, planLater(left), planLater(right)) :: Nil - - case ExtractEquiJoinKeys(RightOuter, leftKeys, rightKeys, condition, left, right) - if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) => - joins.SortMergeOuterJoin( - leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil + leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) => joins.ShuffledHashOuterJoin( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala index dea9e5e580..ab20ee573a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala @@ -17,20 +17,21 @@ package org.apache.spark.sql.execution.joins +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan} +import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan} +import org.apache.spark.util.collection.BitSet /** * :: DeveloperApi :: * Performs an sort merge outer join of two child relations. - * - * Note: this does not support full outer join yet; see SPARK-9730 for progress on this. */ @DeveloperApi case class SortMergeOuterJoin( @@ -52,6 +53,8 @@ case class SortMergeOuterJoin( left.output ++ right.output.map(_.withNullability(true)) case RightOuter => left.output.map(_.withNullability(true)) ++ right.output + case FullOuter => + (left.output ++ right.output).map(_.withNullability(true)) case x => throw new IllegalArgumentException( s"${getClass.getSimpleName} should not take $x as the JoinType") @@ -62,6 +65,7 @@ case class SortMergeOuterJoin( // For left and right outer joins, the output is partitioned by the streamed input's join keys. case LeftOuter => left.outputPartitioning case RightOuter => right.outputPartitioning + case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) case x => throw new IllegalArgumentException( s"${getClass.getSimpleName} should not take $x as the JoinType") @@ -71,6 +75,8 @@ case class SortMergeOuterJoin( // For left and right outer joins, the output is ordered by the streamed input's join keys. case LeftOuter => requiredOrders(leftKeys) case RightOuter => requiredOrders(rightKeys) + // there are null rows in both streams, so there is no order + case FullOuter => Nil case x => throw new IllegalArgumentException( s"SortMergeOuterJoin should not take $x as the JoinType") } @@ -165,6 +171,26 @@ case class SortMergeOuterJoin( new RightOuterIterator( smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows).toScala + case FullOuter => + val leftNullRow = new GenericInternalRow(left.output.length) + val rightNullRow = new GenericInternalRow(right.output.length) + val smjScanner = new SortMergeFullOuterJoinScanner( + leftKeyGenerator = createLeftKeyGenerator(), + rightKeyGenerator = createRightKeyGenerator(), + keyOrdering, + leftIter = RowIterator.fromScala(leftIter), + numLeftRows, + rightIter = RowIterator.fromScala(rightIter), + numRightRows, + boundCondition, + leftNullRow, + rightNullRow) + + new FullOuterIterator( + smjScanner, + resultProj, + numOutputRows).toScala + case x => throw new IllegalArgumentException( s"SortMergeOuterJoin should not take $x as the JoinType") @@ -271,3 +297,196 @@ private class RightOuterIterator( override def getRow: InternalRow = resultProj(joinedRow) } + +private class SortMergeFullOuterJoinScanner( + leftKeyGenerator: Projection, + rightKeyGenerator: Projection, + keyOrdering: Ordering[InternalRow], + leftIter: RowIterator, + numLeftRows: LongSQLMetric, + rightIter: RowIterator, + numRightRows: LongSQLMetric, + boundCondition: InternalRow => Boolean, + leftNullRow: InternalRow, + rightNullRow: InternalRow) { + private[this] val joinedRow: JoinedRow = new JoinedRow() + private[this] var leftRow: InternalRow = _ + private[this] var leftRowKey: InternalRow = _ + private[this] var rightRow: InternalRow = _ + private[this] var rightRowKey: InternalRow = _ + + private[this] var leftIndex: Int = 0 + private[this] var rightIndex: Int = 0 + private[this] val leftMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow] + private[this] val rightMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow] + private[this] var leftMatched: BitSet = new BitSet(1) + private[this] var rightMatched: BitSet = new BitSet(1) + + advancedLeft() + advancedRight() + + // --- Private methods -------------------------------------------------------------------------- + + /** + * Advance the left iterator and compute the new row's join key. + * @return true if the left iterator returned a row and false otherwise. + */ + private def advancedLeft(): Boolean = { + if (leftIter.advanceNext()) { + leftRow = leftIter.getRow + leftRowKey = leftKeyGenerator(leftRow) + numLeftRows += 1 + true + } else { + leftRow = null + leftRowKey = null + false + } + } + + /** + * Advance the right iterator and compute the new row's join key. + * @return true if the right iterator returned a row and false otherwise. + */ + private def advancedRight(): Boolean = { + if (rightIter.advanceNext()) { + rightRow = rightIter.getRow + rightRowKey = rightKeyGenerator(rightRow) + numRightRows += 1 + true + } else { + rightRow = null + rightRowKey = null + false + } + } + + /** + * Populate the left and right buffers with rows matching the provided key. + * This consumes rows from both iterators until their keys are different from the matching key. + */ + private def findMatchingRows(matchingKey: InternalRow): Unit = { + leftMatches.clear() + rightMatches.clear() + leftIndex = 0 + rightIndex = 0 + + while (leftRowKey != null && keyOrdering.compare(leftRowKey, matchingKey) == 0) { + leftMatches += leftRow.copy() + advancedLeft() + } + while (rightRowKey != null && keyOrdering.compare(rightRowKey, matchingKey) == 0) { + rightMatches += rightRow.copy() + advancedRight() + } + + if (leftMatches.size <= leftMatched.capacity) { + leftMatched.clear() + } else { + leftMatched = new BitSet(leftMatches.size) + } + if (rightMatches.size <= rightMatched.capacity) { + rightMatched.clear() + } else { + rightMatched = new BitSet(rightMatches.size) + } + } + + /** + * Scan the left and right buffers for the next valid match. + * + * Note: this method mutates `joinedRow` to point to the latest matching rows in the buffers. + * If a left row has no valid matches on the right, or a right row has no valid matches on the + * left, then the row is joined with the null row and the result is considered a valid match. + * + * @return true if a valid match is found, false otherwise. + */ + private def scanNextInBuffered(): Boolean = { + while (leftIndex < leftMatches.size) { + while (rightIndex < rightMatches.size) { + joinedRow(leftMatches(leftIndex), rightMatches(rightIndex)) + if (boundCondition(joinedRow)) { + leftMatched.set(leftIndex) + rightMatched.set(rightIndex) + rightIndex += 1 + return true + } + rightIndex += 1 + } + rightIndex = 0 + if (!leftMatched.get(leftIndex)) { + // the left row has never matched any right row, join it with null row + joinedRow(leftMatches(leftIndex), rightNullRow) + leftIndex += 1 + return true + } + leftIndex += 1 + } + + while (rightIndex < rightMatches.size) { + if (!rightMatched.get(rightIndex)) { + // the right row has never matched any left row, join it with null row + joinedRow(leftNullRow, rightMatches(rightIndex)) + rightIndex += 1 + return true + } + rightIndex += 1 + } + + // There are no more valid matches in the left and right buffers + false + } + + // --- Public methods -------------------------------------------------------------------------- + + def getJoinedRow(): JoinedRow = joinedRow + + def advanceNext(): Boolean = { + // If we already buffered some matching rows, use them directly + if (leftIndex <= leftMatches.size || rightIndex <= rightMatches.size) { + if (scanNextInBuffered()) { + return true + } + } + + if (leftRow != null && (leftRowKey.anyNull || rightRow == null)) { + joinedRow(leftRow.copy(), rightNullRow) + advancedLeft() + true + } else if (rightRow != null && (rightRowKey.anyNull || leftRow == null)) { + joinedRow(leftNullRow, rightRow.copy()) + advancedRight() + true + } else if (leftRow != null && rightRow != null) { + // Both rows are present and neither have null values, + // so we populate the buffers with rows matching the next key + val comp = keyOrdering.compare(leftRowKey, rightRowKey) + if (comp <= 0) { + findMatchingRows(leftRowKey.copy()) + } else { + findMatchingRows(rightRowKey.copy()) + } + scanNextInBuffered() + true + } else { + // Both iterators have been consumed + false + } + } +} + +private class FullOuterIterator( + smjScanner: SortMergeFullOuterJoinScanner, + resultProj: InternalRow => InternalRow, + numRows: LongSQLMetric + ) extends RowIterator { + private[this] val joinedRow: JoinedRow = smjScanner.getJoinedRow() + + override def advanceNext(): Boolean = { + val r = smjScanner.advanceNext() + if (r) numRows += 1 + r + } + + override def getRow: InternalRow = resultProj(joinedRow) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index b05435bad5..7a027e1308 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -83,7 +83,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { ("SELECT * FROM testData right join testData2 ON key = a and key = 2", classOf[SortMergeOuterJoin]), ("SELECT * FROM testData full outer join testData2 ON key = a", - classOf[ShuffledHashOuterJoin]), + classOf[SortMergeOuterJoin]), ("SELECT * FROM testData left JOIN testData2 ON (key * a != key + a)", classOf[BroadcastNestedLoopJoin]), ("SELECT * FROM testData right JOIN testData2 ON (key * a != key + a)", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index c2e0bdac17..09e0237a7c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -76,37 +76,37 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { test(s"$testName using ShuffledHashOuterJoin") { extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - EnsureRequirements(sqlContext).apply( - ShuffledHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(sqlContext).apply( + ShuffledHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } } } if (joinType != FullOuter) { test(s"$testName using BroadcastHashOuterJoin") { extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } } } + } - test(s"$testName using SortMergeOuterJoin") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - EnsureRequirements(sqlContext).apply( - SortMergeOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)), - expectedAnswer.map(Row.fromTuple), - sortAnswers = false) - } + test(s"$testName using SortMergeOuterJoin") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(sqlContext).apply( + SortMergeOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) } } } -- cgit v1.2.3