aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala128
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala104
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala5
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala151
5 files changed, 328 insertions, 65 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index ec659ce789..5a89a90b73 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -75,6 +75,37 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution {
def clustering: Set[Expression] = ordering.map(_.child).toSet
}
+/**
+ * Describes how an operator's output is split across partitions. The `compatibleWith`,
+ * `guarantees`, and `satisfies` methods describe relationships between child partitionings,
+ * target partitionings, and [[Distribution]]s. These relations are described more precisely in
+ * their individual method docs, but at a high level:
+ *
+ * - `satisfies` is a relationship between partitionings and distributions.
+ * - `compatibleWith` is relationships between an operator's child output partitionings.
+ * - `guarantees` is a relationship between a child's existing output partitioning and a target
+ * output partitioning.
+ *
+ * Diagrammatically:
+ *
+ * +--------------+
+ * | Distribution |
+ * +--------------+
+ * ^
+ * |
+ * satisfies
+ * |
+ * +--------------+ +--------------+
+ * | Child | | Target |
+ * +----| Partitioning |----guarantees--->| Partitioning |
+ * | +--------------+ +--------------+
+ * | ^
+ * | |
+ * | compatibleWith
+ * | |
+ * +------------+
+ *
+ */
sealed trait Partitioning {
/** Returns the number of partitions that the data is split across */
val numPartitions: Int
@@ -90,9 +121,66 @@ sealed trait Partitioning {
/**
* Returns true iff we can say that the partitioning scheme of this [[Partitioning]]
* guarantees the same partitioning scheme described by `other`.
+ *
+ * Compatibility of partitionings is only checked for operators that have multiple children
+ * and that require a specific child output [[Distribution]], such as joins.
+ *
+ * Intuitively, partitionings are compatible if they route the same partitioning key to the same
+ * partition. For instance, two hash partitionings are only compatible if they produce the same
+ * number of output partitionings and hash records according to the same hash function and
+ * same partitioning key schema.
+ *
+ * Put another way, two partitionings are compatible with each other if they satisfy all of the
+ * same distribution guarantees.
*/
- // TODO: Add an example once we have the `nullSafe` concept.
- def guarantees(other: Partitioning): Boolean
+ def compatibleWith(other: Partitioning): Boolean
+
+ /**
+ * Returns true iff we can say that the partitioning scheme of this [[Partitioning]] guarantees
+ * the same partitioning scheme described by `other`. If a `A.guarantees(B)`, then repartitioning
+ * the child's output according to `B` will be unnecessary. `guarantees` is used as a performance
+ * optimization to allow the exchange planner to avoid redundant repartitionings. By default,
+ * a partitioning only guarantees partitionings that are equal to itself (i.e. the same number
+ * of partitions, same strategy (range or hash), etc).
+ *
+ * In order to enable more aggressive optimization, this strict equality check can be relaxed.
+ * For example, say that the planner needs to repartition all of an operator's children so that
+ * they satisfy the [[AllTuples]] distribution. One way to do this is to repartition all children
+ * to have the [[SinglePartition]] partitioning. If one of the operator's children already happens
+ * to be hash-partitioned with a single partition then we do not need to re-shuffle this child;
+ * this repartitioning can be avoided if a single-partition [[HashPartitioning]] `guarantees`
+ * [[SinglePartition]].
+ *
+ * The SinglePartition example given above is not particularly interesting; guarantees' real
+ * value occurs for more advanced partitioning strategies. SPARK-7871 will introduce a notion
+ * of null-safe partitionings, under which partitionings can specify whether rows whose
+ * partitioning keys contain null values will be grouped into the same partition or whether they
+ * will have an unknown / random distribution. If a partitioning does not require nulls to be
+ * clustered then a partitioning which _does_ cluster nulls will guarantee the null clustered
+ * partitioning. The converse is not true, however: a partitioning which clusters nulls cannot
+ * be guaranteed by one which does not cluster them. Thus, in general `guarantees` is not a
+ * symmetric relation.
+ *
+ * Another way to think about `guarantees`: if `A.guarantees(B)`, then any partitioning of rows
+ * produced by `A` could have also been produced by `B`.
+ */
+ def guarantees(other: Partitioning): Boolean = this == other
+}
+
+object Partitioning {
+ def allCompatible(partitionings: Seq[Partitioning]): Boolean = {
+ // Note: this assumes transitivity
+ partitionings.sliding(2).map {
+ case Seq(a) => true
+ case Seq(a, b) =>
+ if (a.numPartitions != b.numPartitions) {
+ assert(!a.compatibleWith(b) && !b.compatibleWith(a))
+ false
+ } else {
+ a.compatibleWith(b) && b.compatibleWith(a)
+ }
+ }.forall(_ == true)
+ }
}
case class UnknownPartitioning(numPartitions: Int) extends Partitioning {
@@ -101,6 +189,8 @@ case class UnknownPartitioning(numPartitions: Int) extends Partitioning {
case _ => false
}
+ override def compatibleWith(other: Partitioning): Boolean = false
+
override def guarantees(other: Partitioning): Boolean = false
}
@@ -109,21 +199,9 @@ case object SinglePartition extends Partitioning {
override def satisfies(required: Distribution): Boolean = true
- override def guarantees(other: Partitioning): Boolean = other match {
- case SinglePartition => true
- case _ => false
- }
-}
-
-case object BroadcastPartitioning extends Partitioning {
- val numPartitions = 1
+ override def compatibleWith(other: Partitioning): Boolean = other.numPartitions == 1
- override def satisfies(required: Distribution): Boolean = true
-
- override def guarantees(other: Partitioning): Boolean = other match {
- case BroadcastPartitioning => true
- case _ => false
- }
+ override def guarantees(other: Partitioning): Boolean = other.numPartitions == 1
}
/**
@@ -147,6 +225,12 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
case _ => false
}
+ override def compatibleWith(other: Partitioning): Boolean = other match {
+ case o: HashPartitioning =>
+ this.clusteringSet == o.clusteringSet && this.numPartitions == o.numPartitions
+ case _ => false
+ }
+
override def guarantees(other: Partitioning): Boolean = other match {
case o: HashPartitioning =>
this.clusteringSet == o.clusteringSet && this.numPartitions == o.numPartitions
@@ -185,6 +269,11 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
case _ => false
}
+ override def compatibleWith(other: Partitioning): Boolean = other match {
+ case o: RangePartitioning => this == o
+ case _ => false
+ }
+
override def guarantees(other: Partitioning): Boolean = other match {
case o: RangePartitioning => this == o
case _ => false
@@ -229,6 +318,13 @@ case class PartitioningCollection(partitionings: Seq[Partitioning])
partitionings.exists(_.satisfies(required))
/**
+ * Returns true if any `partitioning` of this collection is compatible with
+ * the given [[Partitioning]].
+ */
+ override def compatibleWith(other: Partitioning): Boolean =
+ partitionings.exists(_.compatibleWith(other))
+
+ /**
* Returns true if any `partitioning` of this collection guarantees
* the given [[Partitioning]].
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index 49bb729800..b89e634761 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -190,66 +190,72 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
* of input data meets the
* [[org.apache.spark.sql.catalyst.plans.physical.Distribution Distribution]] requirements for
* each operator by inserting [[Exchange]] Operators where required. Also ensure that the
- * required input partition ordering requirements are met.
+ * input partition ordering requirements are met.
*/
private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[SparkPlan] {
// TODO: Determine the number of partitions.
- def numPartitions: Int = sqlContext.conf.numShufflePartitions
+ private def numPartitions: Int = sqlContext.conf.numShufflePartitions
- def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
- case operator: SparkPlan =>
- // Adds Exchange or Sort operators as required
- def addOperatorsIfNecessary(
- partitioning: Partitioning,
- rowOrdering: Seq[SortOrder],
- child: SparkPlan): SparkPlan = {
-
- def addShuffleIfNecessary(child: SparkPlan): SparkPlan = {
- if (!child.outputPartitioning.guarantees(partitioning)) {
- Exchange(partitioning, child)
- } else {
- child
- }
- }
+ /**
+ * Given a required distribution, returns a partitioning that satisfies that distribution.
+ */
+ private def canonicalPartitioning(requiredDistribution: Distribution): Partitioning = {
+ requiredDistribution match {
+ case AllTuples => SinglePartition
+ case ClusteredDistribution(clustering) => HashPartitioning(clustering, numPartitions)
+ case OrderedDistribution(ordering) => RangePartitioning(ordering, numPartitions)
+ case dist => sys.error(s"Do not know how to satisfy distribution $dist")
+ }
+ }
- def addSortIfNecessary(child: SparkPlan): SparkPlan = {
-
- if (rowOrdering.nonEmpty) {
- // If child.outputOrdering is [a, b] and rowOrdering is [a], we do not need to sort.
- val minSize = Seq(rowOrdering.size, child.outputOrdering.size).min
- if (minSize == 0 || rowOrdering.take(minSize) != child.outputOrdering.take(minSize)) {
- sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child)
- } else {
- child
- }
- } else {
- child
- }
- }
+ private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = {
+ val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution
+ val requiredChildOrderings: Seq[Seq[SortOrder]] = operator.requiredChildOrdering
+ var children: Seq[SparkPlan] = operator.children
- addSortIfNecessary(addShuffleIfNecessary(child))
+ // Ensure that the operator's children satisfy their output distribution requirements:
+ children = children.zip(requiredChildDistributions).map { case (child, distribution) =>
+ if (child.outputPartitioning.satisfies(distribution)) {
+ child
+ } else {
+ Exchange(canonicalPartitioning(distribution), child)
}
+ }
- val requirements =
- (operator.requiredChildDistribution, operator.requiredChildOrdering, operator.children)
-
- val fixedChildren = requirements.zipped.map {
- case (AllTuples, rowOrdering, child) =>
- addOperatorsIfNecessary(SinglePartition, rowOrdering, child)
- case (ClusteredDistribution(clustering), rowOrdering, child) =>
- addOperatorsIfNecessary(HashPartitioning(clustering, numPartitions), rowOrdering, child)
- case (OrderedDistribution(ordering), rowOrdering, child) =>
- addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), rowOrdering, child)
-
- case (UnspecifiedDistribution, Seq(), child) =>
+ // If the operator has multiple children and specifies child output distributions (e.g. join),
+ // then the children's output partitionings must be compatible:
+ if (children.length > 1
+ && requiredChildDistributions.toSet != Set(UnspecifiedDistribution)
+ && !Partitioning.allCompatible(children.map(_.outputPartitioning))) {
+ children = children.zip(requiredChildDistributions).map { case (child, distribution) =>
+ val targetPartitioning = canonicalPartitioning(distribution)
+ if (child.outputPartitioning.guarantees(targetPartitioning)) {
child
- case (UnspecifiedDistribution, rowOrdering, child) =>
- sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child)
+ } else {
+ Exchange(targetPartitioning, child)
+ }
+ }
+ }
- case (dist, ordering, _) =>
- sys.error(s"Don't know how to ensure $dist with ordering $ordering")
+ // Now that we've performed any necessary shuffles, add sorts to guarantee output orderings:
+ children = children.zip(requiredChildOrderings).map { case (child, requiredOrdering) =>
+ if (requiredOrdering.nonEmpty) {
+ // If child.outputOrdering is [a, b] and requiredOrdering is [a], we do not need to sort.
+ val minSize = Seq(requiredOrdering.size, child.outputOrdering.size).min
+ if (minSize == 0 || requiredOrdering.take(minSize) != child.outputOrdering.take(minSize)) {
+ sqlContext.planner.BasicOperators.getSortOperator(requiredOrdering, global = false, child)
+ } else {
+ child
+ }
+ } else {
+ child
}
+ }
- operator.withNewChildren(fixedChildren)
+ operator.withNewChildren(children)
+ }
+
+ def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
+ case operator: SparkPlan => ensureDistributionAndOrdering(operator)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index c5d1ed0937..24950f2606 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -256,6 +256,11 @@ case class Repartition(numPartitions: Int, shuffle: Boolean, child: SparkPlan)
extends UnaryNode {
override def output: Seq[Attribute] = child.output
+ override def outputPartitioning: Partitioning = {
+ if (numPartitions == 1) SinglePartition
+ else UnknownPartitioning(numPartitions)
+ }
+
protected override def doExecute(): RDD[InternalRow] = {
child.execute().map(_.copy()).coalesce(numPartitions, shuffle)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala
index 29f3beb3cb..855555dd1d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala
@@ -21,6 +21,7 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.rules.Rule
/**
@@ -33,6 +34,8 @@ case class ConvertToUnsafe(child: SparkPlan) extends UnaryNode {
require(UnsafeProjection.canSupport(child.schema), s"Cannot convert ${child.schema} to Unsafe")
override def output: Seq[Attribute] = child.output
+ override def outputPartitioning: Partitioning = child.outputPartitioning
+ override def outputOrdering: Seq[SortOrder] = child.outputOrdering
override def outputsUnsafeRows: Boolean = true
override def canProcessUnsafeRows: Boolean = false
override def canProcessSafeRows: Boolean = true
@@ -51,6 +54,8 @@ case class ConvertToUnsafe(child: SparkPlan) extends UnaryNode {
@DeveloperApi
case class ConvertToSafe(child: SparkPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output
+ override def outputPartitioning: Partitioning = child.outputPartitioning
+ override def outputOrdering: Seq[SortOrder] = child.outputOrdering
override def outputsUnsafeRows: Boolean = false
override def canProcessUnsafeRows: Boolean = true
override def canProcessSafeRows: Boolean = false
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 18b0e54dc7..5582caa0d3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -18,9 +18,13 @@
package org.apache.spark.sql.execution
import org.apache.spark.SparkFunSuite
+import org.apache.spark.rdd.RDD
import org.apache.spark.sql.TestData._
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext}
@@ -202,4 +206,151 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils {
}
}
}
+
+ // --- Unit tests of EnsureRequirements ---------------------------------------------------------
+
+ // When it comes to testing whether EnsureRequirements properly ensures distribution requirements,
+ // there two dimensions that need to be considered: are the child partitionings compatible and
+ // do they satisfy the distribution requirements? As a result, we need at least four test cases.
+
+ private def assertDistributionRequirementsAreSatisfied(outputPlan: SparkPlan): Unit = {
+ if (outputPlan.children.length > 1
+ && outputPlan.requiredChildDistribution.toSet != Set(UnspecifiedDistribution)) {
+ val childPartitionings = outputPlan.children.map(_.outputPartitioning)
+ if (!Partitioning.allCompatible(childPartitionings)) {
+ fail(s"Partitionings are not compatible: $childPartitionings")
+ }
+ }
+ outputPlan.children.zip(outputPlan.requiredChildDistribution).foreach {
+ case (child, requiredDist) =>
+ assert(child.outputPartitioning.satisfies(requiredDist),
+ s"$child output partitioning does not satisfy $requiredDist:\n$outputPlan")
+ }
+ }
+
+ test("EnsureRequirements with incompatible child partitionings which satisfy distribution") {
+ // Consider an operator that requires inputs that are clustered by two expressions (e.g.
+ // sort merge join where there are multiple columns in the equi-join condition)
+ val clusteringA = Literal(1) :: Nil
+ val clusteringB = Literal(2) :: Nil
+ val distribution = ClusteredDistribution(clusteringA ++ clusteringB)
+ // Say that the left and right inputs are each partitioned by _one_ of the two join columns:
+ val leftPartitioning = HashPartitioning(clusteringA, 1)
+ val rightPartitioning = HashPartitioning(clusteringB, 1)
+ // Individually, each input's partitioning satisfies the clustering distribution:
+ assert(leftPartitioning.satisfies(distribution))
+ assert(rightPartitioning.satisfies(distribution))
+ // However, these partitionings are not compatible with each other, so we still need to
+ // repartition both inputs prior to performing the join:
+ assert(!leftPartitioning.compatibleWith(rightPartitioning))
+ assert(!rightPartitioning.compatibleWith(leftPartitioning))
+ val inputPlan = DummySparkPlan(
+ children = Seq(
+ DummySparkPlan(outputPartitioning = leftPartitioning),
+ DummySparkPlan(outputPartitioning = rightPartitioning)
+ ),
+ requiredChildDistribution = Seq(distribution, distribution),
+ requiredChildOrdering = Seq(Seq.empty, Seq.empty)
+ )
+ val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan)
+ assertDistributionRequirementsAreSatisfied(outputPlan)
+ if (outputPlan.collect { case Exchange(_, _) => true }.isEmpty) {
+ fail(s"Exchange should have been added:\n$outputPlan")
+ }
+ }
+
+ test("EnsureRequirements with child partitionings with different numbers of output partitions") {
+ // This is similar to the previous test, except it checks that partitionings are not compatible
+ // unless they produce the same number of partitions.
+ val clustering = Literal(1) :: Nil
+ val distribution = ClusteredDistribution(clustering)
+ val inputPlan = DummySparkPlan(
+ children = Seq(
+ DummySparkPlan(outputPartitioning = HashPartitioning(clustering, 1)),
+ DummySparkPlan(outputPartitioning = HashPartitioning(clustering, 2))
+ ),
+ requiredChildDistribution = Seq(distribution, distribution),
+ requiredChildOrdering = Seq(Seq.empty, Seq.empty)
+ )
+ val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan)
+ assertDistributionRequirementsAreSatisfied(outputPlan)
+ }
+
+ test("EnsureRequirements with compatible child partitionings that do not satisfy distribution") {
+ val distribution = ClusteredDistribution(Literal(1) :: Nil)
+ // The left and right inputs have compatible partitionings but they do not satisfy the
+ // distribution because they are clustered on different columns. Thus, we need to shuffle.
+ val childPartitioning = HashPartitioning(Literal(2) :: Nil, 1)
+ assert(!childPartitioning.satisfies(distribution))
+ val inputPlan = DummySparkPlan(
+ children = Seq(
+ DummySparkPlan(outputPartitioning = childPartitioning),
+ DummySparkPlan(outputPartitioning = childPartitioning)
+ ),
+ requiredChildDistribution = Seq(distribution, distribution),
+ requiredChildOrdering = Seq(Seq.empty, Seq.empty)
+ )
+ val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan)
+ assertDistributionRequirementsAreSatisfied(outputPlan)
+ if (outputPlan.collect { case Exchange(_, _) => true }.isEmpty) {
+ fail(s"Exchange should have been added:\n$outputPlan")
+ }
+ }
+
+ test("EnsureRequirements with compatible child partitionings that satisfy distribution") {
+ // In this case, all requirements are satisfied and no exchange should be added.
+ val distribution = ClusteredDistribution(Literal(1) :: Nil)
+ val childPartitioning = HashPartitioning(Literal(1) :: Nil, 5)
+ assert(childPartitioning.satisfies(distribution))
+ val inputPlan = DummySparkPlan(
+ children = Seq(
+ DummySparkPlan(outputPartitioning = childPartitioning),
+ DummySparkPlan(outputPartitioning = childPartitioning)
+ ),
+ requiredChildDistribution = Seq(distribution, distribution),
+ requiredChildOrdering = Seq(Seq.empty, Seq.empty)
+ )
+ val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan)
+ assertDistributionRequirementsAreSatisfied(outputPlan)
+ if (outputPlan.collect { case Exchange(_, _) => true }.nonEmpty) {
+ fail(s"Exchange should not have been added:\n$outputPlan")
+ }
+ }
+
+ // This is a regression test for SPARK-9703
+ test("EnsureRequirements should not repartition if only ordering requirement is unsatisfied") {
+ // Consider an operator that imposes both output distribution and ordering requirements on its
+ // children, such as sort sort merge join. If the distribution requirements are satisfied but
+ // the output ordering requirements are unsatisfied, then the planner should only add sorts and
+ // should not need to add additional shuffles / exchanges.
+ val outputOrdering = Seq(SortOrder(Literal(1), Ascending))
+ val distribution = ClusteredDistribution(Literal(1) :: Nil)
+ val inputPlan = DummySparkPlan(
+ children = Seq(
+ DummySparkPlan(outputPartitioning = SinglePartition),
+ DummySparkPlan(outputPartitioning = SinglePartition)
+ ),
+ requiredChildDistribution = Seq(distribution, distribution),
+ requiredChildOrdering = Seq(outputOrdering, outputOrdering)
+ )
+ val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan)
+ assertDistributionRequirementsAreSatisfied(outputPlan)
+ if (outputPlan.collect { case Exchange(_, _) => true }.nonEmpty) {
+ fail(s"No Exchanges should have been added:\n$outputPlan")
+ }
+ }
+
+ // ---------------------------------------------------------------------------------------------
+}
+
+// Used for unit-testing EnsureRequirements
+private case class DummySparkPlan(
+ override val children: Seq[SparkPlan] = Nil,
+ override val outputOrdering: Seq[SortOrder] = Nil,
+ override val outputPartitioning: Partitioning = UnknownPartitioning(0),
+ override val requiredChildDistribution: Seq[Distribution] = Nil,
+ override val requiredChildOrdering: Seq[Seq[SortOrder]] = Nil
+ ) extends SparkPlan {
+ override protected def doExecute(): RDD[InternalRow] = throw new NotImplementedError
+ override def output: Seq[Attribute] = Seq.empty
}