diff options
author | Takeshi YAMAMURO <linguin.m.s@gmail.com> | 2016-08-30 16:43:47 +0800 |
---|---|---|
committer | Cheng Lian <lian@databricks.com> | 2016-08-30 16:43:47 +0800 |
commit | 94922d79e9f90fac3777db0974ccf7566b8ac3b3 (patch) | |
tree | 858f11e3b3df7644384ee5e5568799693d4e5f4e | |
parent | 8fb445d9bdead6f0ff2bd9879145fe688b3bdc80 (diff) | |
download | spark-94922d79e9f90fac3777db0974ccf7566b8ac3b3.tar.gz spark-94922d79e9f90fac3777db0974ccf7566b8ac3b3.tar.bz2 spark-94922d79e9f90fac3777db0974ccf7566b8ac3b3.zip |
[SPARK-17289][SQL] Fix a bug to satisfy sort requirements in partial aggregations
## What changes were proposed in this pull request?
Partial aggregations are generated in `EnsureRequirements`, but the planner fails to
check if partial aggregation satisfies sort requirements.
For the following query:
```
val df2 = (0 to 1000).map(x => (x % 2, x.toString)).toDF("a", "b").createOrReplaceTempView("t2")
spark.sql("select max(b) from t2 group by a").explain(true)
```
Now, the SortAggregator won't insert Sort operator before partial aggregation, this will break sort-based partial aggregation.
```
== Physical Plan ==
SortAggregate(key=[a#5], functions=[max(b#6)], output=[max(b)#17])
+- *Sort [a#5 ASC], false, 0
+- Exchange hashpartitioning(a#5, 200)
+- SortAggregate(key=[a#5], functions=[partial_max(b#6)], output=[a#5, max#19])
+- LocalTableScan [a#5, b#6]
```
Actually, a correct plan is:
```
== Physical Plan ==
SortAggregate(key=[a#5], functions=[max(b#6)], output=[max(b)#17])
+- *Sort [a#5 ASC], false, 0
+- Exchange hashpartitioning(a#5, 200)
+- SortAggregate(key=[a#5], functions=[partial_max(b#6)], output=[a#5, max#19])
+- *Sort [a#5 ASC], false, 0
+- LocalTableScan [a#5, b#6]
```
## How was this patch tested?
Added tests in `PlannerSuite`.
Author: Takeshi YAMAMURO <linguin.m.s@gmail.com>
Closes #14865 from maropu/SPARK-17289.
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala | 3 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala | 22 |
2 files changed, 23 insertions, 2 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index fee7010e8e..66e99ded24 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -164,7 +164,8 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { // If an aggregation needs a shuffle and support partial aggregations, a map-side partial // aggregation and a shuffle are added as children. val (mergeAgg, mapSideAgg) = AggUtils.createMapMergeAggregatePair(operator) - (mergeAgg, createShuffleExchange(requiredChildDistributions.head, mapSideAgg) :: Nil) + (mergeAgg, createShuffleExchange( + requiredChildDistributions.head, ensureDistributionAndOrdering(mapSideAgg)) :: Nil) case _ => // Ensure that the operator's children satisfy their output distribution requirements: val childrenWithDist = operator.children.zip(requiredChildDistributions) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 07efc72bf6..b0aa3378e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -18,12 +18,13 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{execution, DataFrame, Row} +import org.apache.spark.sql.{execution, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.aggregate.SortAggregateExec import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchange} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} @@ -70,6 +71,25 @@ class PlannerSuite extends SharedSQLContext { s"The plan of query $query does not have partial aggregations.") } + test("SPARK-17289 sort-based partial aggregation needs a sort operator as a child") { + withTempView("testSortBasedPartialAggregation") { + val schema = StructType( + StructField(s"key", IntegerType, true) :: StructField(s"value", StringType, true) :: Nil) + val rowRDD = sparkContext.parallelize((0 until 1000).map(d => Row(d % 2, d.toString))) + spark.createDataFrame(rowRDD, schema) + .createOrReplaceTempView("testSortBasedPartialAggregation") + + // This test assumes a query below uses sort-based aggregations + val planned = sql("SELECT MAX(value) FROM testSortBasedPartialAggregation GROUP BY key") + .queryExecution.executedPlan + // This line extracts both SortAggregate and Sort operators + val extractedOps = planned.collect { case n if n.nodeName contains "Sort" => n } + val aggOps = extractedOps.collect { case n if n.nodeName contains "SortAggregate" => n } + assert(extractedOps.size == 4 && aggOps.size == 2, + s"The plan $planned does not have correct sort-based partial aggregate pairs.") + } + } + test("non-partial aggregation for aggregates") { withTempView("testNonPartialAggregation") { val schema = StructType(StructField(s"value", IntegerType, true) :: Nil) |