diff options
author | Davies Liu <davies@databricks.com> | 2016-04-26 12:43:47 -0700 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2016-04-26 12:43:47 -0700 |
commit | 7131b03bcf00cdda99e350f697946d4020a0822f (patch) | |
tree | 35960e7fd415fd776f6547a8a2da028dad883952 | |
parent | 89f082de0e2358ef8352deddcec5f8cc714f4721 (diff) | |
download | spark-7131b03bcf00cdda99e350f697946d4020a0822f.tar.gz spark-7131b03bcf00cdda99e350f697946d4020a0822f.tar.bz2 spark-7131b03bcf00cdda99e350f697946d4020a0822f.zip |
[SPARK-14853] [SQL] Support LeftSemi/LeftAnti in SortMergeJoinExec
## What changes were proposed in this pull request?
This PR update SortMergeJoinExec to support LeftSemi/LeftAnti, so it could support all the join types, same as other three join implementations: BroadcastHashJoinExec, ShuffledHashJoinExec,and BroadcastNestedLoopJoinExec.
This PR also simplify the join selection in SparkStrategy.
## How was this patch tested?
Added new tests.
Author: Davies Liu <davies@databricks.com>
Closes #12668 from davies/smj_semi.
10 files changed, 194 insertions, 175 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 0afa4c7bb9..de832ec70b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -38,13 +38,9 @@ class SparkPlanner( DDLStrategy :: SpecialLimits :: Aggregation :: - ExistenceJoin :: - EquiJoinSelection :: + JoinSelection :: InMemoryScans :: - BasicOperators :: - BroadcastNestedLoop :: - CartesianProduct :: - DefaultJoin :: Nil) + BasicOperators :: Nil) /** * Used to build table scan operators where complex projection and filtering are done using diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 3c10504fbd..3955c5dc92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -64,39 +64,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } - object ExistenceJoin extends Strategy with PredicateHelper { - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case ExtractEquiJoinKeys( - LeftExistence(jt), leftKeys, rightKeys, condition, left, CanBroadcast(right)) => - Seq(joins.BroadcastHashJoinExec( - leftKeys, rightKeys, jt, BuildRight, condition, planLater(left), planLater(right))) - // Find left semi joins where at least some predicates can be evaluated by matching join keys - case ExtractEquiJoinKeys( - LeftExistence(jt), leftKeys, rightKeys, condition, left, right) => - Seq(joins.ShuffledHashJoinExec( - leftKeys, rightKeys, jt, BuildRight, condition, planLater(left), planLater(right))) - case _ => Nil - } - } - - /** - * Matches a plan whose output should be small enough to be used in broadcast join. - */ - object CanBroadcast { - def unapply(plan: LogicalPlan): Option[LogicalPlan] = { - if (plan.statistics.sizeInBytes <= conf.autoBroadcastJoinThreshold) { - Some(plan) - } else { - None - } - } - } - /** - * Uses the [[ExtractEquiJoinKeys]] pattern to find joins where at least some of the predicates - * can be evaluated by matching join keys. + * Select the proper physical plan for join based on joining keys and size of logical plan. * - * Join implementations are chosen with the following precedence: + * At first, uses the [[ExtractEquiJoinKeys]] pattern to find joins where at least some of the + * predicates can be evaluated by matching join keys. If found, Join implementations are chosen + * with the following precedence: * * - Broadcast: if one side of the join has an estimated physical size that is smaller than the * user-configurable [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold @@ -107,8 +80,20 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * - Shuffle hash join: if the average size of a single partition is small enough to build a hash * table. * - Sort merge: if the matching join keys are sortable. + * + * If there is no joining keys, Join implementations are chosen with the following precedence: + * - BroadcastNestedLoopJoin: if one side of the join could be broadcasted + * - CartesianProduct: for Inner join + * - BroadcastNestedLoopJoin */ - object EquiJoinSelection extends Strategy with PredicateHelper { + object JoinSelection extends Strategy with PredicateHelper { + + /** + * Matches a plan whose output should be small enough to be used in broadcast join. + */ + private def canBroadcast(plan: LogicalPlan): Boolean = { + plan.statistics.sizeInBytes <= conf.autoBroadcastJoinThreshold + } /** * Matches a plan whose single partition should be small enough to build a hash table. @@ -116,7 +101,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Note: this assume that the number of partition is fixed, requires additional work if it's * dynamic. */ - def canBuildHashMap(plan: LogicalPlan): Boolean = { + private def canBuildLocalHashMap(plan: LogicalPlan): Boolean = { plan.statistics.sizeInBytes < conf.autoBroadcastJoinThreshold * conf.numShufflePartitions } @@ -131,76 +116,80 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { a.statistics.sizeInBytes * 3 <= b.statistics.sizeInBytes } - /** - * Returns whether we should use shuffle hash join or not. - * - * We should only use shuffle hash join when: - * 1) any single partition of a small table could fit in memory. - * 2) the smaller table is much smaller (3X) than the other one. - */ - private def shouldShuffleHashJoin(left: LogicalPlan, right: LogicalPlan): Boolean = { - canBuildHashMap(left) && muchSmaller(left, right) || - canBuildHashMap(right) && muchSmaller(right, left) + private def canBuildRight(joinType: JoinType): Boolean = joinType match { + case Inner | LeftOuter | LeftSemi | LeftAnti => true + case _ => false } - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + private def canBuildLeft(joinType: JoinType): Boolean = joinType match { + case Inner | RightOuter => true + case _ => false + } - // --- Inner joins -------------------------------------------------------------------------- + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => - Seq(joins.BroadcastHashJoinExec( - leftKeys, rightKeys, Inner, BuildRight, condition, planLater(left), planLater(right))) + // --- BroadcastHashJoin -------------------------------------------------------------------- - case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, CanBroadcast(left), right) => + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) + if canBuildRight(joinType) && canBroadcast(right) => Seq(joins.BroadcastHashJoinExec( - leftKeys, rightKeys, Inner, BuildLeft, condition, planLater(left), planLater(right))) - - case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) - if !conf.preferSortMergeJoin && shouldShuffleHashJoin(left, right) || - !RowOrdering.isOrderable(leftKeys) => - val buildSide = - if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { - BuildRight - } else { - BuildLeft - } - Seq(joins.ShuffledHashJoinExec( - leftKeys, rightKeys, Inner, buildSide, condition, planLater(left), planLater(right))) - - case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) - if RowOrdering.isOrderable(leftKeys) => - joins.SortMergeJoinExec( - leftKeys, rightKeys, Inner, condition, planLater(left), planLater(right)) :: Nil - - // --- Outer joins -------------------------------------------------------------------------- + leftKeys, rightKeys, joinType, BuildRight, condition, planLater(left), planLater(right))) - case ExtractEquiJoinKeys( - LeftOuter, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) + if canBuildLeft(joinType) && canBroadcast(left) => Seq(joins.BroadcastHashJoinExec( - leftKeys, rightKeys, LeftOuter, BuildRight, condition, planLater(left), planLater(right))) + leftKeys, rightKeys, joinType, BuildLeft, condition, planLater(left), planLater(right))) - case ExtractEquiJoinKeys( - RightOuter, leftKeys, rightKeys, condition, CanBroadcast(left), right) => - Seq(joins.BroadcastHashJoinExec( - leftKeys, rightKeys, RightOuter, BuildLeft, condition, planLater(left), planLater(right))) + // --- ShuffledHashJoin --------------------------------------------------------------------- - case ExtractEquiJoinKeys(LeftOuter, leftKeys, rightKeys, condition, left, right) - if !conf.preferSortMergeJoin && canBuildHashMap(right) && muchSmaller(right, left) || + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) + if !conf.preferSortMergeJoin && canBuildRight(joinType) && canBuildLocalHashMap(right) + && muchSmaller(right, left) || !RowOrdering.isOrderable(leftKeys) => Seq(joins.ShuffledHashJoinExec( - leftKeys, rightKeys, LeftOuter, BuildRight, condition, planLater(left), planLater(right))) + leftKeys, rightKeys, joinType, BuildRight, condition, planLater(left), planLater(right))) - case ExtractEquiJoinKeys(RightOuter, leftKeys, rightKeys, condition, left, right) - if !conf.preferSortMergeJoin && canBuildHashMap(left) && muchSmaller(left, right) || + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) + if !conf.preferSortMergeJoin && canBuildLeft(joinType) && canBuildLocalHashMap(left) + && muchSmaller(left, right) || !RowOrdering.isOrderable(leftKeys) => Seq(joins.ShuffledHashJoinExec( - leftKeys, rightKeys, RightOuter, BuildLeft, condition, planLater(left), planLater(right))) + leftKeys, rightKeys, joinType, BuildLeft, condition, planLater(left), planLater(right))) + + // --- SortMergeJoin ------------------------------------------------------------ case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) if RowOrdering.isOrderable(leftKeys) => joins.SortMergeJoinExec( leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil + // --- Without joining keys ------------------------------------------------------------ + + // Pick BroadcastNestedLoopJoin if one side could be broadcasted + case j @ logical.Join(left, right, joinType, condition) + if canBuildRight(joinType) && canBroadcast(right) => + joins.BroadcastNestedLoopJoinExec( + planLater(left), planLater(right), BuildRight, joinType, condition) :: Nil + case j @ logical.Join(left, right, joinType, condition) + if canBuildLeft(joinType) && canBroadcast(left) => + joins.BroadcastNestedLoopJoinExec( + planLater(left), planLater(right), BuildLeft, joinType, condition) :: Nil + + // Pick CartesianProduct for InnerJoin + case logical.Join(left, right, Inner, condition) => + joins.CartesianProductExec(planLater(left), planLater(right), condition) :: Nil + + case logical.Join(left, right, joinType, condition) => + val buildSide = + if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { + BuildRight + } else { + BuildLeft + } + // This join could be very slow or OOM + joins.BroadcastNestedLoopJoinExec( + planLater(left), planLater(right), buildSide, joinType, condition) :: Nil + // --- Cases where this strategy does not apply --------------------------------------------- case _ => Nil @@ -277,45 +266,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } - object BroadcastNestedLoop extends Strategy { - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case j @ logical.Join(CanBroadcast(left), right, Inner | RightOuter, condition) => - execution.joins.BroadcastNestedLoopJoinExec( - planLater(left), planLater(right), joins.BuildLeft, j.joinType, condition) :: Nil - case j @ logical.Join(left, CanBroadcast(right), Inner | LeftOuter | LeftSemi, condition) => - execution.joins.BroadcastNestedLoopJoinExec( - planLater(left), planLater(right), joins.BuildRight, j.joinType, condition) :: Nil - case _ => Nil - } - } - - object CartesianProduct extends Strategy { - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case logical.Join(left, right, Inner, None) => - execution.joins.CartesianProductExec(planLater(left), planLater(right)) :: Nil - case logical.Join(left, right, Inner, Some(condition)) => - execution.FilterExec(condition, - execution.joins.CartesianProductExec(planLater(left), planLater(right))) :: Nil - case _ => Nil - } - } - - object DefaultJoin extends Strategy { - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case logical.Join(left, right, joinType, condition) => - val buildSide = - if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { - joins.BuildRight - } else { - joins.BuildLeft - } - // This join could be very slow or even hang forever - joins.BroadcastNestedLoopJoinExec( - planLater(left), planLater(right), buildSide, joinType, condition) :: Nil - case _ => Nil - } - } - protected lazy val singleRowRdd = sparkContext.parallelize(Seq(InternalRow()), 1) object InMemoryScans extends Strategy { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala index 3ce7c0e315..67f59197ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark._ import org.apache.spark.rdd.{CartesianPartition, CartesianRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, JoinedRow, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -79,7 +79,10 @@ class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numField } -case class CartesianProductExec(left: SparkPlan, right: SparkPlan) extends BinaryExecNode { +case class CartesianProductExec( + left: SparkPlan, + right: SparkPlan, + condition: Option[Expression]) extends BinaryExecNode { override def output: Seq[Attribute] = left.output ++ right.output override private[sql] lazy val metrics = Map( @@ -94,7 +97,18 @@ case class CartesianProductExec(left: SparkPlan, right: SparkPlan) extends Binar val pair = new UnsafeCartesianRDD(leftResults, rightResults, right.output.size) pair.mapPartitionsInternal { iter => val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema) - iter.map { r => + val filtered = if (condition.isDefined) { + val boundCondition: (InternalRow) => Boolean = + newPredicate(condition.get, left.output ++ right.output) + val joined = new JoinedRow + + iter.filter { r => + boundCondition(joined(r._1, r._2)) + } + } else { + iter + } + filtered.map { r => numOutputRows += 1 joiner.join(r._1, r._2) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 96b283a5e4..a4c5491aff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -53,6 +53,8 @@ case class SortMergeJoinExec( left.output.map(_.withNullability(true)) ++ right.output case FullOuter => (left.output ++ right.output).map(_.withNullability(true)) + case LeftExistence(_) => + left.output case x => throw new IllegalArgumentException( s"${getClass.getSimpleName} should not take $x as the JoinType") @@ -65,6 +67,7 @@ case class SortMergeJoinExec( case LeftOuter => left.outputPartitioning case RightOuter => right.outputPartitioning case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) + case LeftExistence(_) => left.outputPartitioning case x => throw new IllegalArgumentException( s"${getClass.getSimpleName} should not take $x as the JoinType") @@ -100,6 +103,7 @@ case class SortMergeJoinExec( (r: InternalRow) => true } } + // An ordering that can be used to compare keys from both sides. val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType)) val resultProj: InternalRow => InternalRow = UnsafeProjection.create(output, output) @@ -107,27 +111,17 @@ case class SortMergeJoinExec( joinType match { case Inner => new RowIterator { - // The projection used to extract keys from input rows of the left child. - private[this] val leftKeyGenerator = UnsafeProjection.create(leftKeys, left.output) - - // The projection used to extract keys from input rows of the right child. - private[this] val rightKeyGenerator = UnsafeProjection.create(rightKeys, right.output) - - // An ordering that can be used to compare keys from both sides. - private[this] val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType)) private[this] var currentLeftRow: InternalRow = _ private[this] var currentRightMatches: ArrayBuffer[InternalRow] = _ private[this] var currentMatchIdx: Int = -1 private[this] val smjScanner = new SortMergeJoinScanner( - leftKeyGenerator, - rightKeyGenerator, + createLeftKeyGenerator(), + createRightKeyGenerator(), keyOrdering, RowIterator.fromScala(leftIter), RowIterator.fromScala(rightIter) ) private[this] val joinRow = new JoinedRow - private[this] val resultProjection: (InternalRow) => InternalRow = - UnsafeProjection.create(schema) if (smjScanner.findNextInnerJoinRows()) { currentRightMatches = smjScanner.getBufferedMatches @@ -159,7 +153,7 @@ case class SortMergeJoinExec( false } - override def getRow: InternalRow = resultProjection(joinRow) + override def getRow: InternalRow = resultProj(joinRow) }.toScala case LeftOuter => @@ -204,6 +198,77 @@ case class SortMergeJoinExec( resultProj, numOutputRows).toScala + case LeftSemi => + new RowIterator { + private[this] var currentLeftRow: InternalRow = _ + private[this] val smjScanner = new SortMergeJoinScanner( + createLeftKeyGenerator(), + createRightKeyGenerator(), + keyOrdering, + RowIterator.fromScala(leftIter), + RowIterator.fromScala(rightIter) + ) + private[this] val joinRow = new JoinedRow + + override def advanceNext(): Boolean = { + while (smjScanner.findNextInnerJoinRows()) { + val currentRightMatches = smjScanner.getBufferedMatches + currentLeftRow = smjScanner.getStreamedRow + var i = 0 + while (i < currentRightMatches.length) { + joinRow(currentLeftRow, currentRightMatches(i)) + if (boundCondition(joinRow)) { + numOutputRows += 1 + return true + } + i += 1 + } + } + false + } + + override def getRow: InternalRow = currentLeftRow + }.toScala + + case LeftAnti => + new RowIterator { + private[this] var currentLeftRow: InternalRow = _ + private[this] val smjScanner = new SortMergeJoinScanner( + createLeftKeyGenerator(), + createRightKeyGenerator(), + keyOrdering, + RowIterator.fromScala(leftIter), + RowIterator.fromScala(rightIter) + ) + private[this] val joinRow = new JoinedRow + + override def advanceNext(): Boolean = { + while (smjScanner.findNextOuterJoinRows()) { + currentLeftRow = smjScanner.getStreamedRow + val currentRightMatches = smjScanner.getBufferedMatches + if (currentRightMatches == null) { + return true + } + var i = 0 + var found = false + while (!found && i < currentRightMatches.length) { + joinRow(currentLeftRow, currentRightMatches(i)) + if (boundCondition(joinRow)) { + found = true + } + i += 1 + } + if (!found) { + numOutputRows += 1 + return true + } + } + false + } + + override def getRow: InternalRow = currentLeftRow + }.toScala + case x => throw new IllegalArgumentException( s"SortMergeJoin should not take $x as the JoinType") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index ef9bb7ea4f..8cbad04e23 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -37,7 +37,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.optimizedPlan - val planned = sqlContext.sessionState.planner.EquiJoinSelection(join) + val planned = sqlContext.sessionState.planner.JoinSelection(join) assert(planned.size === 1) } @@ -65,7 +65,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "0") { Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", - classOf[ShuffledHashJoinExec]), + classOf[SortMergeJoinExec]), ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[BroadcastNestedLoopJoinExec]), ("SELECT * FROM testData JOIN testData2", classOf[CartesianProductExec]), ("SELECT * FROM testData JOIN testData2 WHERE key = 2", classOf[CartesianProductExec]), @@ -99,7 +99,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { classOf[BroadcastNestedLoopJoinExec]), ("SELECT * FROM testData full JOIN testData2 ON (key * a != key + a)", classOf[BroadcastNestedLoopJoinExec]), - ("SELECT * FROM testData ANTI JOIN testData2 ON key = a", classOf[ShuffledHashJoinExec]), + ("SELECT * FROM testData ANTI JOIN testData2 ON key = a", classOf[SortMergeJoinExec]), ("SELECT * FROM testData LEFT ANTI JOIN testData2", classOf[BroadcastNestedLoopJoinExec]) ).foreach(assertJoin) } @@ -144,7 +144,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.optimizedPlan - val planned = sqlContext.sessionState.planner.EquiJoinSelection(join) + val planned = sqlContext.sessionState.planner.JoinSelection(join) assert(planned.size === 1) } @@ -449,9 +449,9 @@ class JoinSuite extends QueryTest with SharedSQLContext { withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", - classOf[ShuffledHashJoinExec]), + classOf[SortMergeJoinExec]), ("SELECT * FROM testData LEFT ANTI JOIN testData2 ON key = a", - classOf[ShuffledHashJoinExec]) + classOf[SortMergeJoinExec]) ).foreach(assertJoin) } @@ -475,7 +475,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", - classOf[ShuffledHashJoinExec]), + classOf[SortMergeJoinExec]), ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[BroadcastNestedLoopJoinExec]), ("SELECT * FROM testData JOIN testData2", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala index bc838ee4da..c7c10abe9a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala @@ -104,6 +104,18 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { } } + test(s"$testName using SortMergeJoin") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(left.sqlContext.sessionState.conf).apply( + SortMergeJoinExec(leftKeys, rightKeys, joinType, boundCondition, left, right)), + expectedAnswer, + sortAnswers = true) + } + } + } + test(s"$testName using BroadcastNestedLoopJoin build left") { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 933f32e496..2a4a3690f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -189,7 +189,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { test(s"$testName using CartesianProduct") { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - FilterExec(condition(), CartesianProductExec(left, right)), + CartesianProductExec(left, right, Some(condition())), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 695b1824e8..1859c6e7ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -255,28 +255,14 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value") // Assume the execution plan is - // ... -> BroadcastLeftSemiJoinHash(nodeId = 0) + // ... -> BroadcastHashJoin(nodeId = 0) val df = df1.join(broadcast(df2), $"key" === $"key2", "leftsemi") testSparkPlanMetrics(df, 2, Map( - 0L -> ("BroadcastLeftSemiJoinHash", Map( + 0L -> ("BroadcastHashJoin", Map( "number of output rows" -> 2L))) ) } - test("ShuffledHashJoin metrics") { - withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { - val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") - val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value") - // Assume the execution plan is - // ... -> ShuffledHashJoin(nodeId = 0) - val df = df1.join(df2, $"key" === $"key2", "leftsemi") - testSparkPlanMetrics(df, 1, Map( - 0L -> ("ShuffledHashJoin", Map( - "number of output rows" -> 2L))) - ) - } - } - test("CartesianProduct metrics") { val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) testDataForJoin.registerTempTable("testDataForJoin") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala index 4a8978e553..9633f9e15b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala @@ -120,12 +120,8 @@ private[hive] class HiveSessionState(sparkSession: SparkSession) DataSinks, Scripts, Aggregation, - ExistenceJoin, - EquiJoinSelection, - BasicOperators, - BroadcastNestedLoop, - CartesianProduct, - DefaultJoin + JoinSelection, + BasicOperators ) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 93a6f0bb58..f6b5101498 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -228,10 +228,10 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton { assert(bhj.isEmpty, "BroadcastHashJoin still planned even though it is switched off") val shj = df.queryExecution.sparkPlan.collect { - case j: ShuffledHashJoinExec => j + case j: SortMergeJoinExec => j } assert(shj.size === 1, - "LeftSemiJoinHash should be planned when BroadcastHashJoin is turned off") + "SortMergeJoinExec should be planned when BroadcastHashJoin is turned off") sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=$tmp") } |