aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala11
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala40
3 files changed, 52 insertions, 2 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
index de779ed370..f498f35792 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
@@ -61,6 +61,9 @@ case class SortOrder(child: Expression, direction: SortDirection)
override def sql: String = child.sql + " " + direction.sql
def isAscending: Boolean = direction == Ascending
+
+ def semanticEquals(other: SortOrder): Boolean =
+ (direction == other.direction) && child.semanticEquals(other.child)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
index 951051c4df..fee7010e8e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
@@ -250,7 +250,16 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
children = children.zip(requiredChildOrderings).map { case (child, requiredOrdering) =>
if (requiredOrdering.nonEmpty) {
// If child.outputOrdering is [a, b] and requiredOrdering is [a], we do not need to sort.
- if (requiredOrdering != child.outputOrdering.take(requiredOrdering.length)) {
+ val orderingMatched = if (requiredOrdering.length > child.outputOrdering.length) {
+ false
+ } else {
+ requiredOrdering.zip(child.outputOrdering).forall {
+ case (requiredOrder, childOutputOrder) =>
+ requiredOrder.semanticEquals(childOutputOrder)
+ }
+ }
+
+ if (!orderingMatched) {
SortExec(requiredOrdering, global = false, child = child)
} else {
child
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 436ff59c4d..07efc72bf6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{execution, DataFrame, Row}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder}
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition}
import org.apache.spark.sql.catalyst.plans.physical._
@@ -444,6 +444,44 @@ class PlannerSuite extends SharedSQLContext {
}
}
+ test("EnsureRequirements skips sort when required ordering is semantically equal to " +
+ "existing ordering") {
+ val exprId: ExprId = NamedExpression.newExprId
+ val attribute1 =
+ AttributeReference(
+ name = "col1",
+ dataType = LongType,
+ nullable = false
+ ) (exprId = exprId,
+ qualifier = Some("col1_qualifier")
+ )
+
+ val attribute2 =
+ AttributeReference(
+ name = "col1",
+ dataType = LongType,
+ nullable = false
+ ) (exprId = exprId)
+
+ val orderingA1 = SortOrder(attribute1, Ascending)
+ val orderingA2 = SortOrder(attribute2, Ascending)
+
+ assert(orderingA1 != orderingA2, s"$orderingA1 should NOT equal to $orderingA2")
+ assert(orderingA1.semanticEquals(orderingA2),
+ s"$orderingA1 should be semantically equal to $orderingA2")
+
+ val inputPlan = DummySparkPlan(
+ children = DummySparkPlan(outputOrdering = Seq(orderingA1)) :: Nil,
+ requiredChildOrdering = Seq(Seq(orderingA2)),
+ requiredChildDistribution = Seq(UnspecifiedDistribution)
+ )
+ val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan)
+ assertDistributionRequirementsAreSatisfied(outputPlan)
+ if (outputPlan.collect { case s: SortExec => true }.nonEmpty) {
+ fail(s"No sorts should have been added:\n$outputPlan")
+ }
+ }
+
// This is a regression test for SPARK-11135
test("EnsureRequirements adds sort when required ordering isn't a prefix of existing ordering") {
val orderingA = SortOrder(Literal(1), Ascending)