aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/BitSet.scala11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala227
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala44
5 files changed, 259 insertions, 34 deletions
diff --git a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala
index 9c15b1188d..7ab67fc3a2 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala
@@ -33,6 +33,17 @@ class BitSet(numBits: Int) extends Serializable {
def capacity: Int = numWords * 64
/**
+ * Clear all set bits.
+ */
+ def clear(): Unit = {
+ var i = 0
+ while (i < numWords) {
+ words(i) = 0L
+ i += 1
+ }
+ }
+
+ /**
* Set all the bits up to a given index
*/
def setUntil(bitIndex: Int) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 2170bc73a0..4572d5efc9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -132,15 +132,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
joins.BroadcastHashOuterJoin(
leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil
- case ExtractEquiJoinKeys(LeftOuter, leftKeys, rightKeys, condition, left, right)
+ case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) =>
joins.SortMergeOuterJoin(
- leftKeys, rightKeys, LeftOuter, condition, planLater(left), planLater(right)) :: Nil
-
- case ExtractEquiJoinKeys(RightOuter, leftKeys, rightKeys, condition, left, right)
- if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) =>
- joins.SortMergeOuterJoin(
- leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil
+ leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) =>
joins.ShuffledHashOuterJoin(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala
index dea9e5e580..ab20ee573a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala
@@ -17,20 +17,21 @@
package org.apache.spark.sql.execution.joins
+import scala.collection.mutable.ArrayBuffer
+
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan}
+import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter}
import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics}
+import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan}
+import org.apache.spark.util.collection.BitSet
/**
* :: DeveloperApi ::
* Performs an sort merge outer join of two child relations.
- *
- * Note: this does not support full outer join yet; see SPARK-9730 for progress on this.
*/
@DeveloperApi
case class SortMergeOuterJoin(
@@ -52,6 +53,8 @@ case class SortMergeOuterJoin(
left.output ++ right.output.map(_.withNullability(true))
case RightOuter =>
left.output.map(_.withNullability(true)) ++ right.output
+ case FullOuter =>
+ (left.output ++ right.output).map(_.withNullability(true))
case x =>
throw new IllegalArgumentException(
s"${getClass.getSimpleName} should not take $x as the JoinType")
@@ -62,6 +65,7 @@ case class SortMergeOuterJoin(
// For left and right outer joins, the output is partitioned by the streamed input's join keys.
case LeftOuter => left.outputPartitioning
case RightOuter => right.outputPartitioning
+ case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
case x =>
throw new IllegalArgumentException(
s"${getClass.getSimpleName} should not take $x as the JoinType")
@@ -71,6 +75,8 @@ case class SortMergeOuterJoin(
// For left and right outer joins, the output is ordered by the streamed input's join keys.
case LeftOuter => requiredOrders(leftKeys)
case RightOuter => requiredOrders(rightKeys)
+ // there are null rows in both streams, so there is no order
+ case FullOuter => Nil
case x => throw new IllegalArgumentException(
s"SortMergeOuterJoin should not take $x as the JoinType")
}
@@ -165,6 +171,26 @@ case class SortMergeOuterJoin(
new RightOuterIterator(
smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows).toScala
+ case FullOuter =>
+ val leftNullRow = new GenericInternalRow(left.output.length)
+ val rightNullRow = new GenericInternalRow(right.output.length)
+ val smjScanner = new SortMergeFullOuterJoinScanner(
+ leftKeyGenerator = createLeftKeyGenerator(),
+ rightKeyGenerator = createRightKeyGenerator(),
+ keyOrdering,
+ leftIter = RowIterator.fromScala(leftIter),
+ numLeftRows,
+ rightIter = RowIterator.fromScala(rightIter),
+ numRightRows,
+ boundCondition,
+ leftNullRow,
+ rightNullRow)
+
+ new FullOuterIterator(
+ smjScanner,
+ resultProj,
+ numOutputRows).toScala
+
case x =>
throw new IllegalArgumentException(
s"SortMergeOuterJoin should not take $x as the JoinType")
@@ -271,3 +297,196 @@ private class RightOuterIterator(
override def getRow: InternalRow = resultProj(joinedRow)
}
+
+private class SortMergeFullOuterJoinScanner(
+ leftKeyGenerator: Projection,
+ rightKeyGenerator: Projection,
+ keyOrdering: Ordering[InternalRow],
+ leftIter: RowIterator,
+ numLeftRows: LongSQLMetric,
+ rightIter: RowIterator,
+ numRightRows: LongSQLMetric,
+ boundCondition: InternalRow => Boolean,
+ leftNullRow: InternalRow,
+ rightNullRow: InternalRow) {
+ private[this] val joinedRow: JoinedRow = new JoinedRow()
+ private[this] var leftRow: InternalRow = _
+ private[this] var leftRowKey: InternalRow = _
+ private[this] var rightRow: InternalRow = _
+ private[this] var rightRowKey: InternalRow = _
+
+ private[this] var leftIndex: Int = 0
+ private[this] var rightIndex: Int = 0
+ private[this] val leftMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow]
+ private[this] val rightMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow]
+ private[this] var leftMatched: BitSet = new BitSet(1)
+ private[this] var rightMatched: BitSet = new BitSet(1)
+
+ advancedLeft()
+ advancedRight()
+
+ // --- Private methods --------------------------------------------------------------------------
+
+ /**
+ * Advance the left iterator and compute the new row's join key.
+ * @return true if the left iterator returned a row and false otherwise.
+ */
+ private def advancedLeft(): Boolean = {
+ if (leftIter.advanceNext()) {
+ leftRow = leftIter.getRow
+ leftRowKey = leftKeyGenerator(leftRow)
+ numLeftRows += 1
+ true
+ } else {
+ leftRow = null
+ leftRowKey = null
+ false
+ }
+ }
+
+ /**
+ * Advance the right iterator and compute the new row's join key.
+ * @return true if the right iterator returned a row and false otherwise.
+ */
+ private def advancedRight(): Boolean = {
+ if (rightIter.advanceNext()) {
+ rightRow = rightIter.getRow
+ rightRowKey = rightKeyGenerator(rightRow)
+ numRightRows += 1
+ true
+ } else {
+ rightRow = null
+ rightRowKey = null
+ false
+ }
+ }
+
+ /**
+ * Populate the left and right buffers with rows matching the provided key.
+ * This consumes rows from both iterators until their keys are different from the matching key.
+ */
+ private def findMatchingRows(matchingKey: InternalRow): Unit = {
+ leftMatches.clear()
+ rightMatches.clear()
+ leftIndex = 0
+ rightIndex = 0
+
+ while (leftRowKey != null && keyOrdering.compare(leftRowKey, matchingKey) == 0) {
+ leftMatches += leftRow.copy()
+ advancedLeft()
+ }
+ while (rightRowKey != null && keyOrdering.compare(rightRowKey, matchingKey) == 0) {
+ rightMatches += rightRow.copy()
+ advancedRight()
+ }
+
+ if (leftMatches.size <= leftMatched.capacity) {
+ leftMatched.clear()
+ } else {
+ leftMatched = new BitSet(leftMatches.size)
+ }
+ if (rightMatches.size <= rightMatched.capacity) {
+ rightMatched.clear()
+ } else {
+ rightMatched = new BitSet(rightMatches.size)
+ }
+ }
+
+ /**
+ * Scan the left and right buffers for the next valid match.
+ *
+ * Note: this method mutates `joinedRow` to point to the latest matching rows in the buffers.
+ * If a left row has no valid matches on the right, or a right row has no valid matches on the
+ * left, then the row is joined with the null row and the result is considered a valid match.
+ *
+ * @return true if a valid match is found, false otherwise.
+ */
+ private def scanNextInBuffered(): Boolean = {
+ while (leftIndex < leftMatches.size) {
+ while (rightIndex < rightMatches.size) {
+ joinedRow(leftMatches(leftIndex), rightMatches(rightIndex))
+ if (boundCondition(joinedRow)) {
+ leftMatched.set(leftIndex)
+ rightMatched.set(rightIndex)
+ rightIndex += 1
+ return true
+ }
+ rightIndex += 1
+ }
+ rightIndex = 0
+ if (!leftMatched.get(leftIndex)) {
+ // the left row has never matched any right row, join it with null row
+ joinedRow(leftMatches(leftIndex), rightNullRow)
+ leftIndex += 1
+ return true
+ }
+ leftIndex += 1
+ }
+
+ while (rightIndex < rightMatches.size) {
+ if (!rightMatched.get(rightIndex)) {
+ // the right row has never matched any left row, join it with null row
+ joinedRow(leftNullRow, rightMatches(rightIndex))
+ rightIndex += 1
+ return true
+ }
+ rightIndex += 1
+ }
+
+ // There are no more valid matches in the left and right buffers
+ false
+ }
+
+ // --- Public methods --------------------------------------------------------------------------
+
+ def getJoinedRow(): JoinedRow = joinedRow
+
+ def advanceNext(): Boolean = {
+ // If we already buffered some matching rows, use them directly
+ if (leftIndex <= leftMatches.size || rightIndex <= rightMatches.size) {
+ if (scanNextInBuffered()) {
+ return true
+ }
+ }
+
+ if (leftRow != null && (leftRowKey.anyNull || rightRow == null)) {
+ joinedRow(leftRow.copy(), rightNullRow)
+ advancedLeft()
+ true
+ } else if (rightRow != null && (rightRowKey.anyNull || leftRow == null)) {
+ joinedRow(leftNullRow, rightRow.copy())
+ advancedRight()
+ true
+ } else if (leftRow != null && rightRow != null) {
+ // Both rows are present and neither have null values,
+ // so we populate the buffers with rows matching the next key
+ val comp = keyOrdering.compare(leftRowKey, rightRowKey)
+ if (comp <= 0) {
+ findMatchingRows(leftRowKey.copy())
+ } else {
+ findMatchingRows(rightRowKey.copy())
+ }
+ scanNextInBuffered()
+ true
+ } else {
+ // Both iterators have been consumed
+ false
+ }
+ }
+}
+
+private class FullOuterIterator(
+ smjScanner: SortMergeFullOuterJoinScanner,
+ resultProj: InternalRow => InternalRow,
+ numRows: LongSQLMetric
+ ) extends RowIterator {
+ private[this] val joinedRow: JoinedRow = smjScanner.getJoinedRow()
+
+ override def advanceNext(): Boolean = {
+ val r = smjScanner.advanceNext()
+ if (r) numRows += 1
+ r
+ }
+
+ override def getRow: InternalRow = resultProj(joinedRow)
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index b05435bad5..7a027e1308 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -83,7 +83,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
("SELECT * FROM testData right join testData2 ON key = a and key = 2",
classOf[SortMergeOuterJoin]),
("SELECT * FROM testData full outer join testData2 ON key = a",
- classOf[ShuffledHashOuterJoin]),
+ classOf[SortMergeOuterJoin]),
("SELECT * FROM testData left JOIN testData2 ON (key * a != key + a)",
classOf[BroadcastNestedLoopJoin]),
("SELECT * FROM testData right JOIN testData2 ON (key * a != key + a)",
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
index c2e0bdac17..09e0237a7c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
@@ -76,37 +76,37 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext {
test(s"$testName using ShuffledHashOuterJoin") {
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
- withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
- checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
- EnsureRequirements(sqlContext).apply(
- ShuffledHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)),
- expectedAnswer.map(Row.fromTuple),
- sortAnswers = true)
- }
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+ checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
+ EnsureRequirements(sqlContext).apply(
+ ShuffledHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)),
+ expectedAnswer.map(Row.fromTuple),
+ sortAnswers = true)
+ }
}
}
if (joinType != FullOuter) {
test(s"$testName using BroadcastHashOuterJoin") {
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
- withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
- checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
- BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right),
- expectedAnswer.map(Row.fromTuple),
- sortAnswers = true)
- }
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+ checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
+ BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right),
+ expectedAnswer.map(Row.fromTuple),
+ sortAnswers = true)
+ }
}
}
+ }
- test(s"$testName using SortMergeOuterJoin") {
- extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
- withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
- checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
- EnsureRequirements(sqlContext).apply(
- SortMergeOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)),
- expectedAnswer.map(Row.fromTuple),
- sortAnswers = false)
- }
+ test(s"$testName using SortMergeOuterJoin") {
+ extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+ checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
+ EnsureRequirements(sqlContext).apply(
+ SortMergeOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)),
+ expectedAnswer.map(Row.fromTuple),
+ sortAnswers = true)
}
}
}