aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDaoyuan Wang <daoyuan.wang@intel.com>2015-04-15 14:06:10 -0700
committerMichael Armbrust <michael@databricks.com>2015-04-15 14:06:10 -0700
commit585638e81ce09a72b9e7f95d38e0d432cfa02456 (patch)
treeb26fb5702b77fd7124b952be806a4cfc039890ce /sql
parent4754e16f4746ebd882b2ce7f1efc6e4d4408922c (diff)
downloadspark-585638e81ce09a72b9e7f95d38e0d432cfa02456.tar.gz
spark-585638e81ce09a72b9e7f95d38e0d432cfa02456.tar.bz2
spark-585638e81ce09a72b9e7f95d38e0d432cfa02456.zip
[SPARK-2213] [SQL] sort merge join for spark sql
Thanks for the initial work from Ishiihara in #3173 This PR introduce a new join method of sort merge join, which firstly ensure that keys of same value are in the same partition, and inside each partition the Rows are sorted by key. Then we can run down both sides together, find matched rows using [sort merge join](http://en.wikipedia.org/wiki/Sort-merge_join). In this way, we don't have to store the whole hash table of one side as hash join, thus we have less memory usage. Also, this PR would benefit from #3438 , making the sorting phrase much more efficient. We introduced a new configuration of "spark.sql.planner.sortMergeJoin" to switch between this(`true`) and ShuffledHashJoin(`false`), probably we want the default value of it be `false` at first. Author: Daoyuan Wang <daoyuan.wang@intel.com> Author: Michael Armbrust <michael@databricks.com> This patch had conflicts when merged, resolved by Committer: Michael Armbrust <michael@databricks.com> Closes #5208 from adrian-wang/smj and squashes the following commits: 2493b9f [Daoyuan Wang] fix style 5049d88 [Daoyuan Wang] propagate rowOrdering for RangePartitioning f91a2ae [Daoyuan Wang] yin's comment: use external sort if option is enabled, add comments f515cd2 [Daoyuan Wang] yin's comment: outputOrdering, join suite refine ec8061b [Daoyuan Wang] minor change 413fd24 [Daoyuan Wang] Merge pull request #3 from marmbrus/pr/5208 952168a [Michael Armbrust] add type 5492884 [Michael Armbrust] copy when ordering 7ddd656 [Michael Armbrust] Cleanup addition of ordering requirements b198278 [Daoyuan Wang] inherit ordering in project c8e82a3 [Daoyuan Wang] fix style 6e897dd [Daoyuan Wang] hide boundReference from manually construct RowOrdering for key compare in smj 8681d73 [Daoyuan Wang] refactor Exchange and fix copy for sorting 2875ef2 [Daoyuan Wang] fix changed configuration 61d7f49 [Daoyuan Wang] add omitted comment 00a4430 [Daoyuan Wang] fix bug 078d69b [Daoyuan Wang] address comments: add comments, do sort in shuffle, and others 3af6ba5 [Daoyuan Wang] use buffer for only one side 171001f [Daoyuan Wang] change default outputordering 47455c9 [Daoyuan Wang] add apache license ... a28277f [Daoyuan Wang] fix style 645c70b [Daoyuan Wang] address comments using sort 068c35d [Daoyuan Wang] fix new style and add some tests 925203b [Daoyuan Wang] address comments 07ce92f [Daoyuan Wang] fix ArrayIndexOutOfBound 42fca0e [Daoyuan Wang] code clean e3ec096 [Daoyuan Wang] fix comment style.. 2edd235 [Daoyuan Wang] fix outputpartitioning 57baa40 [Daoyuan Wang] fix sort eval bug 303b6da [Daoyuan Wang] fix several errors 95db7ad [Daoyuan Wang] fix brackets for if-statement 4464f16 [Daoyuan Wang] fix error 880d8e9 [Daoyuan Wang] sort merge join for spark sql
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala13
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala148
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala169
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala28
-rw-r--r--sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala162
11 files changed, 534 insertions, 33 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
index 1b62e17ff4..b6ec7d3417 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
@@ -17,8 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.sql.types.{UTF8String, StructType, NativeType}
-
+import org.apache.spark.sql.types.{UTF8String, DataType, StructType, NativeType}
/**
* An extended interface to [[Row]] that allows the values for each column to be updated. Setting
@@ -239,3 +238,10 @@ class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[Row] {
return 0
}
}
+
+object RowOrdering {
+ def forSchema(dataTypes: Seq[DataType]): RowOrdering =
+ new RowOrdering(dataTypes.zipWithIndex.map {
+ case(dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending)
+ })
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index 288c11f69f..fb4217a448 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -94,6 +94,9 @@ sealed trait Partitioning {
* only compatible if the `numPartitions` of them is the same.
*/
def compatibleWith(other: Partitioning): Boolean
+
+ /** Returns the expressions that are used to key the partitioning. */
+ def keyExpressions: Seq[Expression]
}
case class UnknownPartitioning(numPartitions: Int) extends Partitioning {
@@ -106,6 +109,8 @@ case class UnknownPartitioning(numPartitions: Int) extends Partitioning {
case UnknownPartitioning(_) => true
case _ => false
}
+
+ override def keyExpressions: Seq[Expression] = Nil
}
case object SinglePartition extends Partitioning {
@@ -117,6 +122,8 @@ case object SinglePartition extends Partitioning {
case SinglePartition => true
case _ => false
}
+
+ override def keyExpressions: Seq[Expression] = Nil
}
case object BroadcastPartitioning extends Partitioning {
@@ -128,6 +135,8 @@ case object BroadcastPartitioning extends Partitioning {
case SinglePartition => true
case _ => false
}
+
+ override def keyExpressions: Seq[Expression] = Nil
}
/**
@@ -158,6 +167,8 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
case _ => false
}
+ override def keyExpressions: Seq[Expression] = expressions
+
override def eval(input: Row = null): EvaluatedType =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
}
@@ -200,6 +211,8 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
case _ => false
}
+ override def keyExpressions: Seq[Expression] = ordering.map(_.child)
+
override def eval(input: Row): EvaluatedType =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index ee641bdfeb..5c65f04ee8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -47,6 +47,7 @@ private[spark] object SQLConf {
// Options that control which operators can be chosen by the query planner. These should be
// considered hints and may be ignored by future versions of Spark SQL.
val EXTERNAL_SORT = "spark.sql.planner.externalSort"
+ val SORTMERGE_JOIN = "spark.sql.planner.sortMergeJoin"
// This is only used for the thriftserver
val THRIFTSERVER_POOL = "spark.sql.thriftserver.scheduler.pool"
@@ -129,6 +130,13 @@ private[sql] class SQLConf extends Serializable {
private[spark] def externalSortEnabled: Boolean = getConf(EXTERNAL_SORT, "false").toBoolean
/**
+ * Sort merge join would sort the two side of join first, and then iterate both sides together
+ * only once to get all matches. Using sort merge join can save a lot of memory usage compared
+ * to HashJoin.
+ */
+ private[spark] def sortMergeJoinEnabled: Boolean = getConf(SORTMERGE_JOIN, "false").toBoolean
+
+ /**
* When set to true, Spark SQL will use the Scala compiler at runtime to generate custom bytecode
* that evaluates expressions found in queries. In general this custom code runs much faster
* than interpreted evaluation, but there are significant start-up costs due to compilation.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 89a4faf35e..f9f3eb2e03 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -1081,7 +1081,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
@transient
protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] {
val batches =
- Batch("Add exchange", Once, AddExchange(self)) :: Nil
+ Batch("Add exchange", Once, EnsureRequirements(self)) :: Nil
}
protected[sql] def openSession(): SQLSession = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index 437408d30b..518fc9e57c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -19,24 +19,42 @@ package org.apache.spark.sql.execution
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.shuffle.sort.SortShuffleManager
-import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.{SparkEnv, HashPartitioner, RangePartitioner, SparkConf}
import org.apache.spark.rdd.{RDD, ShuffledRDD}
import org.apache.spark.sql.{SQLContext, Row}
import org.apache.spark.sql.catalyst.errors.attachTree
-import org.apache.spark.sql.catalyst.expressions.{Attribute, RowOrdering}
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.util.MutablePair
+object Exchange {
+ /**
+ * Returns true when the ordering expressions are a subset of the key.
+ * if true, ShuffledRDD can use `setKeyOrdering(orderingKey)` to sort within [[Exchange]].
+ */
+ def canSortWithShuffle(partitioning: Partitioning, desiredOrdering: Seq[SortOrder]): Boolean = {
+ desiredOrdering.map(_.child).toSet.subsetOf(partitioning.keyExpressions.toSet)
+ }
+}
+
/**
* :: DeveloperApi ::
+ * Performs a shuffle that will result in the desired `newPartitioning`. Optionally sorts each
+ * resulting partition based on expressions from the partition key. It is invalid to construct an
+ * exchange operator with a `newOrdering` that cannot be calculated using the partitioning key.
*/
@DeveloperApi
-case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode {
+case class Exchange(
+ newPartitioning: Partitioning,
+ newOrdering: Seq[SortOrder],
+ child: SparkPlan)
+ extends UnaryNode {
override def outputPartitioning: Partitioning = newPartitioning
+ override def outputOrdering: Seq[SortOrder] = newOrdering
+
override def output: Seq[Attribute] = child.output
/** We must copy rows when sort based shuffle is on */
@@ -45,6 +63,20 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
private val bypassMergeThreshold =
child.sqlContext.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
+ private val keyOrdering = {
+ if (newOrdering.nonEmpty) {
+ val key = newPartitioning.keyExpressions
+ val boundOrdering = newOrdering.map { o =>
+ val ordinal = key.indexOf(o.child)
+ if (ordinal == -1) sys.error(s"Invalid ordering on $o requested for $newPartitioning")
+ o.copy(child = BoundReference(ordinal, o.child.dataType, o.child.nullable))
+ }
+ new RowOrdering(boundOrdering)
+ } else {
+ null // Ordering will not be used
+ }
+ }
+
override def execute(): RDD[Row] = attachTree(this , "execute") {
newPartitioning match {
case HashPartitioning(expressions, numPartitions) =>
@@ -56,7 +88,9 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
// we can avoid the defensive copies to improve performance. In the long run, we probably
// want to include information in shuffle dependencies to indicate whether elements in the
// source RDD should be copied.
- val rdd = if (sortBasedShuffleOn && numPartitions > bypassMergeThreshold) {
+ val willMergeSort = sortBasedShuffleOn && numPartitions > bypassMergeThreshold
+
+ val rdd = if (willMergeSort || newOrdering.nonEmpty) {
child.execute().mapPartitions { iter =>
val hashExpressions = newMutableProjection(expressions, child.output)()
iter.map(r => (hashExpressions(r).copy(), r.copy()))
@@ -69,12 +103,17 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
}
}
val part = new HashPartitioner(numPartitions)
- val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part)
+ val shuffled =
+ if (newOrdering.nonEmpty) {
+ new ShuffledRDD[Row, Row, Row](rdd, part).setKeyOrdering(keyOrdering)
+ } else {
+ new ShuffledRDD[Row, Row, Row](rdd, part)
+ }
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
shuffled.map(_._2)
case RangePartitioning(sortingExpressions, numPartitions) =>
- val rdd = if (sortBasedShuffleOn) {
+ val rdd = if (sortBasedShuffleOn || newOrdering.nonEmpty) {
child.execute().mapPartitions { iter => iter.map(row => (row.copy(), null))}
} else {
child.execute().mapPartitions { iter =>
@@ -87,7 +126,12 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
implicit val ordering = new RowOrdering(sortingExpressions, child.output)
val part = new RangePartitioner(numPartitions, rdd, ascending = true)
- val shuffled = new ShuffledRDD[Row, Null, Null](rdd, part)
+ val shuffled =
+ if (newOrdering.nonEmpty) {
+ new ShuffledRDD[Row, Null, Null](rdd, part).setKeyOrdering(keyOrdering)
+ } else {
+ new ShuffledRDD[Row, Null, Null](rdd, part)
+ }
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
shuffled.map(_._1)
@@ -120,27 +164,34 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
* Ensures that the [[org.apache.spark.sql.catalyst.plans.physical.Partitioning Partitioning]]
* of input data meets the
* [[org.apache.spark.sql.catalyst.plans.physical.Distribution Distribution]] requirements for
- * each operator by inserting [[Exchange]] Operators where required.
+ * each operator by inserting [[Exchange]] Operators where required. Also ensure that the
+ * required input partition ordering requirements are met.
*/
-private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPlan] {
+private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[SparkPlan] {
// TODO: Determine the number of partitions.
def numPartitions: Int = sqlContext.conf.numShufflePartitions
def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
case operator: SparkPlan =>
- // Check if every child's outputPartitioning satisfies the corresponding
+ // True iff every child's outputPartitioning satisfies the corresponding
// required data distribution.
def meetsRequirements: Boolean =
- !operator.requiredChildDistribution.zip(operator.children).map {
+ operator.requiredChildDistribution.zip(operator.children).forall {
case (required, child) =>
val valid = child.outputPartitioning.satisfies(required)
logDebug(
s"${if (valid) "Valid" else "Invalid"} distribution," +
s"required: $required current: ${child.outputPartitioning}")
valid
- }.exists(!_)
+ }
- // Check if outputPartitionings of children are compatible with each other.
+ // True iff any of the children are incorrectly sorted.
+ def needsAnySort: Boolean =
+ operator.requiredChildOrdering.zip(operator.children).exists {
+ case (required, child) => required.nonEmpty && required != child.outputOrdering
+ }
+
+ // True iff outputPartitionings of children are compatible with each other.
// It is possible that every child satisfies its required data distribution
// but two children have incompatible outputPartitionings. For example,
// A dataset is range partitioned by "a.asc" (RangePartitioning) and another
@@ -157,28 +208,69 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl
case Seq(a,b) => a compatibleWith b
}.exists(!_)
- // Check if the partitioning we want to ensure is the same as the child's output
- // partitioning. If so, we do not need to add the Exchange operator.
- def addExchangeIfNecessary(partitioning: Partitioning, child: SparkPlan): SparkPlan =
- if (child.outputPartitioning != partitioning) Exchange(partitioning, child) else child
+ // Adds Exchange or Sort operators as required
+ def addOperatorsIfNecessary(
+ partitioning: Partitioning,
+ rowOrdering: Seq[SortOrder],
+ child: SparkPlan): SparkPlan = {
+ val needSort = rowOrdering.nonEmpty && child.outputOrdering != rowOrdering
+ val needsShuffle = child.outputPartitioning != partitioning
+ val canSortWithShuffle = Exchange.canSortWithShuffle(partitioning, rowOrdering)
+
+ if (needSort && needsShuffle && canSortWithShuffle) {
+ Exchange(partitioning, rowOrdering, child)
+ } else {
+ val withShuffle = if (needsShuffle) {
+ Exchange(partitioning, Nil, child)
+ } else {
+ child
+ }
- if (meetsRequirements && compatible) {
+ val withSort = if (needSort) {
+ if (sqlContext.conf.externalSortEnabled) {
+ ExternalSort(rowOrdering, global = false, withShuffle)
+ } else {
+ Sort(rowOrdering, global = false, withShuffle)
+ }
+ } else {
+ withShuffle
+ }
+
+ withSort
+ }
+ }
+
+ if (meetsRequirements && compatible && !needsAnySort) {
operator
} else {
// At least one child does not satisfies its required data distribution or
// at least one child's outputPartitioning is not compatible with another child's
// outputPartitioning. In this case, we need to add Exchange operators.
- val repartitionedChildren = operator.requiredChildDistribution.zip(operator.children).map {
- case (AllTuples, child) =>
- addExchangeIfNecessary(SinglePartition, child)
- case (ClusteredDistribution(clustering), child) =>
- addExchangeIfNecessary(HashPartitioning(clustering, numPartitions), child)
- case (OrderedDistribution(ordering), child) =>
- addExchangeIfNecessary(RangePartitioning(ordering, numPartitions), child)
- case (UnspecifiedDistribution, child) => child
- case (dist, _) => sys.error(s"Don't know how to ensure $dist")
+ val requirements =
+ (operator.requiredChildDistribution, operator.requiredChildOrdering, operator.children)
+
+ val fixedChildren = requirements.zipped.map {
+ case (AllTuples, rowOrdering, child) =>
+ addOperatorsIfNecessary(SinglePartition, rowOrdering, child)
+ case (ClusteredDistribution(clustering), rowOrdering, child) =>
+ addOperatorsIfNecessary(HashPartitioning(clustering, numPartitions), rowOrdering, child)
+ case (OrderedDistribution(ordering), rowOrdering, child) =>
+ addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), rowOrdering, child)
+
+ case (UnspecifiedDistribution, Seq(), child) =>
+ child
+ case (UnspecifiedDistribution, rowOrdering, child) =>
+ if (sqlContext.conf.externalSortEnabled) {
+ ExternalSort(rowOrdering, global = false, child)
+ } else {
+ Sort(rowOrdering, global = false, child)
+ }
+
+ case (dist, ordering, _) =>
+ sys.error(s"Don't know how to ensure $dist with ordering $ordering")
}
- operator.withNewChildren(repartitionedChildren)
+
+ operator.withNewChildren(fixedChildren)
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index fabcf6b4a0..e159ffe66c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -72,6 +72,12 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
def requiredChildDistribution: Seq[Distribution] =
Seq.fill(children.size)(UnspecifiedDistribution)
+ /** Specifies how data is ordered in each partition. */
+ def outputOrdering: Seq[SortOrder] = Nil
+
+ /** Specifies sort order for each partition requirements on the input data for this operator. */
+ def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil)
+
/**
* Runs this query returning the result as an RDD.
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 5b99e40c2f..e687d01f57 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -90,6 +90,14 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
left.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold =>
makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft)
+ // If the sort merge join option is set, we want to use sort merge join prior to hashjoin
+ // for now let's support inner join first, then add outer join
+ case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
+ if sqlContext.conf.sortMergeJoinEnabled =>
+ val mergeJoin =
+ joins.SortMergeJoin(leftKeys, rightKeys, planLater(left), planLater(right))
+ condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil
+
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) =>
val buildSide =
if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) {
@@ -309,7 +317,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.OneRowRelation =>
execution.PhysicalRDD(Nil, singleRowRdd) :: Nil
case logical.Repartition(expressions, child) =>
- execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil
+ execution.Exchange(
+ HashPartitioning(expressions, numPartitions), Nil, planLater(child)) :: Nil
case e @ EvaluatePython(udf, child, _) =>
BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil
case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd) :: Nil
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index f8221f41bc..308dae236a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -41,6 +41,8 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends
val resuableProjection = buildProjection()
iter.map(resuableProjection)
}
+
+ override def outputOrdering: Seq[SortOrder] = child.outputOrdering
}
/**
@@ -55,6 +57,8 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {
override def execute(): RDD[Row] = child.execute().mapPartitions { iter =>
iter.filter(conditionEvaluator)
}
+
+ override def outputOrdering: Seq[SortOrder] = child.outputOrdering
}
/**
@@ -147,6 +151,8 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
// TODO: Terminal split should be implemented differently from non-terminal split.
// TODO: Pick num splits based on |limit|.
override def execute(): RDD[Row] = sparkContext.makeRDD(collectData(), 1)
+
+ override def outputOrdering: Seq[SortOrder] = sortOrder
}
/**
@@ -172,6 +178,8 @@ case class Sort(
}
override def output: Seq[Attribute] = child.output
+
+ override def outputOrdering: Seq[SortOrder] = sortOrder
}
/**
@@ -202,6 +210,8 @@ case class ExternalSort(
}
override def output: Seq[Attribute] = child.output
+
+ override def outputOrdering: Seq[SortOrder] = sortOrder
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
new file mode 100644
index 0000000000..b5123668ba
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
@@ -0,0 +1,169 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import java.util.NoSuchElementException
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
+import org.apache.spark.util.collection.CompactBuffer
+
+/**
+ * :: DeveloperApi ::
+ * Performs an sort merge join of two child relations.
+ */
+@DeveloperApi
+case class SortMergeJoin(
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ left: SparkPlan,
+ right: SparkPlan) extends BinaryNode {
+
+ override def output: Seq[Attribute] = left.output ++ right.output
+
+ override def outputPartitioning: Partitioning = left.outputPartitioning
+
+ override def requiredChildDistribution: Seq[Distribution] =
+ ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
+
+ // this is to manually construct an ordering that can be used to compare keys from both sides
+ private val keyOrdering: RowOrdering = RowOrdering.forSchema(leftKeys.map(_.dataType))
+
+ override def outputOrdering: Seq[SortOrder] = requiredOrders(leftKeys)
+
+ override def requiredChildOrdering: Seq[Seq[SortOrder]] =
+ requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil
+
+ @transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output)
+ @transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output)
+
+ private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] =
+ keys.map(SortOrder(_, Ascending))
+
+ override def execute(): RDD[Row] = {
+ val leftResults = left.execute().map(_.copy())
+ val rightResults = right.execute().map(_.copy())
+
+ leftResults.zipPartitions(rightResults) { (leftIter, rightIter) =>
+ new Iterator[Row] {
+ // Mutable per row objects.
+ private[this] val joinRow = new JoinedRow5
+ private[this] var leftElement: Row = _
+ private[this] var rightElement: Row = _
+ private[this] var leftKey: Row = _
+ private[this] var rightKey: Row = _
+ private[this] var rightMatches: CompactBuffer[Row] = _
+ private[this] var rightPosition: Int = -1
+ private[this] var stop: Boolean = false
+ private[this] var matchKey: Row = _
+
+ // initialize iterator
+ initialize()
+
+ override final def hasNext: Boolean = nextMatchingPair()
+
+ override final def next(): Row = {
+ if (hasNext) {
+ // we are using the buffered right rows and run down left iterator
+ val joinedRow = joinRow(leftElement, rightMatches(rightPosition))
+ rightPosition += 1
+ if (rightPosition >= rightMatches.size) {
+ rightPosition = 0
+ fetchLeft()
+ if (leftElement == null || keyOrdering.compare(leftKey, matchKey) != 0) {
+ stop = false
+ rightMatches = null
+ }
+ }
+ joinedRow
+ } else {
+ // no more result
+ throw new NoSuchElementException
+ }
+ }
+
+ private def fetchLeft() = {
+ if (leftIter.hasNext) {
+ leftElement = leftIter.next()
+ leftKey = leftKeyGenerator(leftElement)
+ } else {
+ leftElement = null
+ }
+ }
+
+ private def fetchRight() = {
+ if (rightIter.hasNext) {
+ rightElement = rightIter.next()
+ rightKey = rightKeyGenerator(rightElement)
+ } else {
+ rightElement = null
+ }
+ }
+
+ private def initialize() = {
+ fetchLeft()
+ fetchRight()
+ }
+
+ /**
+ * Searches the right iterator for the next rows that have matches in left side, and store
+ * them in a buffer.
+ *
+ * @return true if the search is successful, and false if the right iterator runs out of
+ * tuples.
+ */
+ private def nextMatchingPair(): Boolean = {
+ if (!stop && rightElement != null) {
+ // run both side to get the first match pair
+ while (!stop && leftElement != null && rightElement != null) {
+ val comparing = keyOrdering.compare(leftKey, rightKey)
+ // for inner join, we need to filter those null keys
+ stop = comparing == 0 && !leftKey.anyNull
+ if (comparing > 0 || rightKey.anyNull) {
+ fetchRight()
+ } else if (comparing < 0 || leftKey.anyNull) {
+ fetchLeft()
+ }
+ }
+ rightMatches = new CompactBuffer[Row]()
+ if (stop) {
+ stop = false
+ // iterate the right side to buffer all rows that matches
+ // as the records should be ordered, exit when we meet the first that not match
+ while (!stop && rightElement != null) {
+ rightMatches += rightElement
+ fetchRight()
+ stop = keyOrdering.compare(leftKey, rightKey) != 0
+ }
+ if (rightMatches.size > 0) {
+ rightPosition = 0
+ matchKey = leftKey
+ }
+ }
+ }
+ rightMatches != null && rightMatches.size > 0
+ }
+ }
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index e4dee87849..037d392c1f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -51,6 +51,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
case j: CartesianProduct => j
case j: BroadcastNestedLoopJoin => j
case j: BroadcastLeftSemiJoinHash => j
+ case j: SortMergeJoin => j
}
assert(operators.size === 1)
@@ -62,6 +63,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
test("join operator selection") {
cacheManager.clearCache()
+ val SORTMERGEJOIN_ENABLED: Boolean = conf.sortMergeJoinEnabled
Seq(
("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]),
("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]),
@@ -91,17 +93,41 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
("SELECT * FROM testData full JOIN testData2 ON (key * a != key + a)",
classOf[BroadcastNestedLoopJoin])
).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
+ try {
+ conf.setConf("spark.sql.planner.sortMergeJoin", "true")
+ Seq(
+ ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]),
+ ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]),
+ ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin])
+ ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
+ } finally {
+ conf.setConf("spark.sql.planner.sortMergeJoin", SORTMERGEJOIN_ENABLED.toString)
+ }
}
test("broadcasted hash join operator selection") {
cacheManager.clearCache()
sql("CACHE TABLE testData")
+ val SORTMERGEJOIN_ENABLED: Boolean = conf.sortMergeJoinEnabled
Seq(
("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]),
("SELECT * FROM testData join testData2 ON key = a and key = 2", classOf[BroadcastHashJoin]),
- ("SELECT * FROM testData join testData2 ON key = a where key = 2", classOf[BroadcastHashJoin])
+ ("SELECT * FROM testData join testData2 ON key = a where key = 2",
+ classOf[BroadcastHashJoin])
).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
+ try {
+ conf.setConf("spark.sql.planner.sortMergeJoin", "true")
+ Seq(
+ ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]),
+ ("SELECT * FROM testData join testData2 ON key = a and key = 2",
+ classOf[BroadcastHashJoin]),
+ ("SELECT * FROM testData join testData2 ON key = a where key = 2",
+ classOf[BroadcastHashJoin])
+ ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
+ } finally {
+ conf.setConf("spark.sql.planner.sortMergeJoin", SORTMERGEJOIN_ENABLED.toString)
+ }
sql("UNCACHE TABLE testData")
}
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala
new file mode 100644
index 0000000000..65d070bd3c
--- /dev/null
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.execution
+
+import org.apache.spark.sql.SQLConf
+import org.apache.spark.sql.hive.test.TestHive
+
+/**
+ * Runs the test cases that are included in the hive distribution with sort merge join is true.
+ */
+class SortMergeCompatibilitySuite extends HiveCompatibilitySuite {
+ override def beforeAll() {
+ super.beforeAll()
+ TestHive.setConf(SQLConf.SORTMERGE_JOIN, "true")
+ }
+
+ override def afterAll() {
+ TestHive.setConf(SQLConf.SORTMERGE_JOIN, "false")
+ super.afterAll()
+ }
+
+ override def whiteList = Seq(
+ "auto_join0",
+ "auto_join1",
+ "auto_join10",
+ "auto_join11",
+ "auto_join12",
+ "auto_join13",
+ "auto_join14",
+ "auto_join14_hadoop20",
+ "auto_join15",
+ "auto_join17",
+ "auto_join18",
+ "auto_join19",
+ "auto_join2",
+ "auto_join20",
+ "auto_join21",
+ "auto_join22",
+ "auto_join23",
+ "auto_join24",
+ "auto_join25",
+ "auto_join26",
+ "auto_join27",
+ "auto_join28",
+ "auto_join3",
+ "auto_join30",
+ "auto_join31",
+ "auto_join32",
+ "auto_join4",
+ "auto_join5",
+ "auto_join6",
+ "auto_join7",
+ "auto_join8",
+ "auto_join9",
+ "auto_join_filters",
+ "auto_join_nulls",
+ "auto_join_reordering_values",
+ "auto_smb_mapjoin_14",
+ "auto_sortmerge_join_1",
+ "auto_sortmerge_join_10",
+ "auto_sortmerge_join_11",
+ "auto_sortmerge_join_12",
+ "auto_sortmerge_join_13",
+ "auto_sortmerge_join_14",
+ "auto_sortmerge_join_15",
+ "auto_sortmerge_join_16",
+ "auto_sortmerge_join_2",
+ "auto_sortmerge_join_3",
+ "auto_sortmerge_join_4",
+ "auto_sortmerge_join_5",
+ "auto_sortmerge_join_6",
+ "auto_sortmerge_join_7",
+ "auto_sortmerge_join_8",
+ "auto_sortmerge_join_9",
+ "correlationoptimizer1",
+ "correlationoptimizer10",
+ "correlationoptimizer11",
+ "correlationoptimizer13",
+ "correlationoptimizer14",
+ "correlationoptimizer15",
+ "correlationoptimizer2",
+ "correlationoptimizer3",
+ "correlationoptimizer4",
+ "correlationoptimizer6",
+ "correlationoptimizer7",
+ "correlationoptimizer8",
+ "correlationoptimizer9",
+ "join0",
+ "join1",
+ "join10",
+ "join11",
+ "join12",
+ "join13",
+ "join14",
+ "join14_hadoop20",
+ "join15",
+ "join16",
+ "join17",
+ "join18",
+ "join19",
+ "join2",
+ "join20",
+ "join21",
+ "join22",
+ "join23",
+ "join24",
+ "join25",
+ "join26",
+ "join27",
+ "join28",
+ "join29",
+ "join3",
+ "join30",
+ "join31",
+ "join32",
+ "join32_lessSize",
+ "join33",
+ "join34",
+ "join35",
+ "join36",
+ "join37",
+ "join38",
+ "join39",
+ "join4",
+ "join40",
+ "join41",
+ "join5",
+ "join6",
+ "join7",
+ "join8",
+ "join9",
+ "join_1to1",
+ "join_array",
+ "join_casesensitive",
+ "join_empty",
+ "join_filters",
+ "join_hive_626",
+ "join_map_ppr",
+ "join_nulls",
+ "join_nullsafe",
+ "join_rc",
+ "join_reorder2",
+ "join_reorder3",
+ "join_reorder4",
+ "join_star"
+ )
+}