aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorYin Huai <yhuai@databricks.com>2015-08-02 20:44:23 -0700
committerYin Huai <yhuai@databricks.com>2015-08-02 20:44:23 -0700
commit114ff926fcd078697c1111279b5cf6173b515865 (patch)
tree94d08d409de119301be8596ad9cc9b7213e801f7 /sql
parent30e89111d673776a6b59b11cdb29ab8713ba6f7c (diff)
downloadspark-114ff926fcd078697c1111279b5cf6173b515865.tar.gz
spark-114ff926fcd078697c1111279b5cf6173b515865.tar.bz2
spark-114ff926fcd078697c1111279b5cf6173b515865.zip
[SPARK-2205] [SQL] Avoid unnecessary exchange operators in multi-way joins
This PR adds `PartitioningCollection`, which is used to represent the `outputPartitioning` for SparkPlans with multiple children (e.g. `ShuffledHashJoin`). So, a `SparkPlan` can have multiple descriptions of its partitioning schemes. Taking `ShuffledHashJoin` as an example, it has two descriptions of its partitioning schemes, i.e. `left.outputPartitioning` and `right.outputPartitioning`. So when we have a query like `select * from t1 join t2 on (t1.x = t2.x) join t3 on (t2.x = t3.x)` will only have three Exchange operators (when shuffled joins are needed) instead of four. The code in this PR was authored by yhuai; I'm opening this PR to factor out this change from #7685, a larger pull request which contains two other optimizations. <!-- Reviewable:start --> [<img src="https://reviewable.io/review_button.png" height=40 alt="Review on Reviewable"/>](https://reviewable.io/reviews/apache/spark/7773) <!-- Reviewable:end --> Author: Yin Huai <yhuai@databricks.com> Author: Josh Rosen <joshrosen@databricks.com> Closes #7773 from JoshRosen/multi-way-join-planning-improvements and squashes the following commits: 5c45924 [Josh Rosen] Merge remote-tracking branch 'origin/master' into multi-way-join-planning-improvements cd8269b [Josh Rosen] Refactor test to use SQLTestUtils 2963857 [Yin Huai] Revert unnecessary SqlConf change. 73913f7 [Yin Huai] Add comments and test. Also, revert the change in ShuffledHashOuterJoin for now. 4a99204 [Josh Rosen] Delete unrelated expression change 884ab95 [Josh Rosen] Carve out only SPARK-2205 changes. 247e5fa [Josh Rosen] Merge remote-tracking branch 'origin/master' into multi-way-join-planning-improvements c57a954 [Yin Huai] Bug fix. d3d2e64 [Yin Huai] First round of cleanup. f9516b0 [Yin Huai] Style c6667e7 [Yin Huai] Add PartitioningCollection. e616d3b [Yin Huai] wip 7c2d2d8 [Yin Huai] Bug fix and refactoring. 69bb072 [Yin Huai] Introduce NullSafeHashPartitioning and NullUnsafePartitioning. d5b84c3 [Yin Huai] Do not add unnessary filters. 2201129 [Yin Huai] Filter out rows that will not be joined in equal joins early.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala87
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala49
10 files changed, 148 insertions, 31 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index f4d1dbaf28..ec659ce789 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -60,8 +60,9 @@ case class ClusteredDistribution(clustering: Seq[Expression]) extends Distributi
/**
* Represents data where tuples have been ordered according to the `ordering`
* [[Expression Expressions]]. This is a strictly stronger guarantee than
- * [[ClusteredDistribution]] as an ordering will ensure that tuples that share the same value for
- * the ordering expressions are contiguous and will never be split across partitions.
+ * [[ClusteredDistribution]] as an ordering will ensure that tuples that share the
+ * same value for the ordering expressions are contiguous and will never be split across
+ * partitions.
*/
case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution {
require(
@@ -86,8 +87,12 @@ sealed trait Partitioning {
*/
def satisfies(required: Distribution): Boolean
- /** Returns the expressions that are used to key the partitioning. */
- def keyExpressions: Seq[Expression]
+ /**
+ * Returns true iff we can say that the partitioning scheme of this [[Partitioning]]
+ * guarantees the same partitioning scheme described by `other`.
+ */
+ // TODO: Add an example once we have the `nullSafe` concept.
+ def guarantees(other: Partitioning): Boolean
}
case class UnknownPartitioning(numPartitions: Int) extends Partitioning {
@@ -96,7 +101,7 @@ case class UnknownPartitioning(numPartitions: Int) extends Partitioning {
case _ => false
}
- override def keyExpressions: Seq[Expression] = Nil
+ override def guarantees(other: Partitioning): Boolean = false
}
case object SinglePartition extends Partitioning {
@@ -104,7 +109,10 @@ case object SinglePartition extends Partitioning {
override def satisfies(required: Distribution): Boolean = true
- override def keyExpressions: Seq[Expression] = Nil
+ override def guarantees(other: Partitioning): Boolean = other match {
+ case SinglePartition => true
+ case _ => false
+ }
}
case object BroadcastPartitioning extends Partitioning {
@@ -112,7 +120,10 @@ case object BroadcastPartitioning extends Partitioning {
override def satisfies(required: Distribution): Boolean = true
- override def keyExpressions: Seq[Expression] = Nil
+ override def guarantees(other: Partitioning): Boolean = other match {
+ case BroadcastPartitioning => true
+ case _ => false
+ }
}
/**
@@ -127,7 +138,7 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
override def nullable: Boolean = false
override def dataType: DataType = IntegerType
- private[this] lazy val clusteringSet = expressions.toSet
+ lazy val clusteringSet = expressions.toSet
override def satisfies(required: Distribution): Boolean = required match {
case UnspecifiedDistribution => true
@@ -136,7 +147,11 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
case _ => false
}
- override def keyExpressions: Seq[Expression] = expressions
+ override def guarantees(other: Partitioning): Boolean = other match {
+ case o: HashPartitioning =>
+ this.clusteringSet == o.clusteringSet && this.numPartitions == o.numPartitions
+ case _ => false
+ }
}
/**
@@ -170,5 +185,57 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
case _ => false
}
- override def keyExpressions: Seq[Expression] = ordering.map(_.child)
+ override def guarantees(other: Partitioning): Boolean = other match {
+ case o: RangePartitioning => this == o
+ case _ => false
+ }
+}
+
+/**
+ * A collection of [[Partitioning]]s that can be used to describe the partitioning
+ * scheme of the output of a physical operator. It is usually used for an operator
+ * that has multiple children. In this case, a [[Partitioning]] in this collection
+ * describes how this operator's output is partitioned based on expressions from
+ * a child. For example, for a Join operator on two tables `A` and `B`
+ * with a join condition `A.key1 = B.key2`, assuming we use HashPartitioning schema,
+ * there are two [[Partitioning]]s can be used to describe how the output of
+ * this Join operator is partitioned, which are `HashPartitioning(A.key1)` and
+ * `HashPartitioning(B.key2)`. It is also worth noting that `partitionings`
+ * in this collection do not need to be equivalent, which is useful for
+ * Outer Join operators.
+ */
+case class PartitioningCollection(partitionings: Seq[Partitioning])
+ extends Expression with Partitioning with Unevaluable {
+
+ require(
+ partitionings.map(_.numPartitions).distinct.length == 1,
+ s"PartitioningCollection requires all of its partitionings have the same numPartitions.")
+
+ override def children: Seq[Expression] = partitionings.collect {
+ case expr: Expression => expr
+ }
+
+ override def nullable: Boolean = false
+
+ override def dataType: DataType = IntegerType
+
+ override val numPartitions = partitionings.map(_.numPartitions).distinct.head
+
+ /**
+ * Returns true if any `partitioning` of this collection satisfies the given
+ * [[Distribution]].
+ */
+ override def satisfies(required: Distribution): Boolean =
+ partitionings.exists(_.satisfies(required))
+
+ /**
+ * Returns true if any `partitioning` of this collection guarantees
+ * the given [[Partitioning]].
+ */
+ override def guarantees(other: Partitioning): Boolean =
+ partitionings.exists(_.guarantees(other))
+
+ override def toString: String = {
+ partitionings.map(_.toString).mkString("(", " or ", ")")
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala
index c046dbf4dc..827f7ce692 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala
@@ -42,7 +42,7 @@ class DistributionSuite extends SparkFunSuite {
}
}
- test("HashPartitioning is the output partitioning") {
+ test("HashPartitioning (with nullSafe = true) is the output partitioning") {
// Cases which do not need an exchange between two data properties.
checkSatisfied(
HashPartitioning(Seq('a, 'b, 'c), 10),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index 6bd57f010a..05b009d193 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -209,7 +209,7 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
child: SparkPlan): SparkPlan = {
def addShuffleIfNecessary(child: SparkPlan): SparkPlan = {
- if (child.outputPartitioning != partitioning) {
+ if (!child.outputPartitioning.guarantees(partitioning)) {
Exchange(partitioning, child)
} else {
child
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
index 77e7fe7100..309716a0ef 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
@@ -24,7 +24,7 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.physical.{Distribution, UnspecifiedDistribution}
+import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, Distribution, UnspecifiedDistribution}
import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter}
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
import org.apache.spark.util.ThreadUtils
@@ -57,6 +57,8 @@ case class BroadcastHashOuterJoin(
override def requiredChildDistribution: Seq[Distribution] =
UnspecifiedDistribution :: UnspecifiedDistribution :: Nil
+ override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning
+
@transient
private val broadcastFuture = future {
// Note that we use .execute().collect() because we don't want to convert data to Scala types
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
index 7e671e7914..a323aea4ea 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
@@ -22,7 +22,6 @@ import java.util.{HashMap => JavaHashMap}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.util.collection.CompactBuffer
@@ -38,14 +37,6 @@ trait HashOuterJoin {
val left: SparkPlan
val right: SparkPlan
- override def outputPartitioning: Partitioning = joinType match {
- case LeftOuter => left.outputPartitioning
- case RightOuter => right.outputPartitioning
- case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
- case x =>
- throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType")
- }
-
override def output: Seq[Attribute] = {
joinType match {
case LeftOuter =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
index 26a664104d..68ccd34d8e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
@@ -21,7 +21,7 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution
+import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, Distribution, ClusteredDistribution}
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
/**
@@ -37,7 +37,9 @@ case class LeftSemiJoinHash(
right: SparkPlan,
condition: Option[Expression]) extends BinaryNode with HashSemiJoin {
- override def requiredChildDistribution: Seq[ClusteredDistribution] =
+ override def outputPartitioning: Partitioning = left.outputPartitioning
+
+ override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
protected override def doExecute(): RDD[InternalRow] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
index 5439e10a60..fc6efe87bc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
@@ -21,7 +21,7 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Expression
-import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning}
+import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
/**
@@ -38,9 +38,10 @@ case class ShuffledHashJoin(
right: SparkPlan)
extends BinaryNode with HashJoin {
- override def outputPartitioning: Partitioning = left.outputPartitioning
+ override def outputPartitioning: Partitioning =
+ PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning))
- override def requiredChildDistribution: Seq[ClusteredDistribution] =
+ override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
protected override def doExecute(): RDD[InternalRow] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala
index d29b593207..eee8ad800f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala
@@ -23,7 +23,7 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.physical.{Distribution, ClusteredDistribution}
+import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter}
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
@@ -44,6 +44,14 @@ case class ShuffledHashOuterJoin(
override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
+ override def outputPartitioning: Partitioning = joinType match {
+ case LeftOuter => left.outputPartitioning
+ case RightOuter => right.outputPartitioning
+ case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
+ case x =>
+ throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType")
+ }
+
protected override def doExecute(): RDD[InternalRow] = {
val joinedRow = new JoinedRow()
left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
index bb18b5403f..41be78afd3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
@@ -40,7 +40,8 @@ case class SortMergeJoin(
override def output: Seq[Attribute] = left.output ++ right.output
- override def outputPartitioning: Partitioning = left.outputPartitioning
+ override def outputPartitioning: Partitioning =
+ PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning))
override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 845ce669f0..18b0e54dc7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -23,14 +23,18 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin}
import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext}
import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.test.TestSQLContext.planner._
import org.apache.spark.sql.types._
-import org.apache.spark.sql.{Row, SQLConf, execution}
+import org.apache.spark.sql.{SQLContext, Row, SQLConf, execution}
-class PlannerSuite extends SparkFunSuite {
+class PlannerSuite extends SparkFunSuite with SQLTestUtils {
+
+ override def sqlContext: SQLContext = TestSQLContext
+
private def testPartialAggregationPlan(query: LogicalPlan): Unit = {
val plannedOption = HashAggregation(query).headOption.orElse(Aggregation(query).headOption)
val planned =
@@ -157,4 +161,45 @@ class PlannerSuite extends SparkFunSuite {
val planned = planner.TakeOrderedAndProject(query)
assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject])
}
+
+ test("PartitioningCollection") {
+ withTempTable("normal", "small", "tiny") {
+ testData.registerTempTable("normal")
+ testData.limit(10).registerTempTable("small")
+ testData.limit(3).registerTempTable("tiny")
+
+ // Disable broadcast join
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+ {
+ val numExchanges = sql(
+ """
+ |SELECT *
+ |FROM
+ | normal JOIN small ON (normal.key = small.key)
+ | JOIN tiny ON (small.key = tiny.key)
+ """.stripMargin
+ ).queryExecution.executedPlan.collect {
+ case exchange: Exchange => exchange
+ }.length
+ assert(numExchanges === 3)
+ }
+
+ {
+ // This second query joins on different keys:
+ val numExchanges = sql(
+ """
+ |SELECT *
+ |FROM
+ | normal JOIN small ON (normal.key = small.key)
+ | JOIN tiny ON (normal.key = tiny.key)
+ """.stripMargin
+ ).queryExecution.executedPlan.collect {
+ case exchange: Exchange => exchange
+ }.length
+ assert(numExchanges === 3)
+ }
+
+ }
+ }
+ }
}