aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala66
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala58
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala95
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala56
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala22
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala38
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala16
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala5
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala18
12 files changed, 276 insertions, 117 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 121b6d9e97..de4b4b799f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.execution
-import org.apache.spark.sql.execution.exchange.ShuffleExchange
import org.apache.spark.sql.Strategy
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
@@ -29,7 +28,8 @@ import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution
import org.apache.spark.sql.execution.columnar.{InMemoryColumnarTableScan, InMemoryRelation}
import org.apache.spark.sql.execution.command.{DescribeCommand => RunnableDescribeCommand, _}
-import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTempTableUsing, DescribeCommand => LogicalDescribeCommand, _}
+import org.apache.spark.sql.execution.datasources.{DescribeCommand => LogicalDescribeCommand, _}
+import org.apache.spark.sql.execution.exchange.ShuffleExchange
import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight}
import org.apache.spark.sql.internal.SQLConf
@@ -69,8 +69,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
leftKeys, rightKeys, LeftSemi, BuildRight, condition, planLater(left), planLater(right)))
// Find left semi joins where at least some predicates can be evaluated by matching join keys
case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) =>
- joins.LeftSemiJoinHash(
- leftKeys, rightKeys, planLater(left), planLater(right), condition) :: Nil
+ Seq(joins.ShuffledHashJoin(
+ leftKeys, rightKeys, LeftSemi, BuildRight, condition, planLater(left), planLater(right)))
case _ => Nil
}
}
@@ -80,8 +80,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
*/
object CanBroadcast {
def unapply(plan: LogicalPlan): Option[LogicalPlan] = {
- if (conf.autoBroadcastJoinThreshold > 0 &&
- plan.statistics.sizeInBytes <= conf.autoBroadcastJoinThreshold) {
+ if (plan.statistics.sizeInBytes <= conf.autoBroadcastJoinThreshold) {
Some(plan)
} else {
None
@@ -101,10 +100,41 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
* [[org.apache.spark.sql.functions.broadcast()]] function to a DataFrame), then that side
* of the join will be broadcasted and the other side will be streamed, with no shuffling
* performed. If both sides of the join are eligible to be broadcasted then the
+ * - Shuffle hash join: if single partition is small enough to build a hash table.
* - Sort merge: if the matching join keys are sortable.
*/
object EquiJoinSelection extends Strategy with PredicateHelper {
+ /**
+ * Matches a plan whose single partition should be small enough to build a hash table.
+ */
+ def canBuildHashMap(plan: LogicalPlan): Boolean = {
+ plan.statistics.sizeInBytes < conf.autoBroadcastJoinThreshold * conf.numShufflePartitions
+ }
+
+ /**
+ * Returns whether plan a is much smaller (3X) than plan b.
+ *
+ * The cost to build hash map is higher than sorting, we should only build hash map on a table
+ * that is much smaller than other one. Since we does not have the statistic for number of rows,
+ * use the size of bytes here as estimation.
+ */
+ private def muchSmaller(a: LogicalPlan, b: LogicalPlan): Boolean = {
+ a.statistics.sizeInBytes * 3 <= b.statistics.sizeInBytes
+ }
+
+ /**
+ * Returns whether we should use shuffle hash join or not.
+ *
+ * We should only use shuffle hash join when:
+ * 1) any single partition of a small table could fit in memory.
+ * 2) the smaller table is much smaller (3X) than the other one.
+ */
+ private def shouldShuffleHashJoin(left: LogicalPlan, right: LogicalPlan): Boolean = {
+ canBuildHashMap(left) && muchSmaller(left, right) ||
+ canBuildHashMap(right) && muchSmaller(right, left)
+ }
+
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
// --- Inner joins --------------------------------------------------------------------------
@@ -118,6 +148,18 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
leftKeys, rightKeys, Inner, BuildLeft, condition, planLater(left), planLater(right)))
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
+ if !conf.preferSortMergeJoin && shouldShuffleHashJoin(left, right) ||
+ !RowOrdering.isOrderable(leftKeys) =>
+ val buildSide =
+ if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) {
+ BuildRight
+ } else {
+ BuildLeft
+ }
+ Seq(joins.ShuffledHashJoin(
+ leftKeys, rightKeys, Inner, buildSide, condition, planLater(left), planLater(right)))
+
+ case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
if RowOrdering.isOrderable(leftKeys) =>
joins.SortMergeJoin(
leftKeys, rightKeys, Inner, condition, planLater(left), planLater(right)) :: Nil
@@ -134,6 +176,18 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
Seq(joins.BroadcastHashJoin(
leftKeys, rightKeys, RightOuter, BuildLeft, condition, planLater(left), planLater(right)))
+ case ExtractEquiJoinKeys(LeftOuter, leftKeys, rightKeys, condition, left, right)
+ if !conf.preferSortMergeJoin && canBuildHashMap(right) && muchSmaller(right, left) ||
+ !RowOrdering.isOrderable(leftKeys) =>
+ Seq(joins.ShuffledHashJoin(
+ leftKeys, rightKeys, LeftOuter, BuildRight, condition, planLater(left), planLater(right)))
+
+ case ExtractEquiJoinKeys(RightOuter, leftKeys, rightKeys, condition, left, right)
+ if !conf.preferSortMergeJoin && canBuildHashMap(left) && muchSmaller(left, right) ||
+ !RowOrdering.isOrderable(leftKeys) =>
+ Seq(joins.ShuffledHashJoin(
+ leftKeys, rightKeys, RightOuter, BuildLeft, condition, planLater(left), planLater(right)))
+
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
if RowOrdering.isOrderable(leftKeys) =>
joins.SortMergeJoin(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
deleted file mode 100644
index fa549b4d51..0000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
+++ /dev/null
@@ -1,58 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.joins
-
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.LeftSemi
-import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning}
-import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
-import org.apache.spark.sql.execution.metric.SQLMetrics
-
-/**
- * Build the right table's join keys into a HashedRelation, and iteratively go through the left
- * table, to find if the join keys are in the HashedRelation.
- */
-case class LeftSemiJoinHash(
- leftKeys: Seq[Expression],
- rightKeys: Seq[Expression],
- left: SparkPlan,
- right: SparkPlan,
- condition: Option[Expression]) extends BinaryNode with HashJoin {
-
- override val joinType = LeftSemi
- override val buildSide = BuildRight
-
- override private[sql] lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
-
- override def outputPartitioning: Partitioning = left.outputPartitioning
-
- override def requiredChildDistribution: Seq[Distribution] =
- ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
-
- protected override def doExecute(): RDD[InternalRow] = {
- val numOutputRows = longMetric("numOutputRows")
-
- right.execute().zipPartitions(left.execute()) { (buildIter, streamIter) =>
- val hashRelation = HashedRelation(buildIter.map(_.copy()), buildSideKeyGenerator)
- hashSemiJoin(streamIter, hashRelation, numOutputRows)
- }
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
new file mode 100644
index 0000000000..1e8879ae01
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Expression, JoinedRow}
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
+import org.apache.spark.sql.execution.metric.SQLMetrics
+
+/**
+ * Performs an inner hash join of two child relations by first shuffling the data using the join
+ * keys.
+ */
+case class ShuffledHashJoin(
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ joinType: JoinType,
+ buildSide: BuildSide,
+ condition: Option[Expression],
+ left: SparkPlan,
+ right: SparkPlan)
+ extends BinaryNode with HashJoin {
+
+ override private[sql] lazy val metrics = Map(
+ "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+
+ override def outputPartitioning: Partitioning = joinType match {
+ case Inner => PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning))
+ case LeftSemi => left.outputPartitioning
+ case LeftOuter => left.outputPartitioning
+ case RightOuter => right.outputPartitioning
+ case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
+ case x =>
+ throw new IllegalArgumentException(s"ShuffledHashJoin should not take $x as the JoinType")
+ }
+
+ override def requiredChildDistribution: Seq[Distribution] =
+ ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
+
+ protected override def doExecute(): RDD[InternalRow] = {
+ val numOutputRows = longMetric("numOutputRows")
+
+ streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) =>
+ val hashed = HashedRelation(buildIter.map(_.copy()), buildSideKeyGenerator)
+ val joinedRow = new JoinedRow
+ joinType match {
+ case Inner =>
+ hashJoin(streamIter, hashed, numOutputRows)
+
+ case LeftSemi =>
+ hashSemiJoin(streamIter, hashed, numOutputRows)
+
+ case LeftOuter =>
+ val keyGenerator = streamSideKeyGenerator
+ val resultProj = createResultProjection
+ streamIter.flatMap(currentRow => {
+ val rowKey = keyGenerator(currentRow)
+ joinedRow.withLeft(currentRow)
+ leftOuterIterator(rowKey, joinedRow, hashed.get(rowKey), resultProj, numOutputRows)
+ })
+
+ case RightOuter =>
+ val keyGenerator = streamSideKeyGenerator
+ val resultProj = createResultProjection
+ streamIter.flatMap(currentRow => {
+ val rowKey = keyGenerator(currentRow)
+ joinedRow.withRight(currentRow)
+ rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow, resultProj, numOutputRows)
+ })
+
+ case x =>
+ throw new IllegalArgumentException(
+ s"ShuffledHashJoin should not take $x as the JoinType")
+ }
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
index 807b39ace6..60bd8ea39a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
@@ -665,11 +665,11 @@ private[joins] class SortMergeJoinScanner(
* An iterator for outputting rows in left outer join.
*/
private class LeftOuterIterator(
- smjScanner: SortMergeJoinScanner,
- rightNullRow: InternalRow,
- boundCondition: InternalRow => Boolean,
- resultProj: InternalRow => InternalRow,
- numOutputRows: LongSQLMetric)
+ smjScanner: SortMergeJoinScanner,
+ rightNullRow: InternalRow,
+ boundCondition: InternalRow => Boolean,
+ resultProj: InternalRow => InternalRow,
+ numOutputRows: LongSQLMetric)
extends OneSideOuterIterator(
smjScanner, rightNullRow, boundCondition, resultProj, numOutputRows) {
@@ -681,13 +681,12 @@ private class LeftOuterIterator(
* An iterator for outputting rows in right outer join.
*/
private class RightOuterIterator(
- smjScanner: SortMergeJoinScanner,
- leftNullRow: InternalRow,
- boundCondition: InternalRow => Boolean,
- resultProj: InternalRow => InternalRow,
- numOutputRows: LongSQLMetric)
- extends OneSideOuterIterator(
- smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows) {
+ smjScanner: SortMergeJoinScanner,
+ leftNullRow: InternalRow,
+ boundCondition: InternalRow => Boolean,
+ resultProj: InternalRow => InternalRow,
+ numOutputRows: LongSQLMetric)
+ extends OneSideOuterIterator(smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows) {
protected override def setStreamSideOutput(row: InternalRow): Unit = joinedRow.withRight(row)
protected override def setBufferedSideOutput(row: InternalRow): Unit = joinedRow.withLeft(row)
@@ -710,11 +709,11 @@ private class RightOuterIterator(
* @param numOutputRows an accumulator metric for the number of rows output
*/
private abstract class OneSideOuterIterator(
- smjScanner: SortMergeJoinScanner,
- bufferedSideNullRow: InternalRow,
- boundCondition: InternalRow => Boolean,
- resultProj: InternalRow => InternalRow,
- numOutputRows: LongSQLMetric) extends RowIterator {
+ smjScanner: SortMergeJoinScanner,
+ bufferedSideNullRow: InternalRow,
+ boundCondition: InternalRow => Boolean,
+ resultProj: InternalRow => InternalRow,
+ numOutputRows: LongSQLMetric) extends RowIterator {
// A row to store the joined result, reused many times
protected[this] val joinedRow: JoinedRow = new JoinedRow()
@@ -777,14 +776,14 @@ private abstract class OneSideOuterIterator(
}
private class SortMergeFullOuterJoinScanner(
- leftKeyGenerator: Projection,
- rightKeyGenerator: Projection,
- keyOrdering: Ordering[InternalRow],
- leftIter: RowIterator,
- rightIter: RowIterator,
- boundCondition: InternalRow => Boolean,
- leftNullRow: InternalRow,
- rightNullRow: InternalRow) {
+ leftKeyGenerator: Projection,
+ rightKeyGenerator: Projection,
+ keyOrdering: Ordering[InternalRow],
+ leftIter: RowIterator,
+ rightIter: RowIterator,
+ boundCondition: InternalRow => Boolean,
+ leftNullRow: InternalRow,
+ rightNullRow: InternalRow) {
private[this] val joinedRow: JoinedRow = new JoinedRow()
private[this] var leftRow: InternalRow = _
private[this] var leftRowKey: InternalRow = _
@@ -950,10 +949,9 @@ private class SortMergeFullOuterJoinScanner(
}
private class FullOuterIterator(
- smjScanner: SortMergeFullOuterJoinScanner,
- resultProj: InternalRow => InternalRow,
- numRows: LongSQLMetric
-) extends RowIterator {
+ smjScanner: SortMergeFullOuterJoinScanner,
+ resultProj: InternalRow => InternalRow,
+ numRows: LongSQLMetric) extends RowIterator {
private[this] val joinedRow: JoinedRow = smjScanner.getJoinedRow()
override def advanceNext(): Boolean = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 9aabe2d0ab..c308161413 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -236,6 +236,11 @@ object SQLConf {
doc = "When true, enable partition pruning for in-memory columnar tables.",
isPublic = false)
+ val PREFER_SORTMERGEJOIN = booleanConf("spark.sql.join.preferSortMergeJoin",
+ defaultValue = Some(true),
+ doc = "When true, prefer sort merge join over shuffle hash join",
+ isPublic = false)
+
val AUTO_BROADCASTJOIN_THRESHOLD = intConf("spark.sql.autoBroadcastJoinThreshold",
defaultValue = Some(10 * 1024 * 1024),
doc = "Configures the maximum size in bytes for a table that will be broadcast to all worker " +
@@ -586,6 +591,8 @@ class SQLConf extends Serializable with CatalystConf with ParserConf with Loggin
def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD)
+ def preferSortMergeJoin: Boolean = getConf(PREFER_SORTMERGEJOIN)
+
def defaultSizeInBytes: Long =
getConf(DEFAULT_SIZE_IN_BYTES, autoBroadcastJoinThreshold + 1L)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 03d6df8c28..dfffa4bc8b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -45,8 +45,8 @@ class JoinSuite extends QueryTest with SharedSQLContext {
val df = sql(sqlString)
val physical = df.queryExecution.sparkPlan
val operators = physical.collect {
- case j: LeftSemiJoinHash => j
case j: BroadcastHashJoin => j
+ case j: ShuffledHashJoin => j
case j: CartesianProduct => j
case j: BroadcastNestedLoopJoin => j
case j: SortMergeJoin => j
@@ -63,7 +63,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "0") {
Seq(
- ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]),
+ ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[ShuffledHashJoin]),
("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[BroadcastNestedLoopJoin]),
("SELECT * FROM testData JOIN testData2", classOf[CartesianProduct]),
("SELECT * FROM testData JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
@@ -434,7 +434,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
Seq(
- ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash])
+ ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[ShuffledHashJoin])
).foreach {
case (query, joinClass) => assertJoin(query, joinClass)
}
@@ -460,7 +460,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
Seq(
("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a",
- classOf[LeftSemiJoinHash]),
+ classOf[ShuffledHashJoin]),
("SELECT * FROM testData LEFT SEMI JOIN testData2",
classOf[BroadcastNestedLoopJoin]),
("SELECT * FROM testData JOIN testData2",
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index d293ff66fb..a16bd77bfe 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -247,7 +247,27 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
*/
}
- ignore("rube") {
+ ignore("shuffle hash join") {
+ val N = 4 << 20
+ sqlContext.setConf("spark.sql.shuffle.partitions", "2")
+ sqlContext.setConf("spark.sql.autoBroadcastJoinThreshold", "10000000")
+ sqlContext.setConf("spark.sql.join.preferSortMergeJoin", "false")
+ runBenchmark("shuffle hash join", N) {
+ val df1 = sqlContext.range(N).selectExpr(s"id as k1")
+ val df2 = sqlContext.range(N / 5).selectExpr(s"id * 3 as k2")
+ df1.join(df2, col("k1") === col("k2")).count()
+ }
+
+ /**
+ Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+ shuffle hash join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ -------------------------------------------------------------------------------------------
+ shuffle hash join codegen=false 1168 / 1902 3.6 278.6 1.0X
+ shuffle hash join codegen=true 850 / 1196 4.9 202.8 1.4X
+ */
+ }
+
+ ignore("cube") {
val N = 5 << 20
runBenchmark("cube", N) {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 9cd50abda6..e9b65539b0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.columnar.InMemoryRelation
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchange, ReuseExchange, ShuffleExchange}
-import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin}
+import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin, SortMergeJoin}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
@@ -143,7 +143,7 @@ class PlannerSuite extends SharedSQLContext {
val sortMergeJoins = planned.collect { case join: SortMergeJoin => join }
assert(broadcastHashJoins.size === 1, "Should use broadcast hash join")
- assert(sortMergeJoins.isEmpty, "Should not use sort merge join")
+ assert(sortMergeJoins.isEmpty, "Should not use shuffled hash join")
sqlContext.clearCache()
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
index 814e25d10e..cf2681050e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
@@ -101,6 +101,20 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
EnsureRequirements(sqlContext.sessionState.conf).apply(broadcastJoin)
}
+ def makeShuffledHashJoin(
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ boundCondition: Option[Expression],
+ leftPlan: SparkPlan,
+ rightPlan: SparkPlan,
+ side: BuildSide) = {
+ val shuffledHashJoin =
+ joins.ShuffledHashJoin(leftKeys, rightKeys, Inner, side, None, leftPlan, rightPlan)
+ val filteredJoin =
+ boundCondition.map(Filter(_, shuffledHashJoin)).getOrElse(shuffledHashJoin)
+ EnsureRequirements(sqlContext.sessionState.conf).apply(filteredJoin)
+ }
+
def makeSortMergeJoin(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
@@ -136,6 +150,30 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
}
}
+ test(s"$testName using ShuffledHashJoin (build=left)") {
+ extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+ checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) =>
+ makeShuffledHashJoin(
+ leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildLeft),
+ expectedAnswer.map(Row.fromTuple),
+ sortAnswers = true)
+ }
+ }
+ }
+
+ test(s"$testName using ShuffledHashJoin (build=right)") {
+ extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+ checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) =>
+ makeShuffledHashJoin(
+ leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildRight),
+ expectedAnswer.map(Row.fromTuple),
+ sortAnswers = true)
+ }
+ }
+ }
+
test(s"$testName using SortMergeJoin") {
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
index 1c8b2ea808..4cacb20aa0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
@@ -77,6 +77,22 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext {
}
if (joinType != FullOuter) {
+ test(s"$testName using ShuffledHashJoin") {
+ extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+ val buildSide = if (joinType == LeftOuter) BuildRight else BuildLeft
+ checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
+ EnsureRequirements(sqlContext.sessionState.conf).apply(
+ ShuffledHashJoin(
+ leftKeys, rightKeys, joinType, buildSide, boundCondition, left, right)),
+ expectedAnswer.map(Row.fromTuple),
+ sortAnswers = true)
+ }
+ }
+ }
+ }
+
+ if (joinType != FullOuter) {
test(s"$testName using BroadcastHashJoin") {
val buildSide = joinType match {
case LeftOuter => BuildRight
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
index 5eb6a74523..985a96f684 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
@@ -72,12 +72,13 @@ class SemiJoinSuite extends SparkPlanTest with SharedSQLContext {
ExtractEquiJoinKeys.unapply(join)
}
- test(s"$testName using LeftSemiJoinHash") {
+ test(s"$testName using ShuffledHashJoin") {
extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
EnsureRequirements(left.sqlContext.sessionState.conf).apply(
- LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)),
+ ShuffledHashJoin(
+ leftKeys, rightKeys, LeftSemi, BuildRight, boundCondition, left, right)),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index 988852a4fc..fa68c1a91d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -263,32 +263,20 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
)
}
- test("LeftSemiJoinHash metrics") {
+ test("ShuffledHashJoin metrics") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") {
val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value")
val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value")
// Assume the execution plan is
- // ... -> LeftSemiJoinHash(nodeId = 0)
+ // ... -> ShuffledHashJoin(nodeId = 0)
val df = df1.join(df2, $"key" === $"key2", "leftsemi")
testSparkPlanMetrics(df, 1, Map(
- 0L -> ("LeftSemiJoinHash", Map(
+ 0L -> ("ShuffledHashJoin", Map(
"number of output rows" -> 2L)))
)
}
}
- test("LeftSemiJoinBNL metrics") {
- val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value")
- val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value")
- // Assume the execution plan is
- // ... -> LeftSemiJoinBNL(nodeId = 0)
- val df = df1.join(df2, $"key" < $"key2", "leftsemi")
- testSparkPlanMetrics(df, 2, Map(
- 0L -> ("LeftSemiJoinBNL", Map(
- "number of output rows" -> 2L)))
- )
- }
-
test("CartesianProduct metrics") {
val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
testDataForJoin.registerTempTable("testDataForJoin")