aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@databricks.com>2015-08-10 22:04:41 -0700
committerReynold Xin <rxin@databricks.com>2015-08-10 22:04:41 -0700
commit91e9389f39509e63654bd4bcb7bd919eaedda910 (patch)
treeeb3dd47a78ba5ca4314fe3c574c2da69f9fc3395 /sql
parent071bbad5db1096a548c886762b611a8484a52753 (diff)
downloadspark-91e9389f39509e63654bd4bcb7bd919eaedda910.tar.gz
spark-91e9389f39509e63654bd4bcb7bd919eaedda910.tar.bz2
spark-91e9389f39509e63654bd4bcb7bd919eaedda910.zip
[SPARK-9729] [SPARK-9363] [SQL] Use sort merge join for left and right outer join
This patch adds a new `SortMergeOuterJoin` operator that performs left and right outer joins using sort merge join. It also refactors `SortMergeJoin` in order to improve performance and code clarity. Along the way, I also performed a couple pieces of minor cleanup and optimization: - Rename the `HashJoin` physical planner rule to `EquiJoinSelection`, since it's also used for non-hash joins. - Rewrite the comment at the top of `HashJoin` to better explain the precedence for choosing join operators. - Update `JoinSuite` to use `SqlTestUtils.withConf` for changing SQLConf settings. This patch incorporates several ideas from adrian-wang's patch, #5717. Closes #5717. <!-- Reviewable:start --> [<img src="https://reviewable.io/review_button.png" height=40 alt="Review on Reviewable"/>](https://reviewable.io/reviews/apache/spark/7904) <!-- Reviewable:end --> Author: Josh Rosen <joshrosen@databricks.com> Author: Daoyuan Wang <daoyuan.wang@intel.com> Closes #7904 from JoshRosen/outer-join-smj and squashes 1 commits.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala93
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala45
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala331
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala251
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala132
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala180
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala310
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala125
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala2
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala2
13 files changed, 1165 insertions, 319 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala
index b76757c935..d3560df079 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala
@@ -37,20 +37,20 @@ class JoinedRow extends InternalRow {
}
/** Updates this JoinedRow to used point at two new base rows. Returns itself. */
- def apply(r1: InternalRow, r2: InternalRow): InternalRow = {
+ def apply(r1: InternalRow, r2: InternalRow): JoinedRow = {
row1 = r1
row2 = r2
this
}
/** Updates this JoinedRow by updating its left base row. Returns itself. */
- def withLeft(newLeft: InternalRow): InternalRow = {
+ def withLeft(newLeft: InternalRow): JoinedRow = {
row1 = newLeft
this
}
/** Updates this JoinedRow by updating its right base row. Returns itself. */
- def withRight(newRight: InternalRow): InternalRow = {
+ def withRight(newRight: InternalRow): JoinedRow = {
row2 = newRight
this
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index f73bb0488c..4bf00b3399 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -873,7 +873,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
HashAggregation ::
Aggregation ::
LeftSemiJoin ::
- HashJoin ::
+ EquiJoinSelection ::
InMemoryScans ::
BasicOperators ::
CartesianProduct ::
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala
new file mode 100644
index 0000000000..7462dbc4eb
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import java.util.NoSuchElementException
+
+import org.apache.spark.sql.catalyst.InternalRow
+
+/**
+ * An internal iterator interface which presents a more restrictive API than
+ * [[scala.collection.Iterator]].
+ *
+ * One major departure from the Scala iterator API is the fusing of the `hasNext()` and `next()`
+ * calls: Scala's iterator allows users to call `hasNext()` without immediately advancing the
+ * iterator to consume the next row, whereas RowIterator combines these calls into a single
+ * [[advanceNext()]] method.
+ */
+private[sql] abstract class RowIterator {
+ /**
+ * Advance this iterator by a single row. Returns `false` if this iterator has no more rows
+ * and `true` otherwise. If this returns `true`, then the new row can be retrieved by calling
+ * [[getRow]].
+ */
+ def advanceNext(): Boolean
+
+ /**
+ * Retrieve the row from this iterator. This method is idempotent. It is illegal to call this
+ * method after [[advanceNext()]] has returned `false`.
+ */
+ def getRow: InternalRow
+
+ /**
+ * Convert this RowIterator into a [[scala.collection.Iterator]].
+ */
+ def toScala: Iterator[InternalRow] = new RowIteratorToScala(this)
+}
+
+object RowIterator {
+ def fromScala(scalaIter: Iterator[InternalRow]): RowIterator = {
+ scalaIter match {
+ case wrappedRowIter: RowIteratorToScala => wrappedRowIter.rowIter
+ case _ => new RowIteratorFromScala(scalaIter)
+ }
+ }
+}
+
+private final class RowIteratorToScala(val rowIter: RowIterator) extends Iterator[InternalRow] {
+ private [this] var hasNextWasCalled: Boolean = false
+ private [this] var _hasNext: Boolean = false
+ override def hasNext: Boolean = {
+ // Idempotency:
+ if (!hasNextWasCalled) {
+ _hasNext = rowIter.advanceNext()
+ hasNextWasCalled = true
+ }
+ _hasNext
+ }
+ override def next(): InternalRow = {
+ if (!hasNext) throw new NoSuchElementException
+ hasNextWasCalled = false
+ rowIter.getRow
+ }
+}
+
+private final class RowIteratorFromScala(scalaIter: Iterator[InternalRow]) extends RowIterator {
+ private[this] var _next: InternalRow = null
+ override def advanceNext(): Boolean = {
+ if (scalaIter.hasNext) {
+ _next = scalaIter.next()
+ true
+ } else {
+ _next = null
+ false
+ }
+ }
+ override def getRow: InternalRow = _next
+ override def toScala: Iterator[InternalRow] = scalaIter
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index c4b9b5acea..1fc870d44b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -63,19 +63,23 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
/**
- * Uses the ExtractEquiJoinKeys pattern to find joins where at least some of the predicates can be
- * evaluated by matching hash keys.
+ * Uses the [[ExtractEquiJoinKeys]] pattern to find joins where at least some of the predicates
+ * can be evaluated by matching join keys.
*
- * This strategy applies a simple optimization based on the estimates of the physical sizes of
- * the two join sides. When planning a [[joins.BroadcastHashJoin]], if one side has an
- * estimated physical size smaller than the user-settable threshold
- * [[org.apache.spark.sql.SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]], the planner would mark it as the
- * ''build'' relation and mark the other relation as the ''stream'' side. The build table will be
- * ''broadcasted'' to all of the executors involved in the join, as a
- * [[org.apache.spark.broadcast.Broadcast]] object. If both estimates exceed the threshold, they
- * will instead be used to decide the build side in a [[joins.ShuffledHashJoin]].
+ * Join implementations are chosen with the following precedence:
+ *
+ * - Broadcast: if one side of the join has an estimated physical size that is smaller than the
+ * user-configurable [[org.apache.spark.sql.SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold
+ * or if that side has an explicit broadcast hint (e.g. the user applied the
+ * [[org.apache.spark.sql.functions.broadcast()]] function to a DataFrame), then that side
+ * of the join will be broadcasted and the other side will be streamed, with no shuffling
+ * performed. If both sides of the join are eligible to be broadcasted then the
+ * - Sort merge: if the matching join keys are sortable and
+ * [[org.apache.spark.sql.SQLConf.SORTMERGE_JOIN]] is enabled (default), then sort merge join
+ * will be used.
+ * - Hash: will be chosen if neither of the above optimizations apply to this join.
*/
- object HashJoin extends Strategy with PredicateHelper {
+ object EquiJoinSelection extends Strategy with PredicateHelper {
private[this] def makeBroadcastHashJoin(
leftKeys: Seq[Expression],
@@ -90,14 +94,15 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+
+ // --- Inner joins --------------------------------------------------------------------------
+
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, CanBroadcast(right)) =>
makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildRight)
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, CanBroadcast(left), right) =>
makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft)
- // If the sort merge join option is set, we want to use sort merge join prior to hashjoin
- // for now let's support inner join first, then add outer join
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) =>
val mergeJoin =
@@ -115,6 +120,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
leftKeys, rightKeys, buildSide, planLater(left), planLater(right))
condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil
+ // --- Outer joins --------------------------------------------------------------------------
+
case ExtractEquiJoinKeys(
LeftOuter, leftKeys, rightKeys, condition, left, CanBroadcast(right)) =>
joins.BroadcastHashOuterJoin(
@@ -125,10 +132,22 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
joins.BroadcastHashOuterJoin(
leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil
+ case ExtractEquiJoinKeys(LeftOuter, leftKeys, rightKeys, condition, left, right)
+ if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) =>
+ joins.SortMergeOuterJoin(
+ leftKeys, rightKeys, LeftOuter, condition, planLater(left), planLater(right)) :: Nil
+
+ case ExtractEquiJoinKeys(RightOuter, leftKeys, rightKeys, condition, left, right)
+ if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) =>
+ joins.SortMergeOuterJoin(
+ leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil
+
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) =>
joins.ShuffledHashOuterJoin(
leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil
+ // --- Cases where this strategy does not apply ---------------------------------------------
+
case _ => Nil
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
index 23aebf4b06..017a44b9ca 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
@@ -65,8 +65,9 @@ case class BroadcastNestedLoopJoin(
left.output.map(_.withNullability(true)) ++ right.output
case FullOuter =>
left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
- case _ =>
- left.output ++ right.output
+ case x =>
+ throw new IllegalArgumentException(
+ s"BroadcastNestedLoopJoin should not take $x as the JoinType")
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
index 4ae23c186c..6d656ea284 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
@@ -17,15 +17,14 @@
package org.apache.spark.sql.execution.joins
-import java.util.NoSuchElementException
+import scala.collection.mutable.ArrayBuffer
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
-import org.apache.spark.util.collection.CompactBuffer
+import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan}
/**
* :: DeveloperApi ::
@@ -38,8 +37,6 @@ case class SortMergeJoin(
left: SparkPlan,
right: SparkPlan) extends BinaryNode {
- override protected[sql] val trackNumOfRowsEnabled = true
-
override def output: Seq[Attribute] = left.output ++ right.output
override def outputPartitioning: Partitioning =
@@ -56,117 +53,265 @@ case class SortMergeJoin(
@transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output)
@transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output)
+ protected[this] def isUnsafeMode: Boolean = {
+ (codegenEnabled && unsafeEnabled
+ && UnsafeProjection.canSupport(leftKeys)
+ && UnsafeProjection.canSupport(rightKeys)
+ && UnsafeProjection.canSupport(schema))
+ }
+
+ override def outputsUnsafeRows: Boolean = isUnsafeMode
+ override def canProcessUnsafeRows: Boolean = isUnsafeMode
+ override def canProcessSafeRows: Boolean = !isUnsafeMode
+
private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = {
// This must be ascending in order to agree with the `keyOrdering` defined in `doExecute()`.
keys.map(SortOrder(_, Ascending))
}
protected override def doExecute(): RDD[InternalRow] = {
- val leftResults = left.execute().map(_.copy())
- val rightResults = right.execute().map(_.copy())
-
- leftResults.zipPartitions(rightResults) { (leftIter, rightIter) =>
- new Iterator[InternalRow] {
+ left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
+ new RowIterator {
// An ordering that can be used to compare keys from both sides.
private[this] val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType))
- // Mutable per row objects.
+ private[this] var currentLeftRow: InternalRow = _
+ private[this] var currentRightMatches: ArrayBuffer[InternalRow] = _
+ private[this] var currentMatchIdx: Int = -1
+ private[this] val smjScanner = new SortMergeJoinScanner(
+ leftKeyGenerator,
+ rightKeyGenerator,
+ keyOrdering,
+ RowIterator.fromScala(leftIter),
+ RowIterator.fromScala(rightIter)
+ )
private[this] val joinRow = new JoinedRow
- private[this] var leftElement: InternalRow = _
- private[this] var rightElement: InternalRow = _
- private[this] var leftKey: InternalRow = _
- private[this] var rightKey: InternalRow = _
- private[this] var rightMatches: CompactBuffer[InternalRow] = _
- private[this] var rightPosition: Int = -1
- private[this] var stop: Boolean = false
- private[this] var matchKey: InternalRow = _
-
- // initialize iterator
- initialize()
-
- override final def hasNext: Boolean = nextMatchingPair()
-
- override final def next(): InternalRow = {
- if (hasNext) {
- // we are using the buffered right rows and run down left iterator
- val joinedRow = joinRow(leftElement, rightMatches(rightPosition))
- rightPosition += 1
- if (rightPosition >= rightMatches.size) {
- rightPosition = 0
- fetchLeft()
- if (leftElement == null || keyOrdering.compare(leftKey, matchKey) != 0) {
- stop = false
- rightMatches = null
- }
- }
- joinedRow
+ private[this] val resultProjection: (InternalRow) => InternalRow = {
+ if (isUnsafeMode) {
+ UnsafeProjection.create(schema)
} else {
- // no more result
- throw new NoSuchElementException
+ identity[InternalRow]
}
}
- private def fetchLeft() = {
- if (leftIter.hasNext) {
- leftElement = leftIter.next()
- leftKey = leftKeyGenerator(leftElement)
- } else {
- leftElement = null
+ override def advanceNext(): Boolean = {
+ if (currentMatchIdx == -1 || currentMatchIdx == currentRightMatches.length) {
+ if (smjScanner.findNextInnerJoinRows()) {
+ currentRightMatches = smjScanner.getBufferedMatches
+ currentLeftRow = smjScanner.getStreamedRow
+ currentMatchIdx = 0
+ } else {
+ currentRightMatches = null
+ currentLeftRow = null
+ currentMatchIdx = -1
+ }
}
- }
-
- private def fetchRight() = {
- if (rightIter.hasNext) {
- rightElement = rightIter.next()
- rightKey = rightKeyGenerator(rightElement)
+ if (currentLeftRow != null) {
+ joinRow(currentLeftRow, currentRightMatches(currentMatchIdx))
+ currentMatchIdx += 1
+ true
} else {
- rightElement = null
+ false
}
}
- private def initialize() = {
- fetchLeft()
- fetchRight()
+ override def getRow: InternalRow = resultProjection(joinRow)
+ }.toScala
+ }
+ }
+}
+
+/**
+ * Helper class that is used to implement [[SortMergeJoin]] and [[SortMergeOuterJoin]].
+ *
+ * To perform an inner (outer) join, users of this class call [[findNextInnerJoinRows()]]
+ * ([[findNextOuterJoinRows()]]), which returns `true` if a result has been produced and `false`
+ * otherwise. If a result has been produced, then the caller may call [[getStreamedRow]] to return
+ * the matching row from the streamed input and may call [[getBufferedMatches]] to return the
+ * sequence of matching rows from the buffered input (in the case of an outer join, this will return
+ * an empty sequence if there are no matches from the buffered input). For efficiency, both of these
+ * methods return mutable objects which are re-used across calls to the `findNext*JoinRows()`
+ * methods.
+ *
+ * @param streamedKeyGenerator a projection that produces join keys from the streamed input.
+ * @param bufferedKeyGenerator a projection that produces join keys from the buffered input.
+ * @param keyOrdering an ordering which can be used to compare join keys.
+ * @param streamedIter an input whose rows will be streamed.
+ * @param bufferedIter an input whose rows will be buffered to construct sequences of rows that
+ * have the same join key.
+ */
+private[joins] class SortMergeJoinScanner(
+ streamedKeyGenerator: Projection,
+ bufferedKeyGenerator: Projection,
+ keyOrdering: Ordering[InternalRow],
+ streamedIter: RowIterator,
+ bufferedIter: RowIterator) {
+ private[this] var streamedRow: InternalRow = _
+ private[this] var streamedRowKey: InternalRow = _
+ private[this] var bufferedRow: InternalRow = _
+ // Note: this is guaranteed to never have any null columns:
+ private[this] var bufferedRowKey: InternalRow = _
+ /**
+ * The join key for the rows buffered in `bufferedMatches`, or null if `bufferedMatches` is empty
+ */
+ private[this] var matchJoinKey: InternalRow = _
+ /** Buffered rows from the buffered side of the join. This is empty if there are no matches. */
+ private[this] val bufferedMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow]
+
+ // Initialization (note: do _not_ want to advance streamed here).
+ advancedBufferedToRowWithNullFreeJoinKey()
+
+ // --- Public methods ---------------------------------------------------------------------------
+
+ def getStreamedRow: InternalRow = streamedRow
+
+ def getBufferedMatches: ArrayBuffer[InternalRow] = bufferedMatches
+
+ /**
+ * Advances both input iterators, stopping when we have found rows with matching join keys.
+ * @return true if matching rows have been found and false otherwise. If this returns true, then
+ * [[getStreamedRow]] and [[getBufferedMatches]] can be called to construct the join
+ * results.
+ */
+ final def findNextInnerJoinRows(): Boolean = {
+ while (advancedStreamed() && streamedRowKey.anyNull) {
+ // Advance the streamed side of the join until we find the next row whose join key contains
+ // no nulls or we hit the end of the streamed iterator.
+ }
+ if (streamedRow == null) {
+ // We have consumed the entire streamed iterator, so there can be no more matches.
+ matchJoinKey = null
+ bufferedMatches.clear()
+ false
+ } else if (matchJoinKey != null && keyOrdering.compare(streamedRowKey, matchJoinKey) == 0) {
+ // The new streamed row has the same join key as the previous row, so return the same matches.
+ true
+ } else if (bufferedRow == null) {
+ // The streamed row's join key does not match the current batch of buffered rows and there are
+ // no more rows to read from the buffered iterator, so there can be no more matches.
+ matchJoinKey = null
+ bufferedMatches.clear()
+ false
+ } else {
+ // Advance both the streamed and buffered iterators to find the next pair of matching rows.
+ var comp = keyOrdering.compare(streamedRowKey, bufferedRowKey)
+ do {
+ if (streamedRowKey.anyNull) {
+ advancedStreamed()
+ } else {
+ assert(!bufferedRowKey.anyNull)
+ comp = keyOrdering.compare(streamedRowKey, bufferedRowKey)
+ if (comp > 0) advancedBufferedToRowWithNullFreeJoinKey()
+ else if (comp < 0) advancedStreamed()
}
+ } while (streamedRow != null && bufferedRow != null && comp != 0)
+ if (streamedRow == null || bufferedRow == null) {
+ // We have either hit the end of one of the iterators, so there can be no more matches.
+ matchJoinKey = null
+ bufferedMatches.clear()
+ false
+ } else {
+ // The streamed row's join key matches the current buffered row's join, so walk through the
+ // buffered iterator to buffer the rest of the matching rows.
+ assert(comp == 0)
+ bufferMatchingRows()
+ true
+ }
+ }
+ }
- /**
- * Searches the right iterator for the next rows that have matches in left side, and store
- * them in a buffer.
- *
- * @return true if the search is successful, and false if the right iterator runs out of
- * tuples.
- */
- private def nextMatchingPair(): Boolean = {
- if (!stop && rightElement != null) {
- // run both side to get the first match pair
- while (!stop && leftElement != null && rightElement != null) {
- val comparing = keyOrdering.compare(leftKey, rightKey)
- // for inner join, we need to filter those null keys
- stop = comparing == 0 && !leftKey.anyNull
- if (comparing > 0 || rightKey.anyNull) {
- fetchRight()
- } else if (comparing < 0 || leftKey.anyNull) {
- fetchLeft()
- }
- }
- rightMatches = new CompactBuffer[InternalRow]()
- if (stop) {
- stop = false
- // iterate the right side to buffer all rows that matches
- // as the records should be ordered, exit when we meet the first that not match
- while (!stop && rightElement != null) {
- rightMatches += rightElement
- fetchRight()
- stop = keyOrdering.compare(leftKey, rightKey) != 0
- }
- if (rightMatches.size > 0) {
- rightPosition = 0
- matchKey = leftKey
- }
- }
+ /**
+ * Advances the streamed input iterator and buffers all rows from the buffered input that
+ * have matching keys.
+ * @return true if the streamed iterator returned a row, false otherwise. If this returns true,
+ * then [getStreamedRow and [[getBufferedMatches]] can be called to produce the outer
+ * join results.
+ */
+ final def findNextOuterJoinRows(): Boolean = {
+ if (!advancedStreamed()) {
+ // We have consumed the entire streamed iterator, so there can be no more matches.
+ matchJoinKey = null
+ bufferedMatches.clear()
+ false
+ } else {
+ if (matchJoinKey != null && keyOrdering.compare(streamedRowKey, matchJoinKey) == 0) {
+ // Matches the current group, so do nothing.
+ } else {
+ // The streamed row does not match the current group.
+ matchJoinKey = null
+ bufferedMatches.clear()
+ if (bufferedRow != null && !streamedRowKey.anyNull) {
+ // The buffered iterator could still contain matching rows, so we'll need to walk through
+ // it until we either find matches or pass where they would be found.
+ var comp = 1
+ do {
+ comp = keyOrdering.compare(streamedRowKey, bufferedRowKey)
+ } while (comp > 0 && advancedBufferedToRowWithNullFreeJoinKey())
+ if (comp == 0) {
+ // We have found matches, so buffer them (this updates matchJoinKey)
+ bufferMatchingRows()
+ } else {
+ // We have overshot the position where the row would be found, hence no matches.
}
- rightMatches != null && rightMatches.size > 0
}
}
+ // If there is a streamed input then we always return true
+ true
}
}
+
+ // --- Private methods --------------------------------------------------------------------------
+
+ /**
+ * Advance the streamed iterator and compute the new row's join key.
+ * @return true if the streamed iterator returned a row and false otherwise.
+ */
+ private def advancedStreamed(): Boolean = {
+ if (streamedIter.advanceNext()) {
+ streamedRow = streamedIter.getRow
+ streamedRowKey = streamedKeyGenerator(streamedRow)
+ true
+ } else {
+ streamedRow = null
+ streamedRowKey = null
+ false
+ }
+ }
+
+ /**
+ * Advance the buffered iterator until we find a row with join key that does not contain nulls.
+ * @return true if the buffered iterator returned a row and false otherwise.
+ */
+ private def advancedBufferedToRowWithNullFreeJoinKey(): Boolean = {
+ var foundRow: Boolean = false
+ while (!foundRow && bufferedIter.advanceNext()) {
+ bufferedRow = bufferedIter.getRow
+ bufferedRowKey = bufferedKeyGenerator(bufferedRow)
+ foundRow = !bufferedRowKey.anyNull
+ }
+ if (!foundRow) {
+ bufferedRow = null
+ bufferedRowKey = null
+ false
+ } else {
+ true
+ }
+ }
+
+ /**
+ * Called when the streamed and buffered join keys match in order to buffer the matching rows.
+ */
+ private def bufferMatchingRows(): Unit = {
+ assert(streamedRowKey != null)
+ assert(!streamedRowKey.anyNull)
+ assert(bufferedRowKey != null)
+ assert(!bufferedRowKey.anyNull)
+ assert(keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0)
+ // This join key may have been produced by a mutable projection, so we need to make a copy:
+ matchJoinKey = streamedRowKey.copy()
+ bufferedMatches.clear()
+ do {
+ bufferedMatches += bufferedRow.copy() // need to copy mutable rows before buffering them
+ advancedBufferedToRowWithNullFreeJoinKey()
+ } while (bufferedRow != null && keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0)
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala
new file mode 100644
index 0000000000..5326966b07
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala
@@ -0,0 +1,251 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter}
+import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan}
+
+/**
+ * :: DeveloperApi ::
+ * Performs an sort merge outer join of two child relations.
+ *
+ * Note: this does not support full outer join yet; see SPARK-9730 for progress on this.
+ */
+@DeveloperApi
+case class SortMergeOuterJoin(
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ joinType: JoinType,
+ condition: Option[Expression],
+ left: SparkPlan,
+ right: SparkPlan) extends BinaryNode {
+
+ override def output: Seq[Attribute] = {
+ joinType match {
+ case LeftOuter =>
+ left.output ++ right.output.map(_.withNullability(true))
+ case RightOuter =>
+ left.output.map(_.withNullability(true)) ++ right.output
+ case x =>
+ throw new IllegalArgumentException(
+ s"${getClass.getSimpleName} should not take $x as the JoinType")
+ }
+ }
+
+ override def outputPartitioning: Partitioning = joinType match {
+ // For left and right outer joins, the output is partitioned by the streamed input's join keys.
+ case LeftOuter => left.outputPartitioning
+ case RightOuter => right.outputPartitioning
+ case x =>
+ throw new IllegalArgumentException(
+ s"${getClass.getSimpleName} should not take $x as the JoinType")
+ }
+
+ override def outputOrdering: Seq[SortOrder] = joinType match {
+ // For left and right outer joins, the output is ordered by the streamed input's join keys.
+ case LeftOuter => requiredOrders(leftKeys)
+ case RightOuter => requiredOrders(rightKeys)
+ case x => throw new IllegalArgumentException(
+ s"SortMergeOuterJoin should not take $x as the JoinType")
+ }
+
+ override def requiredChildDistribution: Seq[Distribution] =
+ ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
+
+ override def requiredChildOrdering: Seq[Seq[SortOrder]] =
+ requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil
+
+ private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = {
+ // This must be ascending in order to agree with the `keyOrdering` defined in `doExecute()`.
+ keys.map(SortOrder(_, Ascending))
+ }
+
+ private def isUnsafeMode: Boolean = {
+ (codegenEnabled && unsafeEnabled
+ && UnsafeProjection.canSupport(leftKeys)
+ && UnsafeProjection.canSupport(rightKeys)
+ && UnsafeProjection.canSupport(schema))
+ }
+
+ override def outputsUnsafeRows: Boolean = isUnsafeMode
+ override def canProcessUnsafeRows: Boolean = isUnsafeMode
+ override def canProcessSafeRows: Boolean = !isUnsafeMode
+
+ private def createLeftKeyGenerator(): Projection = {
+ if (isUnsafeMode) {
+ UnsafeProjection.create(leftKeys, left.output)
+ } else {
+ newProjection(leftKeys, left.output)
+ }
+ }
+
+ private def createRightKeyGenerator(): Projection = {
+ if (isUnsafeMode) {
+ UnsafeProjection.create(rightKeys, right.output)
+ } else {
+ newProjection(rightKeys, right.output)
+ }
+ }
+
+ override def doExecute(): RDD[InternalRow] = {
+ left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
+ // An ordering that can be used to compare keys from both sides.
+ val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType))
+ val boundCondition: (InternalRow) => Boolean = {
+ condition.map { cond =>
+ newPredicate(cond, left.output ++ right.output)
+ }.getOrElse {
+ (r: InternalRow) => true
+ }
+ }
+ val resultProj: InternalRow => InternalRow = {
+ if (isUnsafeMode) {
+ UnsafeProjection.create(schema)
+ } else {
+ identity[InternalRow]
+ }
+ }
+
+ joinType match {
+ case LeftOuter =>
+ val smjScanner = new SortMergeJoinScanner(
+ streamedKeyGenerator = createLeftKeyGenerator(),
+ bufferedKeyGenerator = createRightKeyGenerator(),
+ keyOrdering,
+ streamedIter = RowIterator.fromScala(leftIter),
+ bufferedIter = RowIterator.fromScala(rightIter)
+ )
+ val rightNullRow = new GenericInternalRow(right.output.length)
+ new LeftOuterIterator(smjScanner, rightNullRow, boundCondition, resultProj).toScala
+
+ case RightOuter =>
+ val smjScanner = new SortMergeJoinScanner(
+ streamedKeyGenerator = createRightKeyGenerator(),
+ bufferedKeyGenerator = createLeftKeyGenerator(),
+ keyOrdering,
+ streamedIter = RowIterator.fromScala(rightIter),
+ bufferedIter = RowIterator.fromScala(leftIter)
+ )
+ val leftNullRow = new GenericInternalRow(left.output.length)
+ new RightOuterIterator(smjScanner, leftNullRow, boundCondition, resultProj).toScala
+
+ case x =>
+ throw new IllegalArgumentException(
+ s"SortMergeOuterJoin should not take $x as the JoinType")
+ }
+ }
+ }
+}
+
+
+private class LeftOuterIterator(
+ smjScanner: SortMergeJoinScanner,
+ rightNullRow: InternalRow,
+ boundCondition: InternalRow => Boolean,
+ resultProj: InternalRow => InternalRow
+ ) extends RowIterator {
+ private[this] val joinedRow: JoinedRow = new JoinedRow()
+ private[this] var rightIdx: Int = 0
+ assert(smjScanner.getBufferedMatches.length == 0)
+
+ private def advanceLeft(): Boolean = {
+ rightIdx = 0
+ if (smjScanner.findNextOuterJoinRows()) {
+ joinedRow.withLeft(smjScanner.getStreamedRow)
+ if (smjScanner.getBufferedMatches.isEmpty) {
+ // There are no matching right rows, so return nulls for the right row
+ joinedRow.withRight(rightNullRow)
+ } else {
+ // Find the next row from the right input that satisfied the bound condition
+ if (!advanceRightUntilBoundConditionSatisfied()) {
+ joinedRow.withRight(rightNullRow)
+ }
+ }
+ true
+ } else {
+ // Left input has been exhausted
+ false
+ }
+ }
+
+ private def advanceRightUntilBoundConditionSatisfied(): Boolean = {
+ var foundMatch: Boolean = false
+ while (!foundMatch && rightIdx < smjScanner.getBufferedMatches.length) {
+ foundMatch = boundCondition(joinedRow.withRight(smjScanner.getBufferedMatches(rightIdx)))
+ rightIdx += 1
+ }
+ foundMatch
+ }
+
+ override def advanceNext(): Boolean = {
+ advanceRightUntilBoundConditionSatisfied() || advanceLeft()
+ }
+
+ override def getRow: InternalRow = resultProj(joinedRow)
+}
+
+private class RightOuterIterator(
+ smjScanner: SortMergeJoinScanner,
+ leftNullRow: InternalRow,
+ boundCondition: InternalRow => Boolean,
+ resultProj: InternalRow => InternalRow
+ ) extends RowIterator {
+ private[this] val joinedRow: JoinedRow = new JoinedRow()
+ private[this] var leftIdx: Int = 0
+ assert(smjScanner.getBufferedMatches.length == 0)
+
+ private def advanceRight(): Boolean = {
+ leftIdx = 0
+ if (smjScanner.findNextOuterJoinRows()) {
+ joinedRow.withRight(smjScanner.getStreamedRow)
+ if (smjScanner.getBufferedMatches.isEmpty) {
+ // There are no matching left rows, so return nulls for the left row
+ joinedRow.withLeft(leftNullRow)
+ } else {
+ // Find the next row from the left input that satisfied the bound condition
+ if (!advanceLeftUntilBoundConditionSatisfied()) {
+ joinedRow.withLeft(leftNullRow)
+ }
+ }
+ true
+ } else {
+ // Right input has been exhausted
+ false
+ }
+ }
+
+ private def advanceLeftUntilBoundConditionSatisfied(): Boolean = {
+ var foundMatch: Boolean = false
+ while (!foundMatch && leftIdx < smjScanner.getBufferedMatches.length) {
+ foundMatch = boundCondition(joinedRow.withLeft(smjScanner.getBufferedMatches(leftIdx)))
+ leftIdx += 1
+ }
+ foundMatch
+ }
+
+ override def advanceNext(): Boolean = {
+ advanceLeftUntilBoundConditionSatisfied() || advanceRight()
+ }
+
+ override def getRow: InternalRow = resultProj(joinedRow)
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 5bef1d8966..ae07eaf91c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -22,13 +22,14 @@ import org.scalatest.BeforeAndAfterEach
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.execution.joins._
-import org.apache.spark.sql.types.BinaryType
+import org.apache.spark.sql.test.SQLTestUtils
-class JoinSuite extends QueryTest with BeforeAndAfterEach {
+class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
// Ensures tables are loaded.
TestData
+ override def sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext
lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
import ctx.logicalPlanToSparkQuery
@@ -37,7 +38,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
val x = testData2.as("x")
val y = testData2.as("y")
val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.optimizedPlan
- val planned = ctx.planner.HashJoin(join)
+ val planned = ctx.planner.EquiJoinSelection(join)
assert(planned.size === 1)
}
@@ -55,6 +56,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
case j: BroadcastNestedLoopJoin => j
case j: BroadcastLeftSemiJoinHash => j
case j: SortMergeJoin => j
+ case j: SortMergeOuterJoin => j
}
assert(operators.size === 1)
@@ -66,7 +68,6 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
test("join operator selection") {
ctx.cacheManager.clearCache()
- val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled
Seq(
("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]),
("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]),
@@ -83,11 +84,11 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]),
("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]),
("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]),
- ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]),
+ ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[SortMergeOuterJoin]),
("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2",
- classOf[ShuffledHashOuterJoin]),
+ classOf[SortMergeOuterJoin]),
("SELECT * FROM testData right join testData2 ON key = a and key = 2",
- classOf[ShuffledHashOuterJoin]),
+ classOf[SortMergeOuterJoin]),
("SELECT * FROM testData full outer join testData2 ON key = a",
classOf[ShuffledHashOuterJoin]),
("SELECT * FROM testData left JOIN testData2 ON (key * a != key + a)",
@@ -97,82 +98,75 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
("SELECT * FROM testData full JOIN testData2 ON (key * a != key + a)",
classOf[BroadcastNestedLoopJoin])
).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
- try {
- ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true)
+ withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") {
Seq(
- ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]),
- ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]),
- ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin])
+ ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[ShuffledHashJoin]),
+ ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2",
+ classOf[ShuffledHashJoin]),
+ ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2",
+ classOf[ShuffledHashJoin]),
+ ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]),
+ ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2",
+ classOf[ShuffledHashOuterJoin]),
+ ("SELECT * FROM testData right join testData2 ON key = a and key = 2",
+ classOf[ShuffledHashOuterJoin]),
+ ("SELECT * FROM testData full outer join testData2 ON key = a",
+ classOf[ShuffledHashOuterJoin])
).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
- } finally {
- ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED)
}
}
test("SortMergeJoin shouldn't work on unsortable columns") {
- val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled
- try {
- ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true)
+ withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") {
Seq(
("SELECT * FROM arrayData JOIN complexData ON data = a", classOf[ShuffledHashJoin])
).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
- } finally {
- ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED)
}
}
test("broadcasted hash join operator selection") {
ctx.cacheManager.clearCache()
ctx.sql("CACHE TABLE testData")
-
- val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled
- Seq(
- ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]),
- ("SELECT * FROM testData join testData2 ON key = a and key = 2", classOf[BroadcastHashJoin]),
- ("SELECT * FROM testData join testData2 ON key = a where key = 2",
- classOf[BroadcastHashJoin])
- ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
- try {
- ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true)
- Seq(
- ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]),
- ("SELECT * FROM testData join testData2 ON key = a and key = 2",
- classOf[BroadcastHashJoin]),
- ("SELECT * FROM testData join testData2 ON key = a where key = 2",
- classOf[BroadcastHashJoin])
- ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
- } finally {
- ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED)
+ for (sortMergeJoinEnabled <- Seq(true, false)) {
+ withClue(s"sortMergeJoinEnabled=$sortMergeJoinEnabled") {
+ withSQLConf(SQLConf.SORTMERGE_JOIN.key -> s"$sortMergeJoinEnabled") {
+ Seq(
+ ("SELECT * FROM testData join testData2 ON key = a",
+ classOf[BroadcastHashJoin]),
+ ("SELECT * FROM testData join testData2 ON key = a and key = 2",
+ classOf[BroadcastHashJoin]),
+ ("SELECT * FROM testData join testData2 ON key = a where key = 2",
+ classOf[BroadcastHashJoin])
+ ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
+ }
+ }
}
-
ctx.sql("UNCACHE TABLE testData")
}
test("broadcasted hash outer join operator selection") {
ctx.cacheManager.clearCache()
ctx.sql("CACHE TABLE testData")
-
- val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled
- Seq(
- ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]),
- ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2",
- classOf[BroadcastHashOuterJoin]),
- ("SELECT * FROM testData right join testData2 ON key = a and key = 2",
- classOf[BroadcastHashOuterJoin])
- ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
- try {
- ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true)
+ withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") {
Seq(
- ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]),
+ ("SELECT * FROM testData LEFT JOIN testData2 ON key = a",
+ classOf[SortMergeOuterJoin]),
+ ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2",
+ classOf[BroadcastHashOuterJoin]),
+ ("SELECT * FROM testData right join testData2 ON key = a and key = 2",
+ classOf[BroadcastHashOuterJoin])
+ ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
+ }
+ withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") {
+ Seq(
+ ("SELECT * FROM testData LEFT JOIN testData2 ON key = a",
+ classOf[ShuffledHashOuterJoin]),
("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2",
classOf[BroadcastHashOuterJoin]),
("SELECT * FROM testData right join testData2 ON key = a and key = 2",
classOf[BroadcastHashOuterJoin])
).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
- } finally {
- ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED)
}
-
ctx.sql("UNCACHE TABLE testData")
}
@@ -180,7 +174,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
val x = testData2.as("x")
val y = testData2.as("y")
val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.optimizedPlan
- val planned = ctx.planner.HashJoin(join)
+ val planned = ctx.planner.EquiJoinSelection(join)
assert(planned.size === 1)
}
@@ -457,25 +451,24 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
test("broadcasted left semi join operator selection") {
ctx.cacheManager.clearCache()
ctx.sql("CACHE TABLE testData")
- val tmp = ctx.conf.autoBroadcastJoinThreshold
- ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=1000000000")
- Seq(
- ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a",
- classOf[BroadcastLeftSemiJoinHash])
- ).foreach {
- case (query, joinClass) => assertJoin(query, joinClass)
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000000000") {
+ Seq(
+ ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a",
+ classOf[BroadcastLeftSemiJoinHash])
+ ).foreach {
+ case (query, joinClass) => assertJoin(query, joinClass)
+ }
}
- ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1")
-
- Seq(
- ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash])
- ).foreach {
- case (query, joinClass) => assertJoin(query, joinClass)
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+ Seq(
+ ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash])
+ ).foreach {
+ case (query, joinClass) => assertJoin(query, joinClass)
+ }
}
- ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, tmp)
ctx.sql("UNCACHE TABLE testData")
}
@@ -488,6 +481,5 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
Row(2, 2) ::
Row(3, 1) ::
Row(3, 2) :: Nil)
-
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
new file mode 100644
index 0000000000..ddff7cebcc
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
@@ -0,0 +1,180 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
+import org.apache.spark.sql.catalyst.plans.Inner
+import org.apache.spark.sql.catalyst.plans.logical.Join
+import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
+import org.apache.spark.sql.{SQLConf, execution, Row, DataFrame}
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.execution._
+
+class InnerJoinSuite extends SparkPlanTest with SQLTestUtils {
+
+ private def testInnerJoin(
+ testName: String,
+ leftRows: DataFrame,
+ rightRows: DataFrame,
+ condition: Expression,
+ expectedAnswer: Seq[Product]): Unit = {
+ val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
+ ExtractEquiJoinKeys.unapply(join).foreach {
+ case (joinType, leftKeys, rightKeys, boundCondition, leftChild, rightChild) =>
+
+ def makeBroadcastHashJoin(left: SparkPlan, right: SparkPlan, side: BuildSide) = {
+ val broadcastHashJoin =
+ execution.joins.BroadcastHashJoin(leftKeys, rightKeys, side, left, right)
+ boundCondition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin)
+ }
+
+ def makeShuffledHashJoin(left: SparkPlan, right: SparkPlan, side: BuildSide) = {
+ val shuffledHashJoin =
+ execution.joins.ShuffledHashJoin(leftKeys, rightKeys, side, left, right)
+ val filteredJoin =
+ boundCondition.map(Filter(_, shuffledHashJoin)).getOrElse(shuffledHashJoin)
+ EnsureRequirements(sqlContext).apply(filteredJoin)
+ }
+
+ def makeSortMergeJoin(left: SparkPlan, right: SparkPlan) = {
+ val sortMergeJoin =
+ execution.joins.SortMergeJoin(leftKeys, rightKeys, left, right)
+ val filteredJoin = boundCondition.map(Filter(_, sortMergeJoin)).getOrElse(sortMergeJoin)
+ EnsureRequirements(sqlContext).apply(filteredJoin)
+ }
+
+ test(s"$testName using BroadcastHashJoin (build=left)") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+ checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
+ makeBroadcastHashJoin(left, right, joins.BuildLeft),
+ expectedAnswer.map(Row.fromTuple),
+ sortAnswers = true)
+ }
+ }
+
+ test(s"$testName using BroadcastHashJoin (build=right)") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+ checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
+ makeBroadcastHashJoin(left, right, joins.BuildRight),
+ expectedAnswer.map(Row.fromTuple),
+ sortAnswers = true)
+ }
+ }
+
+ test(s"$testName using ShuffledHashJoin (build=left)") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+ checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
+ makeShuffledHashJoin(left, right, joins.BuildLeft),
+ expectedAnswer.map(Row.fromTuple),
+ sortAnswers = true)
+ }
+ }
+
+ test(s"$testName using ShuffledHashJoin (build=right)") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+ checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
+ makeShuffledHashJoin(left, right, joins.BuildRight),
+ expectedAnswer.map(Row.fromTuple),
+ sortAnswers = true)
+ }
+ }
+
+ test(s"$testName using SortMergeJoin") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+ checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
+ makeSortMergeJoin(left, right),
+ expectedAnswer.map(Row.fromTuple),
+ sortAnswers = true)
+ }
+ }
+ }
+ }
+
+ {
+ val upperCaseData = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq(
+ Row(1, "A"),
+ Row(2, "B"),
+ Row(3, "C"),
+ Row(4, "D"),
+ Row(5, "E"),
+ Row(6, "F"),
+ Row(null, "G")
+ )), new StructType().add("N", IntegerType).add("L", StringType))
+
+ val lowerCaseData = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq(
+ Row(1, "a"),
+ Row(2, "b"),
+ Row(3, "c"),
+ Row(4, "d"),
+ Row(null, "e")
+ )), new StructType().add("n", IntegerType).add("l", StringType))
+
+ testInnerJoin(
+ "inner join, one match per row",
+ upperCaseData,
+ lowerCaseData,
+ (upperCaseData.col("N") === lowerCaseData.col("n")).expr,
+ Seq(
+ (1, "A", 1, "a"),
+ (2, "B", 2, "b"),
+ (3, "C", 3, "c"),
+ (4, "D", 4, "d")
+ )
+ )
+ }
+
+ private val testData2 = Seq(
+ (1, 1),
+ (1, 2),
+ (2, 1),
+ (2, 2),
+ (3, 1),
+ (3, 2)
+ ).toDF("a", "b")
+
+ {
+ val left = testData2.where("a = 1")
+ val right = testData2.where("a = 1")
+ testInnerJoin(
+ "inner join, multiple matches",
+ left,
+ right,
+ (left.col("a") === right.col("a")).expr,
+ Seq(
+ (1, 1, 1, 1),
+ (1, 1, 1, 2),
+ (1, 2, 1, 1),
+ (1, 2, 1, 2)
+ )
+ )
+ }
+
+ {
+ val left = testData2.where("a = 1")
+ val right = testData2.where("a = 2")
+ testInnerJoin(
+ "inner join, no matches",
+ left,
+ right,
+ (left.col("a") === right.col("a")).expr,
+ Seq.empty
+ )
+ }
+
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
index 2c27da596b..e16f5e39aa 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
@@ -1,89 +1,221 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.joins
-
-import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.catalyst.expressions.{Expression, LessThan}
-import org.apache.spark.sql.catalyst.plans.{FullOuter, LeftOuter, RightOuter}
-import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest}
-
-class OuterJoinSuite extends SparkPlanTest {
-
- val left = Seq(
- (1, 2.0),
- (2, 1.0),
- (3, 3.0)
- ).toDF("a", "b")
-
- val right = Seq(
- (2, 3.0),
- (3, 2.0),
- (4, 1.0)
- ).toDF("c", "d")
-
- val leftKeys: List[Expression] = 'a :: Nil
- val rightKeys: List[Expression] = 'c :: Nil
- val condition = Some(LessThan('b, 'd))
-
- test("shuffled hash outer join") {
- checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
- ShuffledHashOuterJoin(leftKeys, rightKeys, LeftOuter, condition, left, right),
- Seq(
- (1, 2.0, null, null),
- (2, 1.0, 2, 3.0),
- (3, 3.0, null, null)
- ).map(Row.fromTuple))
-
- checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
- ShuffledHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, right),
- Seq(
- (2, 1.0, 2, 3.0),
- (null, null, 3, 2.0),
- (null, null, 4, 1.0)
- ).map(Row.fromTuple))
-
- checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
- ShuffledHashOuterJoin(leftKeys, rightKeys, FullOuter, condition, left, right),
- Seq(
- (1, 2.0, null, null),
- (2, 1.0, 2, 3.0),
- (3, 3.0, null, null),
- (null, null, 3, 2.0),
- (null, null, 4, 1.0)
- ).map(Row.fromTuple))
- }
-
- test("broadcast hash outer join") {
- checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
- BroadcastHashOuterJoin(leftKeys, rightKeys, LeftOuter, condition, left, right),
- Seq(
- (1, 2.0, null, null),
- (2, 1.0, 2, 3.0),
- (3, 3.0, null, null)
- ).map(Row.fromTuple))
-
- checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
- BroadcastHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, right),
- Seq(
- (2, 1.0, 2, 3.0),
- (null, null, 3, 2.0),
- (null, null, 4, 1.0)
- ).map(Row.fromTuple))
- }
-}
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
+import org.apache.spark.sql.catalyst.plans.logical.Join
+import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.sql.types.{IntegerType, DoubleType, StructType}
+import org.apache.spark.sql.{SQLConf, DataFrame, Row}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.execution.{EnsureRequirements, joins, SparkPlan, SparkPlanTest}
+
+class OuterJoinSuite extends SparkPlanTest with SQLTestUtils {
+
+ private def testOuterJoin(
+ testName: String,
+ leftRows: DataFrame,
+ rightRows: DataFrame,
+ joinType: JoinType,
+ condition: Expression,
+ expectedAnswer: Seq[Product]): Unit = {
+ val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
+ ExtractEquiJoinKeys.unapply(join).foreach {
+ case (_, leftKeys, rightKeys, boundCondition, leftChild, rightChild) =>
+ test(s"$testName using ShuffledHashOuterJoin") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+ checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
+ EnsureRequirements(sqlContext).apply(
+ ShuffledHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)),
+ expectedAnswer.map(Row.fromTuple),
+ sortAnswers = true)
+ }
+ }
+
+ if (joinType != FullOuter) {
+ test(s"$testName using BroadcastHashOuterJoin") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+ checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
+ BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right),
+ expectedAnswer.map(Row.fromTuple),
+ sortAnswers = true)
+ }
+ }
+
+ test(s"$testName using SortMergeOuterJoin") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+ checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
+ EnsureRequirements(sqlContext).apply(
+ SortMergeOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)),
+ expectedAnswer.map(Row.fromTuple),
+ sortAnswers = false)
+ }
+ }
+ }
+ }
+
+ test(s"$testName using BroadcastNestedLoopJoin (build=left)") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+ checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
+ joins.BroadcastNestedLoopJoin(left, right, joins.BuildLeft, joinType, Some(condition)),
+ expectedAnswer.map(Row.fromTuple),
+ sortAnswers = true)
+ }
+ }
+
+ test(s"$testName using BroadcastNestedLoopJoin (build=right)") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+ checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
+ joins.BroadcastNestedLoopJoin(left, right, joins.BuildRight, joinType, Some(condition)),
+ expectedAnswer.map(Row.fromTuple),
+ sortAnswers = true)
+ }
+ }
+ }
+
+ val left = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq(
+ Row(1, 2.0),
+ Row(2, 100.0),
+ Row(2, 1.0), // This row is duplicated to ensure that we will have multiple buffered matches
+ Row(2, 1.0),
+ Row(3, 3.0),
+ Row(5, 1.0),
+ Row(6, 6.0),
+ Row(null, null)
+ )), new StructType().add("a", IntegerType).add("b", DoubleType))
+
+ val right = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq(
+ Row(0, 0.0),
+ Row(2, 3.0), // This row is duplicated to ensure that we will have multiple buffered matches
+ Row(2, -1.0),
+ Row(2, -1.0),
+ Row(2, 3.0),
+ Row(3, 2.0),
+ Row(4, 1.0),
+ Row(5, 3.0),
+ Row(7, 7.0),
+ Row(null, null)
+ )), new StructType().add("c", IntegerType).add("d", DoubleType))
+
+ val condition = {
+ And(
+ (left.col("a") === right.col("c")).expr,
+ LessThan(left.col("b").expr, right.col("d").expr))
+ }
+
+ // --- Basic outer joins ------------------------------------------------------------------------
+
+ testOuterJoin(
+ "basic left outer join",
+ left,
+ right,
+ LeftOuter,
+ condition,
+ Seq(
+ (null, null, null, null),
+ (1, 2.0, null, null),
+ (2, 100.0, null, null),
+ (2, 1.0, 2, 3.0),
+ (2, 1.0, 2, 3.0),
+ (2, 1.0, 2, 3.0),
+ (2, 1.0, 2, 3.0),
+ (3, 3.0, null, null),
+ (5, 1.0, 5, 3.0),
+ (6, 6.0, null, null)
+ )
+ )
+
+ testOuterJoin(
+ "basic right outer join",
+ left,
+ right,
+ RightOuter,
+ condition,
+ Seq(
+ (null, null, null, null),
+ (null, null, 0, 0.0),
+ (2, 1.0, 2, 3.0),
+ (2, 1.0, 2, 3.0),
+ (null, null, 2, -1.0),
+ (null, null, 2, -1.0),
+ (2, 1.0, 2, 3.0),
+ (2, 1.0, 2, 3.0),
+ (null, null, 3, 2.0),
+ (null, null, 4, 1.0),
+ (5, 1.0, 5, 3.0),
+ (null, null, 7, 7.0)
+ )
+ )
+
+ testOuterJoin(
+ "basic full outer join",
+ left,
+ right,
+ FullOuter,
+ condition,
+ Seq(
+ (1, 2.0, null, null),
+ (null, null, 2, -1.0),
+ (null, null, 2, -1.0),
+ (2, 100.0, null, null),
+ (2, 1.0, 2, 3.0),
+ (2, 1.0, 2, 3.0),
+ (2, 1.0, 2, 3.0),
+ (2, 1.0, 2, 3.0),
+ (3, 3.0, null, null),
+ (5, 1.0, 5, 3.0),
+ (6, 6.0, null, null),
+ (null, null, 0, 0.0),
+ (null, null, 3, 2.0),
+ (null, null, 4, 1.0),
+ (null, null, 7, 7.0),
+ (null, null, null, null),
+ (null, null, null, null)
+ )
+ )
+
+ // --- Both inputs empty ------------------------------------------------------------------------
+
+ testOuterJoin(
+ "left outer join with both inputs empty",
+ left.filter("false"),
+ right.filter("false"),
+ LeftOuter,
+ condition,
+ Seq.empty
+ )
+
+ testOuterJoin(
+ "right outer join with both inputs empty",
+ left.filter("false"),
+ right.filter("false"),
+ RightOuter,
+ condition,
+ Seq.empty
+ )
+
+ testOuterJoin(
+ "full outer join with both inputs empty",
+ left.filter("false"),
+ right.filter("false"),
+ FullOuter,
+ condition,
+ Seq.empty
+ )
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
index 927e85a7db..4503ed251f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
@@ -17,58 +17,91 @@
package org.apache.spark.sql.execution.joins
-import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.catalyst.expressions.{LessThan, Expression}
-import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest}
+import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
+import org.apache.spark.sql.catalyst.plans.Inner
+import org.apache.spark.sql.catalyst.plans.logical.Join
+import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}
+import org.apache.spark.sql.{SQLConf, DataFrame, Row}
+import org.apache.spark.sql.catalyst.expressions.{And, LessThan, Expression}
+import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest}
+class SemiJoinSuite extends SparkPlanTest with SQLTestUtils {
-class SemiJoinSuite extends SparkPlanTest{
- val left = Seq(
- (1, 2.0),
- (1, 2.0),
- (2, 1.0),
- (2, 1.0),
- (3, 3.0)
- ).toDF("a", "b")
+ private def testLeftSemiJoin(
+ testName: String,
+ leftRows: DataFrame,
+ rightRows: DataFrame,
+ condition: Expression,
+ expectedAnswer: Seq[Product]): Unit = {
+ val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
+ ExtractEquiJoinKeys.unapply(join).foreach {
+ case (joinType, leftKeys, rightKeys, boundCondition, leftChild, rightChild) =>
+ test(s"$testName using LeftSemiJoinHash") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+ checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
+ EnsureRequirements(left.sqlContext).apply(
+ LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)),
+ expectedAnswer.map(Row.fromTuple),
+ sortAnswers = true)
+ }
+ }
- val right = Seq(
- (2, 3.0),
- (2, 3.0),
- (3, 2.0),
- (4, 1.0)
- ).toDF("c", "d")
+ test(s"$testName using BroadcastLeftSemiJoinHash") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+ checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
+ BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition),
+ expectedAnswer.map(Row.fromTuple),
+ sortAnswers = true)
+ }
+ }
+ }
- val leftKeys: List[Expression] = 'a :: Nil
- val rightKeys: List[Expression] = 'c :: Nil
- val condition = Some(LessThan('b, 'd))
-
- test("left semi join hash") {
- checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
- LeftSemiJoinHash(leftKeys, rightKeys, left, right, condition),
- Seq(
- (2, 1.0),
- (2, 1.0)
- ).map(Row.fromTuple))
+ test(s"$testName using LeftSemiJoinBNL") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+ checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
+ LeftSemiJoinBNL(left, right, Some(condition)),
+ expectedAnswer.map(Row.fromTuple),
+ sortAnswers = true)
+ }
+ }
}
- test("left semi join BNL") {
- checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
- LeftSemiJoinBNL(left, right, condition),
- Seq(
- (1, 2.0),
- (1, 2.0),
- (2, 1.0),
- (2, 1.0)
- ).map(Row.fromTuple))
- }
+ val left = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq(
+ Row(1, 2.0),
+ Row(1, 2.0),
+ Row(2, 1.0),
+ Row(2, 1.0),
+ Row(3, 3.0),
+ Row(null, null),
+ Row(null, 5.0),
+ Row(6, null)
+ )), new StructType().add("a", IntegerType).add("b", DoubleType))
- test("broadcast left semi join hash") {
- checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
- BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, condition),
- Seq(
- (2, 1.0),
- (2, 1.0)
- ).map(Row.fromTuple))
+ val right = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq(
+ Row(2, 3.0),
+ Row(2, 3.0),
+ Row(3, 2.0),
+ Row(4, 1.0),
+ Row(null, null),
+ Row(null, 5.0),
+ Row(6, null)
+ )), new StructType().add("c", IntegerType).add("d", DoubleType))
+
+ val condition = {
+ And(
+ (left.col("a") === right.col("c")).expr,
+ LessThan(left.col("b").expr, right.col("d").expr))
}
+
+ testLeftSemiJoin(
+ "basic test",
+ left,
+ right,
+ condition,
+ Seq(
+ (2, 1.0),
+ (2, 1.0)
+ )
+ )
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
index 4c11acdab9..1066695589 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.SQLContext
import org.apache.spark.util.Utils
trait SQLTestUtils { this: SparkFunSuite =>
- def sqlContext: SQLContext
+ protected def sqlContext: SQLContext
protected def configuration = sqlContext.sparkContext.hadoopConfiguration
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index 567d7fa12f..f17177a771 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -531,7 +531,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging {
HashAggregation,
Aggregation,
LeftSemiJoin,
- HashJoin,
+ EquiJoinSelection,
BasicOperators,
CartesianProduct,
BroadcastNestedLoopJoin