aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala39
2 files changed, 45 insertions, 0 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index 7b4161930b..6b10057707 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -467,6 +467,12 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
}
def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
+ case operator @ Exchange(partitioning, child, _) =>
+ child.children match {
+ case Exchange(childPartitioning, baseChild, _)::Nil =>
+ if (childPartitioning.guarantees(partitioning)) child else operator
+ case _ => operator
+ }
case operator: SparkPlan => ensureDistributionAndOrdering(operator)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 858e289c27..03a1b8e11d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -417,6 +417,45 @@ class PlannerSuite extends SharedSQLContext {
}
}
+ test("EnsureRequirements eliminates Exchange if child has Exchange with same partitioning") {
+ val distribution = ClusteredDistribution(Literal(1) :: Nil)
+ val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 5)
+ val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5)
+ assert(!childPartitioning.satisfies(distribution))
+ val inputPlan = Exchange(finalPartitioning,
+ DummySparkPlan(
+ children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil,
+ requiredChildDistribution = Seq(distribution),
+ requiredChildOrdering = Seq(Seq.empty)),
+ None)
+
+ val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan)
+ assertDistributionRequirementsAreSatisfied(outputPlan)
+ if (outputPlan.collect { case e: Exchange => true }.size == 2) {
+ fail(s"Topmost Exchange should have been eliminated:\n$outputPlan")
+ }
+ }
+
+ test("EnsureRequirements does not eliminate Exchange with different partitioning") {
+ val distribution = ClusteredDistribution(Literal(1) :: Nil)
+ // Number of partitions differ
+ val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 8)
+ val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5)
+ assert(!childPartitioning.satisfies(distribution))
+ val inputPlan = Exchange(finalPartitioning,
+ DummySparkPlan(
+ children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil,
+ requiredChildDistribution = Seq(distribution),
+ requiredChildOrdering = Seq(Seq.empty)),
+ None)
+
+ val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan)
+ assertDistributionRequirementsAreSatisfied(outputPlan)
+ if (outputPlan.collect { case e: Exchange => true }.size == 1) {
+ fail(s"Topmost Exchange should not have been eliminated:\n$outputPlan")
+ }
+ }
+
// ---------------------------------------------------------------------------------------------
}