aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-04-26 12:43:47 -0700
committerDavies Liu <davies.liu@gmail.com>2016-04-26 12:43:47 -0700
commit7131b03bcf00cdda99e350f697946d4020a0822f (patch)
tree35960e7fd415fd776f6547a8a2da028dad883952
parent89f082de0e2358ef8352deddcec5f8cc714f4721 (diff)
downloadspark-7131b03bcf00cdda99e350f697946d4020a0822f.tar.gz
spark-7131b03bcf00cdda99e350f697946d4020a0822f.tar.bz2
spark-7131b03bcf00cdda99e350f697946d4020a0822f.zip
[SPARK-14853] [SQL] Support LeftSemi/LeftAnti in SortMergeJoinExec
## What changes were proposed in this pull request? This PR update SortMergeJoinExec to support LeftSemi/LeftAnti, so it could support all the join types, same as other three join implementations: BroadcastHashJoinExec, ShuffledHashJoinExec,and BroadcastNestedLoopJoinExec. This PR also simplify the join selection in SparkStrategy. ## How was this patch tested? Added new tests. Author: Davies Liu <davies@databricks.com> Closes #12668 from davies/smj_semi.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala192
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala20
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala91
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala14
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala12
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala18
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala8
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala4
10 files changed, 194 insertions, 175 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
index 0afa4c7bb9..de832ec70b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
@@ -38,13 +38,9 @@ class SparkPlanner(
DDLStrategy ::
SpecialLimits ::
Aggregation ::
- ExistenceJoin ::
- EquiJoinSelection ::
+ JoinSelection ::
InMemoryScans ::
- BasicOperators ::
- BroadcastNestedLoop ::
- CartesianProduct ::
- DefaultJoin :: Nil)
+ BasicOperators :: Nil)
/**
* Used to build table scan operators where complex projection and filtering are done using
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 3c10504fbd..3955c5dc92 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -64,39 +64,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
}
- object ExistenceJoin extends Strategy with PredicateHelper {
- def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case ExtractEquiJoinKeys(
- LeftExistence(jt), leftKeys, rightKeys, condition, left, CanBroadcast(right)) =>
- Seq(joins.BroadcastHashJoinExec(
- leftKeys, rightKeys, jt, BuildRight, condition, planLater(left), planLater(right)))
- // Find left semi joins where at least some predicates can be evaluated by matching join keys
- case ExtractEquiJoinKeys(
- LeftExistence(jt), leftKeys, rightKeys, condition, left, right) =>
- Seq(joins.ShuffledHashJoinExec(
- leftKeys, rightKeys, jt, BuildRight, condition, planLater(left), planLater(right)))
- case _ => Nil
- }
- }
-
- /**
- * Matches a plan whose output should be small enough to be used in broadcast join.
- */
- object CanBroadcast {
- def unapply(plan: LogicalPlan): Option[LogicalPlan] = {
- if (plan.statistics.sizeInBytes <= conf.autoBroadcastJoinThreshold) {
- Some(plan)
- } else {
- None
- }
- }
- }
-
/**
- * Uses the [[ExtractEquiJoinKeys]] pattern to find joins where at least some of the predicates
- * can be evaluated by matching join keys.
+ * Select the proper physical plan for join based on joining keys and size of logical plan.
*
- * Join implementations are chosen with the following precedence:
+ * At first, uses the [[ExtractEquiJoinKeys]] pattern to find joins where at least some of the
+ * predicates can be evaluated by matching join keys. If found, Join implementations are chosen
+ * with the following precedence:
*
* - Broadcast: if one side of the join has an estimated physical size that is smaller than the
* user-configurable [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold
@@ -107,8 +80,20 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
* - Shuffle hash join: if the average size of a single partition is small enough to build a hash
* table.
* - Sort merge: if the matching join keys are sortable.
+ *
+ * If there is no joining keys, Join implementations are chosen with the following precedence:
+ * - BroadcastNestedLoopJoin: if one side of the join could be broadcasted
+ * - CartesianProduct: for Inner join
+ * - BroadcastNestedLoopJoin
*/
- object EquiJoinSelection extends Strategy with PredicateHelper {
+ object JoinSelection extends Strategy with PredicateHelper {
+
+ /**
+ * Matches a plan whose output should be small enough to be used in broadcast join.
+ */
+ private def canBroadcast(plan: LogicalPlan): Boolean = {
+ plan.statistics.sizeInBytes <= conf.autoBroadcastJoinThreshold
+ }
/**
* Matches a plan whose single partition should be small enough to build a hash table.
@@ -116,7 +101,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
* Note: this assume that the number of partition is fixed, requires additional work if it's
* dynamic.
*/
- def canBuildHashMap(plan: LogicalPlan): Boolean = {
+ private def canBuildLocalHashMap(plan: LogicalPlan): Boolean = {
plan.statistics.sizeInBytes < conf.autoBroadcastJoinThreshold * conf.numShufflePartitions
}
@@ -131,76 +116,80 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
a.statistics.sizeInBytes * 3 <= b.statistics.sizeInBytes
}
- /**
- * Returns whether we should use shuffle hash join or not.
- *
- * We should only use shuffle hash join when:
- * 1) any single partition of a small table could fit in memory.
- * 2) the smaller table is much smaller (3X) than the other one.
- */
- private def shouldShuffleHashJoin(left: LogicalPlan, right: LogicalPlan): Boolean = {
- canBuildHashMap(left) && muchSmaller(left, right) ||
- canBuildHashMap(right) && muchSmaller(right, left)
+ private def canBuildRight(joinType: JoinType): Boolean = joinType match {
+ case Inner | LeftOuter | LeftSemi | LeftAnti => true
+ case _ => false
}
- def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ private def canBuildLeft(joinType: JoinType): Boolean = joinType match {
+ case Inner | RightOuter => true
+ case _ => false
+ }
- // --- Inner joins --------------------------------------------------------------------------
+ def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, CanBroadcast(right)) =>
- Seq(joins.BroadcastHashJoinExec(
- leftKeys, rightKeys, Inner, BuildRight, condition, planLater(left), planLater(right)))
+ // --- BroadcastHashJoin --------------------------------------------------------------------
- case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, CanBroadcast(left), right) =>
+ case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
+ if canBuildRight(joinType) && canBroadcast(right) =>
Seq(joins.BroadcastHashJoinExec(
- leftKeys, rightKeys, Inner, BuildLeft, condition, planLater(left), planLater(right)))
-
- case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
- if !conf.preferSortMergeJoin && shouldShuffleHashJoin(left, right) ||
- !RowOrdering.isOrderable(leftKeys) =>
- val buildSide =
- if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) {
- BuildRight
- } else {
- BuildLeft
- }
- Seq(joins.ShuffledHashJoinExec(
- leftKeys, rightKeys, Inner, buildSide, condition, planLater(left), planLater(right)))
-
- case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
- if RowOrdering.isOrderable(leftKeys) =>
- joins.SortMergeJoinExec(
- leftKeys, rightKeys, Inner, condition, planLater(left), planLater(right)) :: Nil
-
- // --- Outer joins --------------------------------------------------------------------------
+ leftKeys, rightKeys, joinType, BuildRight, condition, planLater(left), planLater(right)))
- case ExtractEquiJoinKeys(
- LeftOuter, leftKeys, rightKeys, condition, left, CanBroadcast(right)) =>
+ case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
+ if canBuildLeft(joinType) && canBroadcast(left) =>
Seq(joins.BroadcastHashJoinExec(
- leftKeys, rightKeys, LeftOuter, BuildRight, condition, planLater(left), planLater(right)))
+ leftKeys, rightKeys, joinType, BuildLeft, condition, planLater(left), planLater(right)))
- case ExtractEquiJoinKeys(
- RightOuter, leftKeys, rightKeys, condition, CanBroadcast(left), right) =>
- Seq(joins.BroadcastHashJoinExec(
- leftKeys, rightKeys, RightOuter, BuildLeft, condition, planLater(left), planLater(right)))
+ // --- ShuffledHashJoin ---------------------------------------------------------------------
- case ExtractEquiJoinKeys(LeftOuter, leftKeys, rightKeys, condition, left, right)
- if !conf.preferSortMergeJoin && canBuildHashMap(right) && muchSmaller(right, left) ||
+ case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
+ if !conf.preferSortMergeJoin && canBuildRight(joinType) && canBuildLocalHashMap(right)
+ && muchSmaller(right, left) ||
!RowOrdering.isOrderable(leftKeys) =>
Seq(joins.ShuffledHashJoinExec(
- leftKeys, rightKeys, LeftOuter, BuildRight, condition, planLater(left), planLater(right)))
+ leftKeys, rightKeys, joinType, BuildRight, condition, planLater(left), planLater(right)))
- case ExtractEquiJoinKeys(RightOuter, leftKeys, rightKeys, condition, left, right)
- if !conf.preferSortMergeJoin && canBuildHashMap(left) && muchSmaller(left, right) ||
+ case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
+ if !conf.preferSortMergeJoin && canBuildLeft(joinType) && canBuildLocalHashMap(left)
+ && muchSmaller(left, right) ||
!RowOrdering.isOrderable(leftKeys) =>
Seq(joins.ShuffledHashJoinExec(
- leftKeys, rightKeys, RightOuter, BuildLeft, condition, planLater(left), planLater(right)))
+ leftKeys, rightKeys, joinType, BuildLeft, condition, planLater(left), planLater(right)))
+
+ // --- SortMergeJoin ------------------------------------------------------------
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
if RowOrdering.isOrderable(leftKeys) =>
joins.SortMergeJoinExec(
leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil
+ // --- Without joining keys ------------------------------------------------------------
+
+ // Pick BroadcastNestedLoopJoin if one side could be broadcasted
+ case j @ logical.Join(left, right, joinType, condition)
+ if canBuildRight(joinType) && canBroadcast(right) =>
+ joins.BroadcastNestedLoopJoinExec(
+ planLater(left), planLater(right), BuildRight, joinType, condition) :: Nil
+ case j @ logical.Join(left, right, joinType, condition)
+ if canBuildLeft(joinType) && canBroadcast(left) =>
+ joins.BroadcastNestedLoopJoinExec(
+ planLater(left), planLater(right), BuildLeft, joinType, condition) :: Nil
+
+ // Pick CartesianProduct for InnerJoin
+ case logical.Join(left, right, Inner, condition) =>
+ joins.CartesianProductExec(planLater(left), planLater(right), condition) :: Nil
+
+ case logical.Join(left, right, joinType, condition) =>
+ val buildSide =
+ if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) {
+ BuildRight
+ } else {
+ BuildLeft
+ }
+ // This join could be very slow or OOM
+ joins.BroadcastNestedLoopJoinExec(
+ planLater(left), planLater(right), buildSide, joinType, condition) :: Nil
+
// --- Cases where this strategy does not apply ---------------------------------------------
case _ => Nil
@@ -277,45 +266,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
}
- object BroadcastNestedLoop extends Strategy {
- def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case j @ logical.Join(CanBroadcast(left), right, Inner | RightOuter, condition) =>
- execution.joins.BroadcastNestedLoopJoinExec(
- planLater(left), planLater(right), joins.BuildLeft, j.joinType, condition) :: Nil
- case j @ logical.Join(left, CanBroadcast(right), Inner | LeftOuter | LeftSemi, condition) =>
- execution.joins.BroadcastNestedLoopJoinExec(
- planLater(left), planLater(right), joins.BuildRight, j.joinType, condition) :: Nil
- case _ => Nil
- }
- }
-
- object CartesianProduct extends Strategy {
- def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case logical.Join(left, right, Inner, None) =>
- execution.joins.CartesianProductExec(planLater(left), planLater(right)) :: Nil
- case logical.Join(left, right, Inner, Some(condition)) =>
- execution.FilterExec(condition,
- execution.joins.CartesianProductExec(planLater(left), planLater(right))) :: Nil
- case _ => Nil
- }
- }
-
- object DefaultJoin extends Strategy {
- def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case logical.Join(left, right, joinType, condition) =>
- val buildSide =
- if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) {
- joins.BuildRight
- } else {
- joins.BuildLeft
- }
- // This join could be very slow or even hang forever
- joins.BroadcastNestedLoopJoinExec(
- planLater(left), planLater(right), buildSide, joinType, condition) :: Nil
- case _ => Nil
- }
- }
-
protected lazy val singleRowRdd = sparkContext.parallelize(Seq(InternalRow()), 1)
object InMemoryScans extends Strategy {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
index 3ce7c0e315..67f59197ad 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark._
import org.apache.spark.rdd.{CartesianPartition, CartesianRDD, RDD}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, JoinedRow, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
@@ -79,7 +79,10 @@ class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numField
}
-case class CartesianProductExec(left: SparkPlan, right: SparkPlan) extends BinaryExecNode {
+case class CartesianProductExec(
+ left: SparkPlan,
+ right: SparkPlan,
+ condition: Option[Expression]) extends BinaryExecNode {
override def output: Seq[Attribute] = left.output ++ right.output
override private[sql] lazy val metrics = Map(
@@ -94,7 +97,18 @@ case class CartesianProductExec(left: SparkPlan, right: SparkPlan) extends Binar
val pair = new UnsafeCartesianRDD(leftResults, rightResults, right.output.size)
pair.mapPartitionsInternal { iter =>
val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema)
- iter.map { r =>
+ val filtered = if (condition.isDefined) {
+ val boundCondition: (InternalRow) => Boolean =
+ newPredicate(condition.get, left.output ++ right.output)
+ val joined = new JoinedRow
+
+ iter.filter { r =>
+ boundCondition(joined(r._1, r._2))
+ }
+ } else {
+ iter
+ }
+ filtered.map { r =>
numOutputRows += 1
joiner.join(r._1, r._2)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index 96b283a5e4..a4c5491aff 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -53,6 +53,8 @@ case class SortMergeJoinExec(
left.output.map(_.withNullability(true)) ++ right.output
case FullOuter =>
(left.output ++ right.output).map(_.withNullability(true))
+ case LeftExistence(_) =>
+ left.output
case x =>
throw new IllegalArgumentException(
s"${getClass.getSimpleName} should not take $x as the JoinType")
@@ -65,6 +67,7 @@ case class SortMergeJoinExec(
case LeftOuter => left.outputPartitioning
case RightOuter => right.outputPartitioning
case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
+ case LeftExistence(_) => left.outputPartitioning
case x =>
throw new IllegalArgumentException(
s"${getClass.getSimpleName} should not take $x as the JoinType")
@@ -100,6 +103,7 @@ case class SortMergeJoinExec(
(r: InternalRow) => true
}
}
+
// An ordering that can be used to compare keys from both sides.
val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType))
val resultProj: InternalRow => InternalRow = UnsafeProjection.create(output, output)
@@ -107,27 +111,17 @@ case class SortMergeJoinExec(
joinType match {
case Inner =>
new RowIterator {
- // The projection used to extract keys from input rows of the left child.
- private[this] val leftKeyGenerator = UnsafeProjection.create(leftKeys, left.output)
-
- // The projection used to extract keys from input rows of the right child.
- private[this] val rightKeyGenerator = UnsafeProjection.create(rightKeys, right.output)
-
- // An ordering that can be used to compare keys from both sides.
- private[this] val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType))
private[this] var currentLeftRow: InternalRow = _
private[this] var currentRightMatches: ArrayBuffer[InternalRow] = _
private[this] var currentMatchIdx: Int = -1
private[this] val smjScanner = new SortMergeJoinScanner(
- leftKeyGenerator,
- rightKeyGenerator,
+ createLeftKeyGenerator(),
+ createRightKeyGenerator(),
keyOrdering,
RowIterator.fromScala(leftIter),
RowIterator.fromScala(rightIter)
)
private[this] val joinRow = new JoinedRow
- private[this] val resultProjection: (InternalRow) => InternalRow =
- UnsafeProjection.create(schema)
if (smjScanner.findNextInnerJoinRows()) {
currentRightMatches = smjScanner.getBufferedMatches
@@ -159,7 +153,7 @@ case class SortMergeJoinExec(
false
}
- override def getRow: InternalRow = resultProjection(joinRow)
+ override def getRow: InternalRow = resultProj(joinRow)
}.toScala
case LeftOuter =>
@@ -204,6 +198,77 @@ case class SortMergeJoinExec(
resultProj,
numOutputRows).toScala
+ case LeftSemi =>
+ new RowIterator {
+ private[this] var currentLeftRow: InternalRow = _
+ private[this] val smjScanner = new SortMergeJoinScanner(
+ createLeftKeyGenerator(),
+ createRightKeyGenerator(),
+ keyOrdering,
+ RowIterator.fromScala(leftIter),
+ RowIterator.fromScala(rightIter)
+ )
+ private[this] val joinRow = new JoinedRow
+
+ override def advanceNext(): Boolean = {
+ while (smjScanner.findNextInnerJoinRows()) {
+ val currentRightMatches = smjScanner.getBufferedMatches
+ currentLeftRow = smjScanner.getStreamedRow
+ var i = 0
+ while (i < currentRightMatches.length) {
+ joinRow(currentLeftRow, currentRightMatches(i))
+ if (boundCondition(joinRow)) {
+ numOutputRows += 1
+ return true
+ }
+ i += 1
+ }
+ }
+ false
+ }
+
+ override def getRow: InternalRow = currentLeftRow
+ }.toScala
+
+ case LeftAnti =>
+ new RowIterator {
+ private[this] var currentLeftRow: InternalRow = _
+ private[this] val smjScanner = new SortMergeJoinScanner(
+ createLeftKeyGenerator(),
+ createRightKeyGenerator(),
+ keyOrdering,
+ RowIterator.fromScala(leftIter),
+ RowIterator.fromScala(rightIter)
+ )
+ private[this] val joinRow = new JoinedRow
+
+ override def advanceNext(): Boolean = {
+ while (smjScanner.findNextOuterJoinRows()) {
+ currentLeftRow = smjScanner.getStreamedRow
+ val currentRightMatches = smjScanner.getBufferedMatches
+ if (currentRightMatches == null) {
+ return true
+ }
+ var i = 0
+ var found = false
+ while (!found && i < currentRightMatches.length) {
+ joinRow(currentLeftRow, currentRightMatches(i))
+ if (boundCondition(joinRow)) {
+ found = true
+ }
+ i += 1
+ }
+ if (!found) {
+ numOutputRows += 1
+ return true
+ }
+ }
+ false
+ }
+
+ override def getRow: InternalRow = currentLeftRow
+ }.toScala
+
case x =>
throw new IllegalArgumentException(
s"SortMergeJoin should not take $x as the JoinType")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index ef9bb7ea4f..8cbad04e23 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -37,7 +37,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
val x = testData2.as("x")
val y = testData2.as("y")
val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.optimizedPlan
- val planned = sqlContext.sessionState.planner.EquiJoinSelection(join)
+ val planned = sqlContext.sessionState.planner.JoinSelection(join)
assert(planned.size === 1)
}
@@ -65,7 +65,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "0") {
Seq(
("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a",
- classOf[ShuffledHashJoinExec]),
+ classOf[SortMergeJoinExec]),
("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[BroadcastNestedLoopJoinExec]),
("SELECT * FROM testData JOIN testData2", classOf[CartesianProductExec]),
("SELECT * FROM testData JOIN testData2 WHERE key = 2", classOf[CartesianProductExec]),
@@ -99,7 +99,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
classOf[BroadcastNestedLoopJoinExec]),
("SELECT * FROM testData full JOIN testData2 ON (key * a != key + a)",
classOf[BroadcastNestedLoopJoinExec]),
- ("SELECT * FROM testData ANTI JOIN testData2 ON key = a", classOf[ShuffledHashJoinExec]),
+ ("SELECT * FROM testData ANTI JOIN testData2 ON key = a", classOf[SortMergeJoinExec]),
("SELECT * FROM testData LEFT ANTI JOIN testData2", classOf[BroadcastNestedLoopJoinExec])
).foreach(assertJoin)
}
@@ -144,7 +144,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
val x = testData2.as("x")
val y = testData2.as("y")
val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.optimizedPlan
- val planned = sqlContext.sessionState.planner.EquiJoinSelection(join)
+ val planned = sqlContext.sessionState.planner.JoinSelection(join)
assert(planned.size === 1)
}
@@ -449,9 +449,9 @@ class JoinSuite extends QueryTest with SharedSQLContext {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
Seq(
("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a",
- classOf[ShuffledHashJoinExec]),
+ classOf[SortMergeJoinExec]),
("SELECT * FROM testData LEFT ANTI JOIN testData2 ON key = a",
- classOf[ShuffledHashJoinExec])
+ classOf[SortMergeJoinExec])
).foreach(assertJoin)
}
@@ -475,7 +475,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
Seq(
("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a",
- classOf[ShuffledHashJoinExec]),
+ classOf[SortMergeJoinExec]),
("SELECT * FROM testData LEFT SEMI JOIN testData2",
classOf[BroadcastNestedLoopJoinExec]),
("SELECT * FROM testData JOIN testData2",
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala
index bc838ee4da..c7c10abe9a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala
@@ -104,6 +104,18 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext {
}
}
+ test(s"$testName using SortMergeJoin") {
+ extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+ checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
+ EnsureRequirements(left.sqlContext.sessionState.conf).apply(
+ SortMergeJoinExec(leftKeys, rightKeys, joinType, boundCondition, left, right)),
+ expectedAnswer,
+ sortAnswers = true)
+ }
+ }
+ }
+
test(s"$testName using BroadcastNestedLoopJoin build left") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
index 933f32e496..2a4a3690f2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
@@ -189,7 +189,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
test(s"$testName using CartesianProduct") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
- FilterExec(condition(), CartesianProductExec(left, right)),
+ CartesianProductExec(left, right, Some(condition())),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index 695b1824e8..1859c6e7ad 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -255,28 +255,14 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value")
val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value")
// Assume the execution plan is
- // ... -> BroadcastLeftSemiJoinHash(nodeId = 0)
+ // ... -> BroadcastHashJoin(nodeId = 0)
val df = df1.join(broadcast(df2), $"key" === $"key2", "leftsemi")
testSparkPlanMetrics(df, 2, Map(
- 0L -> ("BroadcastLeftSemiJoinHash", Map(
+ 0L -> ("BroadcastHashJoin", Map(
"number of output rows" -> 2L)))
)
}
- test("ShuffledHashJoin metrics") {
- withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") {
- val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value")
- val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value")
- // Assume the execution plan is
- // ... -> ShuffledHashJoin(nodeId = 0)
- val df = df1.join(df2, $"key" === $"key2", "leftsemi")
- testSparkPlanMetrics(df, 1, Map(
- 0L -> ("ShuffledHashJoin", Map(
- "number of output rows" -> 2L)))
- )
- }
- }
-
test("CartesianProduct metrics") {
val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
testDataForJoin.registerTempTable("testDataForJoin")
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
index 4a8978e553..9633f9e15b 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
@@ -120,12 +120,8 @@ private[hive] class HiveSessionState(sparkSession: SparkSession)
DataSinks,
Scripts,
Aggregation,
- ExistenceJoin,
- EquiJoinSelection,
- BasicOperators,
- BroadcastNestedLoop,
- CartesianProduct,
- DefaultJoin
+ JoinSelection,
+ BasicOperators
)
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
index 93a6f0bb58..f6b5101498 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
@@ -228,10 +228,10 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton {
assert(bhj.isEmpty, "BroadcastHashJoin still planned even though it is switched off")
val shj = df.queryExecution.sparkPlan.collect {
- case j: ShuffledHashJoinExec => j
+ case j: SortMergeJoinExec => j
}
assert(shj.size === 1,
- "LeftSemiJoinHash should be planned when BroadcastHashJoin is turned off")
+ "SortMergeJoinExec should be planned when BroadcastHashJoin is turned off")
sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=$tmp")
}