diff options
Diffstat (limited to 'sql')
3 files changed, 52 insertions, 2 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index de779ed370..f498f35792 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -61,6 +61,9 @@ case class SortOrder(child: Expression, direction: SortDirection) override def sql: String = child.sql + " " + direction.sql def isAscending: Boolean = direction == Ascending + + def semanticEquals(other: SortOrder): Boolean = + (direction == other.direction) && child.semanticEquals(other.child) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 951051c4df..fee7010e8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -250,7 +250,16 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { children = children.zip(requiredChildOrderings).map { case (child, requiredOrdering) => if (requiredOrdering.nonEmpty) { // If child.outputOrdering is [a, b] and requiredOrdering is [a], we do not need to sort. - if (requiredOrdering != child.outputOrdering.take(requiredOrdering.length)) { + val orderingMatched = if (requiredOrdering.length > child.outputOrdering.length) { + false + } else { + requiredOrdering.zip(child.outputOrdering).forall { + case (requiredOrder, childOutputOrder) => + requiredOrder.semanticEquals(childOutputOrder) + } + } + + if (!orderingMatched) { SortExec(requiredOrdering, global = false, child = child) } else { child diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 436ff59c4d..07efc72bf6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.{execution, DataFrame, Row} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} import org.apache.spark.sql.catalyst.plans.physical._ @@ -444,6 +444,44 @@ class PlannerSuite extends SharedSQLContext { } } + test("EnsureRequirements skips sort when required ordering is semantically equal to " + + "existing ordering") { + val exprId: ExprId = NamedExpression.newExprId + val attribute1 = + AttributeReference( + name = "col1", + dataType = LongType, + nullable = false + ) (exprId = exprId, + qualifier = Some("col1_qualifier") + ) + + val attribute2 = + AttributeReference( + name = "col1", + dataType = LongType, + nullable = false + ) (exprId = exprId) + + val orderingA1 = SortOrder(attribute1, Ascending) + val orderingA2 = SortOrder(attribute2, Ascending) + + assert(orderingA1 != orderingA2, s"$orderingA1 should NOT equal to $orderingA2") + assert(orderingA1.semanticEquals(orderingA2), + s"$orderingA1 should be semantically equal to $orderingA2") + + val inputPlan = DummySparkPlan( + children = DummySparkPlan(outputOrdering = Seq(orderingA1)) :: Nil, + requiredChildOrdering = Seq(Seq(orderingA2)), + requiredChildDistribution = Seq(UnspecifiedDistribution) + ) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) + assertDistributionRequirementsAreSatisfied(outputPlan) + if (outputPlan.collect { case s: SortExec => true }.nonEmpty) { + fail(s"No sorts should have been added:\n$outputPlan") + } + } + // This is a regression test for SPARK-11135 test("EnsureRequirements adds sort when required ordering isn't a prefix of existing ordering") { val orderingA = SortOrder(Literal(1), Ascending) |