aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala13
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala148
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala169
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala28
-rw-r--r--sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala162
11 files changed, 534 insertions, 33 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
index 1b62e17ff4..b6ec7d3417 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
@@ -17,8 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.sql.types.{UTF8String, StructType, NativeType}
-
+import org.apache.spark.sql.types.{UTF8String, DataType, StructType, NativeType}
/**
* An extended interface to [[Row]] that allows the values for each column to be updated. Setting
@@ -239,3 +238,10 @@ class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[Row] {
return 0
}
}
+
+object RowOrdering {
+ def forSchema(dataTypes: Seq[DataType]): RowOrdering =
+ new RowOrdering(dataTypes.zipWithIndex.map {
+ case(dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending)
+ })
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index 288c11f69f..fb4217a448 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -94,6 +94,9 @@ sealed trait Partitioning {
* only compatible if the `numPartitions` of them is the same.
*/
def compatibleWith(other: Partitioning): Boolean
+
+ /** Returns the expressions that are used to key the partitioning. */
+ def keyExpressions: Seq[Expression]
}
case class UnknownPartitioning(numPartitions: Int) extends Partitioning {
@@ -106,6 +109,8 @@ case class UnknownPartitioning(numPartitions: Int) extends Partitioning {
case UnknownPartitioning(_) => true
case _ => false
}
+
+ override def keyExpressions: Seq[Expression] = Nil
}
case object SinglePartition extends Partitioning {
@@ -117,6 +122,8 @@ case object SinglePartition extends Partitioning {
case SinglePartition => true
case _ => false
}
+
+ override def keyExpressions: Seq[Expression] = Nil
}
case object BroadcastPartitioning extends Partitioning {
@@ -128,6 +135,8 @@ case object BroadcastPartitioning extends Partitioning {
case SinglePartition => true
case _ => false
}
+
+ override def keyExpressions: Seq[Expression] = Nil
}
/**
@@ -158,6 +167,8 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
case _ => false
}
+ override def keyExpressions: Seq[Expression] = expressions
+
override def eval(input: Row = null): EvaluatedType =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
}
@@ -200,6 +211,8 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
case _ => false
}
+ override def keyExpressions: Seq[Expression] = ordering.map(_.child)
+
override def eval(input: Row): EvaluatedType =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index ee641bdfeb..5c65f04ee8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -47,6 +47,7 @@ private[spark] object SQLConf {
// Options that control which operators can be chosen by the query planner. These should be
// considered hints and may be ignored by future versions of Spark SQL.
val EXTERNAL_SORT = "spark.sql.planner.externalSort"
+ val SORTMERGE_JOIN = "spark.sql.planner.sortMergeJoin"
// This is only used for the thriftserver
val THRIFTSERVER_POOL = "spark.sql.thriftserver.scheduler.pool"
@@ -129,6 +130,13 @@ private[sql] class SQLConf extends Serializable {
private[spark] def externalSortEnabled: Boolean = getConf(EXTERNAL_SORT, "false").toBoolean
/**
+ * Sort merge join would sort the two side of join first, and then iterate both sides together
+ * only once to get all matches. Using sort merge join can save a lot of memory usage compared
+ * to HashJoin.
+ */
+ private[spark] def sortMergeJoinEnabled: Boolean = getConf(SORTMERGE_JOIN, "false").toBoolean
+
+ /**
* When set to true, Spark SQL will use the Scala compiler at runtime to generate custom bytecode
* that evaluates expressions found in queries. In general this custom code runs much faster
* than interpreted evaluation, but there are significant start-up costs due to compilation.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 89a4faf35e..f9f3eb2e03 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -1081,7 +1081,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
@transient
protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] {
val batches =
- Batch("Add exchange", Once, AddExchange(self)) :: Nil
+ Batch("Add exchange", Once, EnsureRequirements(self)) :: Nil
}
protected[sql] def openSession(): SQLSession = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index 437408d30b..518fc9e57c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -19,24 +19,42 @@ package org.apache.spark.sql.execution
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.shuffle.sort.SortShuffleManager
-import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.{SparkEnv, HashPartitioner, RangePartitioner, SparkConf}
import org.apache.spark.rdd.{RDD, ShuffledRDD}
import org.apache.spark.sql.{SQLContext, Row}
import org.apache.spark.sql.catalyst.errors.attachTree
-import org.apache.spark.sql.catalyst.expressions.{Attribute, RowOrdering}
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.util.MutablePair
+object Exchange {
+ /**
+ * Returns true when the ordering expressions are a subset of the key.
+ * if true, ShuffledRDD can use `setKeyOrdering(orderingKey)` to sort within [[Exchange]].
+ */
+ def canSortWithShuffle(partitioning: Partitioning, desiredOrdering: Seq[SortOrder]): Boolean = {
+ desiredOrdering.map(_.child).toSet.subsetOf(partitioning.keyExpressions.toSet)
+ }
+}
+
/**
* :: DeveloperApi ::
+ * Performs a shuffle that will result in the desired `newPartitioning`. Optionally sorts each
+ * resulting partition based on expressions from the partition key. It is invalid to construct an
+ * exchange operator with a `newOrdering` that cannot be calculated using the partitioning key.
*/
@DeveloperApi
-case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode {
+case class Exchange(
+ newPartitioning: Partitioning,
+ newOrdering: Seq[SortOrder],
+ child: SparkPlan)
+ extends UnaryNode {
override def outputPartitioning: Partitioning = newPartitioning
+ override def outputOrdering: Seq[SortOrder] = newOrdering
+
override def output: Seq[Attribute] = child.output
/** We must copy rows when sort based shuffle is on */
@@ -45,6 +63,20 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
private val bypassMergeThreshold =
child.sqlContext.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
+ private val keyOrdering = {
+ if (newOrdering.nonEmpty) {
+ val key = newPartitioning.keyExpressions
+ val boundOrdering = newOrdering.map { o =>
+ val ordinal = key.indexOf(o.child)
+ if (ordinal == -1) sys.error(s"Invalid ordering on $o requested for $newPartitioning")
+ o.copy(child = BoundReference(ordinal, o.child.dataType, o.child.nullable))
+ }
+ new RowOrdering(boundOrdering)
+ } else {
+ null // Ordering will not be used
+ }
+ }
+
override def execute(): RDD[Row] = attachTree(this , "execute") {
newPartitioning match {
case HashPartitioning(expressions, numPartitions) =>
@@ -56,7 +88,9 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
// we can avoid the defensive copies to improve performance. In the long run, we probably
// want to include information in shuffle dependencies to indicate whether elements in the
// source RDD should be copied.
- val rdd = if (sortBasedShuffleOn && numPartitions > bypassMergeThreshold) {
+ val willMergeSort = sortBasedShuffleOn && numPartitions > bypassMergeThreshold
+
+ val rdd = if (willMergeSort || newOrdering.nonEmpty) {
child.execute().mapPartitions { iter =>
val hashExpressions = newMutableProjection(expressions, child.output)()
iter.map(r => (hashExpressions(r).copy(), r.copy()))
@@ -69,12 +103,17 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
}
}
val part = new HashPartitioner(numPartitions)
- val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part)
+ val shuffled =
+ if (newOrdering.nonEmpty) {
+ new ShuffledRDD[Row, Row, Row](rdd, part).setKeyOrdering(keyOrdering)
+ } else {
+ new ShuffledRDD[Row, Row, Row](rdd, part)
+ }
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
shuffled.map(_._2)
case RangePartitioning(sortingExpressions, numPartitions) =>
- val rdd = if (sortBasedShuffleOn) {
+ val rdd = if (sortBasedShuffleOn || newOrdering.nonEmpty) {
child.execute().mapPartitions { iter => iter.map(row => (row.copy(), null))}
} else {
child.execute().mapPartitions { iter =>
@@ -87,7 +126,12 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
implicit val ordering = new RowOrdering(sortingExpressions, child.output)
val part = new RangePartitioner(numPartitions, rdd, ascending = true)
- val shuffled = new ShuffledRDD[Row, Null, Null](rdd, part)
+ val shuffled =
+ if (newOrdering.nonEmpty) {
+ new ShuffledRDD[Row, Null, Null](rdd, part).setKeyOrdering(keyOrdering)
+ } else {
+ new ShuffledRDD[Row, Null, Null](rdd, part)
+ }
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
shuffled.map(_._1)
@@ -120,27 +164,34 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
* Ensures that the [[org.apache.spark.sql.catalyst.plans.physical.Partitioning Partitioning]]
* of input data meets the
* [[org.apache.spark.sql.catalyst.plans.physical.Distribution Distribution]] requirements for
- * each operator by inserting [[Exchange]] Operators where required.
+ * each operator by inserting [[Exchange]] Operators where required. Also ensure that the
+ * required input partition ordering requirements are met.
*/
-private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPlan] {
+private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[SparkPlan] {
// TODO: Determine the number of partitions.
def numPartitions: Int = sqlContext.conf.numShufflePartitions
def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
case operator: SparkPlan =>
- // Check if every child's outputPartitioning satisfies the corresponding
+ // True iff every child's outputPartitioning satisfies the corresponding
// required data distribution.
def meetsRequirements: Boolean =
- !operator.requiredChildDistribution.zip(operator.children).map {
+ operator.requiredChildDistribution.zip(operator.children).forall {
case (required, child) =>
val valid = child.outputPartitioning.satisfies(required)
logDebug(
s"${if (valid) "Valid" else "Invalid"} distribution," +
s"required: $required current: ${child.outputPartitioning}")
valid
- }.exists(!_)
+ }
- // Check if outputPartitionings of children are compatible with each other.
+ // True iff any of the children are incorrectly sorted.
+ def needsAnySort: Boolean =
+ operator.requiredChildOrdering.zip(operator.children).exists {
+ case (required, child) => required.nonEmpty && required != child.outputOrdering
+ }
+
+ // True iff outputPartitionings of children are compatible with each other.
// It is possible that every child satisfies its required data distribution
// but two children have incompatible outputPartitionings. For example,
// A dataset is range partitioned by "a.asc" (RangePartitioning) and another
@@ -157,28 +208,69 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl
case Seq(a,b) => a compatibleWith b
}.exists(!_)
- // Check if the partitioning we want to ensure is the same as the child's output
- // partitioning. If so, we do not need to add the Exchange operator.
- def addExchangeIfNecessary(partitioning: Partitioning, child: SparkPlan): SparkPlan =
- if (child.outputPartitioning != partitioning) Exchange(partitioning, child) else child
+ // Adds Exchange or Sort operators as required
+ def addOperatorsIfNecessary(
+ partitioning: Partitioning,
+ rowOrdering: Seq[SortOrder],
+ child: SparkPlan): SparkPlan = {
+ val needSort = rowOrdering.nonEmpty && child.outputOrdering != rowOrdering
+ val needsShuffle = child.outputPartitioning != partitioning
+ val canSortWithShuffle = Exchange.canSortWithShuffle(partitioning, rowOrdering)
+
+ if (needSort && needsShuffle && canSortWithShuffle) {
+ Exchange(partitioning, rowOrdering, child)
+ } else {
+ val withShuffle = if (needsShuffle) {
+ Exchange(partitioning, Nil, child)
+ } else {
+ child
+ }
- if (meetsRequirements && compatible) {
+ val withSort = if (needSort) {
+ if (sqlContext.conf.externalSortEnabled) {
+ ExternalSort(rowOrdering, global = false, withShuffle)
+ } else {
+ Sort(rowOrdering, global = false, withShuffle)
+ }
+ } else {
+ withShuffle
+ }
+
+ withSort
+ }
+ }
+
+ if (meetsRequirements && compatible && !needsAnySort) {
operator
} else {
// At least one child does not satisfies its required data distribution or
// at least one child's outputPartitioning is not compatible with another child's
// outputPartitioning. In this case, we need to add Exchange operators.
- val repartitionedChildren = operator.requiredChildDistribution.zip(operator.children).map {
- case (AllTuples, child) =>
- addExchangeIfNecessary(SinglePartition, child)
- case (ClusteredDistribution(clustering), child) =>
- addExchangeIfNecessary(HashPartitioning(clustering, numPartitions), child)
- case (OrderedDistribution(ordering), child) =>
- addExchangeIfNecessary(RangePartitioning(ordering, numPartitions), child)
- case (UnspecifiedDistribution, child) => child
- case (dist, _) => sys.error(s"Don't know how to ensure $dist")
+ val requirements =
+ (operator.requiredChildDistribution, operator.requiredChildOrdering, operator.children)
+
+ val fixedChildren = requirements.zipped.map {
+ case (AllTuples, rowOrdering, child) =>
+ addOperatorsIfNecessary(SinglePartition, rowOrdering, child)
+ case (ClusteredDistribution(clustering), rowOrdering, child) =>
+ addOperatorsIfNecessary(HashPartitioning(clustering, numPartitions), rowOrdering, child)
+ case (OrderedDistribution(ordering), rowOrdering, child) =>
+ addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), rowOrdering, child)
+
+ case (UnspecifiedDistribution, Seq(), child) =>
+ child
+ case (UnspecifiedDistribution, rowOrdering, child) =>
+ if (sqlContext.conf.externalSortEnabled) {
+ ExternalSort(rowOrdering, global = false, child)
+ } else {
+ Sort(rowOrdering, global = false, child)
+ }
+
+ case (dist, ordering, _) =>
+ sys.error(s"Don't know how to ensure $dist with ordering $ordering")
}
- operator.withNewChildren(repartitionedChildren)
+
+ operator.withNewChildren(fixedChildren)
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index fabcf6b4a0..e159ffe66c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -72,6 +72,12 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
def requiredChildDistribution: Seq[Distribution] =
Seq.fill(children.size)(UnspecifiedDistribution)
+ /** Specifies how data is ordered in each partition. */
+ def outputOrdering: Seq[SortOrder] = Nil
+
+ /** Specifies sort order for each partition requirements on the input data for this operator. */
+ def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil)
+
/**
* Runs this query returning the result as an RDD.
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 5b99e40c2f..e687d01f57 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -90,6 +90,14 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
left.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold =>
makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft)
+ // If the sort merge join option is set, we want to use sort merge join prior to hashjoin
+ // for now let's support inner join first, then add outer join
+ case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
+ if sqlContext.conf.sortMergeJoinEnabled =>
+ val mergeJoin =
+ joins.SortMergeJoin(leftKeys, rightKeys, planLater(left), planLater(right))
+ condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil
+
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) =>
val buildSide =
if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) {
@@ -309,7 +317,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.OneRowRelation =>
execution.PhysicalRDD(Nil, singleRowRdd) :: Nil
case logical.Repartition(expressions, child) =>
- execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil
+ execution.Exchange(
+ HashPartitioning(expressions, numPartitions), Nil, planLater(child)) :: Nil
case e @ EvaluatePython(udf, child, _) =>
BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil
case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd) :: Nil
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index f8221f41bc..308dae236a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -41,6 +41,8 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends
val resuableProjection = buildProjection()
iter.map(resuableProjection)
}
+
+ override def outputOrdering: Seq[SortOrder] = child.outputOrdering
}
/**
@@ -55,6 +57,8 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {
override def execute(): RDD[Row] = child.execute().mapPartitions { iter =>
iter.filter(conditionEvaluator)
}
+
+ override def outputOrdering: Seq[SortOrder] = child.outputOrdering
}
/**
@@ -147,6 +151,8 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
// TODO: Terminal split should be implemented differently from non-terminal split.
// TODO: Pick num splits based on |limit|.
override def execute(): RDD[Row] = sparkContext.makeRDD(collectData(), 1)
+
+ override def outputOrdering: Seq[SortOrder] = sortOrder
}
/**
@@ -172,6 +178,8 @@ case class Sort(
}
override def output: Seq[Attribute] = child.output
+
+ override def outputOrdering: Seq[SortOrder] = sortOrder
}
/**
@@ -202,6 +210,8 @@ case class ExternalSort(
}
override def output: Seq[Attribute] = child.output
+
+ override def outputOrdering: Seq[SortOrder] = sortOrder
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
new file mode 100644
index 0000000000..b5123668ba
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
@@ -0,0 +1,169 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import java.util.NoSuchElementException
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
+import org.apache.spark.util.collection.CompactBuffer
+
+/**
+ * :: DeveloperApi ::
+ * Performs an sort merge join of two child relations.
+ */
+@DeveloperApi
+case class SortMergeJoin(
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ left: SparkPlan,
+ right: SparkPlan) extends BinaryNode {
+
+ override def output: Seq[Attribute] = left.output ++ right.output
+
+ override def outputPartitioning: Partitioning = left.outputPartitioning
+
+ override def requiredChildDistribution: Seq[Distribution] =
+ ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
+
+ // this is to manually construct an ordering that can be used to compare keys from both sides
+ private val keyOrdering: RowOrdering = RowOrdering.forSchema(leftKeys.map(_.dataType))
+
+ override def outputOrdering: Seq[SortOrder] = requiredOrders(leftKeys)
+
+ override def requiredChildOrdering: Seq[Seq[SortOrder]] =
+ requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil
+
+ @transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output)
+ @transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output)
+
+ private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] =
+ keys.map(SortOrder(_, Ascending))
+
+ override def execute(): RDD[Row] = {
+ val leftResults = left.execute().map(_.copy())
+ val rightResults = right.execute().map(_.copy())
+
+ leftResults.zipPartitions(rightResults) { (leftIter, rightIter) =>
+ new Iterator[Row] {
+ // Mutable per row objects.
+ private[this] val joinRow = new JoinedRow5
+ private[this] var leftElement: Row = _
+ private[this] var rightElement: Row = _
+ private[this] var leftKey: Row = _
+ private[this] var rightKey: Row = _
+ private[this] var rightMatches: CompactBuffer[Row] = _
+ private[this] var rightPosition: Int = -1
+ private[this] var stop: Boolean = false
+ private[this] var matchKey: Row = _
+
+ // initialize iterator
+ initialize()
+
+ override final def hasNext: Boolean = nextMatchingPair()
+
+ override final def next(): Row = {
+ if (hasNext) {
+ // we are using the buffered right rows and run down left iterator
+ val joinedRow = joinRow(leftElement, rightMatches(rightPosition))
+ rightPosition += 1
+ if (rightPosition >= rightMatches.size) {
+ rightPosition = 0
+ fetchLeft()
+ if (leftElement == null || keyOrdering.compare(leftKey, matchKey) != 0) {
+ stop = false
+ rightMatches = null
+ }
+ }
+ joinedRow
+ } else {
+ // no more result
+ throw new NoSuchElementException
+ }
+ }
+
+ private def fetchLeft() = {
+ if (leftIter.hasNext) {
+ leftElement = leftIter.next()
+ leftKey = leftKeyGenerator(leftElement)
+ } else {
+ leftElement = null
+ }
+ }
+
+ private def fetchRight() = {
+ if (rightIter.hasNext) {
+ rightElement = rightIter.next()
+ rightKey = rightKeyGenerator(rightElement)
+ } else {
+ rightElement = null
+ }
+ }
+
+ private def initialize() = {
+ fetchLeft()
+ fetchRight()
+ }
+
+ /**
+ * Searches the right iterator for the next rows that have matches in left side, and store
+ * them in a buffer.
+ *
+ * @return true if the search is successful, and false if the right iterator runs out of
+ * tuples.
+ */
+ private def nextMatchingPair(): Boolean = {
+ if (!stop && rightElement != null) {
+ // run both side to get the first match pair
+ while (!stop && leftElement != null && rightElement != null) {
+ val comparing = keyOrdering.compare(leftKey, rightKey)
+ // for inner join, we need to filter those null keys
+ stop = comparing == 0 && !leftKey.anyNull
+ if (comparing > 0 || rightKey.anyNull) {
+ fetchRight()
+ } else if (comparing < 0 || leftKey.anyNull) {
+ fetchLeft()
+ }
+ }
+ rightMatches = new CompactBuffer[Row]()
+ if (stop) {
+ stop = false
+ // iterate the right side to buffer all rows that matches
+ // as the records should be ordered, exit when we meet the first that not match
+ while (!stop && rightElement != null) {
+ rightMatches += rightElement
+ fetchRight()
+ stop = keyOrdering.compare(leftKey, rightKey) != 0
+ }
+ if (rightMatches.size > 0) {
+ rightPosition = 0
+ matchKey = leftKey
+ }
+ }
+ }
+ rightMatches != null && rightMatches.size > 0
+ }
+ }
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index e4dee87849..037d392c1f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -51,6 +51,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
case j: CartesianProduct => j
case j: BroadcastNestedLoopJoin => j
case j: BroadcastLeftSemiJoinHash => j
+ case j: SortMergeJoin => j
}
assert(operators.size === 1)
@@ -62,6 +63,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
test("join operator selection") {
cacheManager.clearCache()
+ val SORTMERGEJOIN_ENABLED: Boolean = conf.sortMergeJoinEnabled
Seq(
("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]),
("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]),
@@ -91,17 +93,41 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
("SELECT * FROM testData full JOIN testData2 ON (key * a != key + a)",
classOf[BroadcastNestedLoopJoin])
).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
+ try {
+ conf.setConf("spark.sql.planner.sortMergeJoin", "true")
+ Seq(
+ ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]),
+ ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]),
+ ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin])
+ ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
+ } finally {
+ conf.setConf("spark.sql.planner.sortMergeJoin", SORTMERGEJOIN_ENABLED.toString)
+ }
}
test("broadcasted hash join operator selection") {
cacheManager.clearCache()
sql("CACHE TABLE testData")
+ val SORTMERGEJOIN_ENABLED: Boolean = conf.sortMergeJoinEnabled
Seq(
("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]),
("SELECT * FROM testData join testData2 ON key = a and key = 2", classOf[BroadcastHashJoin]),
- ("SELECT * FROM testData join testData2 ON key = a where key = 2", classOf[BroadcastHashJoin])
+ ("SELECT * FROM testData join testData2 ON key = a where key = 2",
+ classOf[BroadcastHashJoin])
).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
+ try {
+ conf.setConf("spark.sql.planner.sortMergeJoin", "true")
+ Seq(
+ ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]),
+ ("SELECT * FROM testData join testData2 ON key = a and key = 2",
+ classOf[BroadcastHashJoin]),
+ ("SELECT * FROM testData join testData2 ON key = a where key = 2",
+ classOf[BroadcastHashJoin])
+ ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
+ } finally {
+ conf.setConf("spark.sql.planner.sortMergeJoin", SORTMERGEJOIN_ENABLED.toString)
+ }
sql("UNCACHE TABLE testData")
}
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala
new file mode 100644
index 0000000000..65d070bd3c
--- /dev/null
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.execution
+
+import org.apache.spark.sql.SQLConf
+import org.apache.spark.sql.hive.test.TestHive
+
+/**
+ * Runs the test cases that are included in the hive distribution with sort merge join is true.
+ */
+class SortMergeCompatibilitySuite extends HiveCompatibilitySuite {
+ override def beforeAll() {
+ super.beforeAll()
+ TestHive.setConf(SQLConf.SORTMERGE_JOIN, "true")
+ }
+
+ override def afterAll() {
+ TestHive.setConf(SQLConf.SORTMERGE_JOIN, "false")
+ super.afterAll()
+ }
+
+ override def whiteList = Seq(
+ "auto_join0",
+ "auto_join1",
+ "auto_join10",
+ "auto_join11",
+ "auto_join12",
+ "auto_join13",
+ "auto_join14",
+ "auto_join14_hadoop20",
+ "auto_join15",
+ "auto_join17",
+ "auto_join18",
+ "auto_join19",
+ "auto_join2",
+ "auto_join20",
+ "auto_join21",
+ "auto_join22",
+ "auto_join23",
+ "auto_join24",
+ "auto_join25",
+ "auto_join26",
+ "auto_join27",
+ "auto_join28",
+ "auto_join3",
+ "auto_join30",
+ "auto_join31",
+ "auto_join32",
+ "auto_join4",
+ "auto_join5",
+ "auto_join6",
+ "auto_join7",
+ "auto_join8",
+ "auto_join9",
+ "auto_join_filters",
+ "auto_join_nulls",
+ "auto_join_reordering_values",
+ "auto_smb_mapjoin_14",
+ "auto_sortmerge_join_1",
+ "auto_sortmerge_join_10",
+ "auto_sortmerge_join_11",
+ "auto_sortmerge_join_12",
+ "auto_sortmerge_join_13",
+ "auto_sortmerge_join_14",
+ "auto_sortmerge_join_15",
+ "auto_sortmerge_join_16",
+ "auto_sortmerge_join_2",
+ "auto_sortmerge_join_3",
+ "auto_sortmerge_join_4",
+ "auto_sortmerge_join_5",
+ "auto_sortmerge_join_6",
+ "auto_sortmerge_join_7",
+ "auto_sortmerge_join_8",
+ "auto_sortmerge_join_9",
+ "correlationoptimizer1",
+ "correlationoptimizer10",
+ "correlationoptimizer11",
+ "correlationoptimizer13",
+ "correlationoptimizer14",
+ "correlationoptimizer15",
+ "correlationoptimizer2",
+ "correlationoptimizer3",
+ "correlationoptimizer4",
+ "correlationoptimizer6",
+ "correlationoptimizer7",
+ "correlationoptimizer8",
+ "correlationoptimizer9",
+ "join0",
+ "join1",
+ "join10",
+ "join11",
+ "join12",
+ "join13",
+ "join14",
+ "join14_hadoop20",
+ "join15",
+ "join16",
+ "join17",
+ "join18",
+ "join19",
+ "join2",
+ "join20",
+ "join21",
+ "join22",
+ "join23",
+ "join24",
+ "join25",
+ "join26",
+ "join27",
+ "join28",
+ "join29",
+ "join3",
+ "join30",
+ "join31",
+ "join32",
+ "join32_lessSize",
+ "join33",
+ "join34",
+ "join35",
+ "join36",
+ "join37",
+ "join38",
+ "join39",
+ "join4",
+ "join40",
+ "join41",
+ "join5",
+ "join6",
+ "join7",
+ "join8",
+ "join9",
+ "join_1to1",
+ "join_array",
+ "join_casesensitive",
+ "join_empty",
+ "join_filters",
+ "join_hive_626",
+ "join_map_ppr",
+ "join_nulls",
+ "join_nullsafe",
+ "join_rc",
+ "join_reorder2",
+ "join_reorder3",
+ "join_reorder4",
+ "join_star"
+ )
+}