aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core/src')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala22
2 files changed, 23 insertions, 2 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
index fee7010e8e..66e99ded24 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
@@ -164,7 +164,8 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
// If an aggregation needs a shuffle and support partial aggregations, a map-side partial
// aggregation and a shuffle are added as children.
val (mergeAgg, mapSideAgg) = AggUtils.createMapMergeAggregatePair(operator)
- (mergeAgg, createShuffleExchange(requiredChildDistributions.head, mapSideAgg) :: Nil)
+ (mergeAgg, createShuffleExchange(
+ requiredChildDistributions.head, ensureDistributionAndOrdering(mapSideAgg)) :: Nil)
case _ =>
// Ensure that the operator's children satisfy their output distribution requirements:
val childrenWithDist = operator.children.zip(requiredChildDistributions)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 07efc72bf6..b0aa3378e5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -18,12 +18,13 @@
package org.apache.spark.sql.execution
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{execution, DataFrame, Row}
+import org.apache.spark.sql.{execution, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition}
import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.execution.aggregate.SortAggregateExec
import org.apache.spark.sql.execution.columnar.InMemoryRelation
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchange}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
@@ -70,6 +71,25 @@ class PlannerSuite extends SharedSQLContext {
s"The plan of query $query does not have partial aggregations.")
}
+ test("SPARK-17289 sort-based partial aggregation needs a sort operator as a child") {
+ withTempView("testSortBasedPartialAggregation") {
+ val schema = StructType(
+ StructField(s"key", IntegerType, true) :: StructField(s"value", StringType, true) :: Nil)
+ val rowRDD = sparkContext.parallelize((0 until 1000).map(d => Row(d % 2, d.toString)))
+ spark.createDataFrame(rowRDD, schema)
+ .createOrReplaceTempView("testSortBasedPartialAggregation")
+
+ // This test assumes a query below uses sort-based aggregations
+ val planned = sql("SELECT MAX(value) FROM testSortBasedPartialAggregation GROUP BY key")
+ .queryExecution.executedPlan
+ // This line extracts both SortAggregate and Sort operators
+ val extractedOps = planned.collect { case n if n.nodeName contains "Sort" => n }
+ val aggOps = extractedOps.collect { case n if n.nodeName contains "SortAggregate" => n }
+ assert(extractedOps.size == 4 && aggOps.size == 2,
+ s"The plan $planned does not have correct sort-based partial aggregate pairs.")
+ }
+ }
+
test("non-partial aggregation for aggregates") {
withTempView("testNonPartialAggregation") {
val schema = StructType(StructField(s"value", IntegerType, true) :: Nil)