aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/test/scala
diff options
context:
space:
mode:
authorTakeshi YAMAMURO <linguin.m.s@gmail.com>2016-08-30 16:43:47 +0800
committerCheng Lian <lian@databricks.com>2016-08-30 16:43:47 +0800
commit94922d79e9f90fac3777db0974ccf7566b8ac3b3 (patch)
tree858f11e3b3df7644384ee5e5568799693d4e5f4e /sql/core/src/test/scala
parent8fb445d9bdead6f0ff2bd9879145fe688b3bdc80 (diff)
downloadspark-94922d79e9f90fac3777db0974ccf7566b8ac3b3.tar.gz
spark-94922d79e9f90fac3777db0974ccf7566b8ac3b3.tar.bz2
spark-94922d79e9f90fac3777db0974ccf7566b8ac3b3.zip
[SPARK-17289][SQL] Fix a bug to satisfy sort requirements in partial aggregations
## What changes were proposed in this pull request? Partial aggregations are generated in `EnsureRequirements`, but the planner fails to check if partial aggregation satisfies sort requirements. For the following query: ``` val df2 = (0 to 1000).map(x => (x % 2, x.toString)).toDF("a", "b").createOrReplaceTempView("t2") spark.sql("select max(b) from t2 group by a").explain(true) ``` Now, the SortAggregator won't insert Sort operator before partial aggregation, this will break sort-based partial aggregation. ``` == Physical Plan == SortAggregate(key=[a#5], functions=[max(b#6)], output=[max(b)#17]) +- *Sort [a#5 ASC], false, 0 +- Exchange hashpartitioning(a#5, 200) +- SortAggregate(key=[a#5], functions=[partial_max(b#6)], output=[a#5, max#19]) +- LocalTableScan [a#5, b#6] ``` Actually, a correct plan is: ``` == Physical Plan == SortAggregate(key=[a#5], functions=[max(b#6)], output=[max(b)#17]) +- *Sort [a#5 ASC], false, 0 +- Exchange hashpartitioning(a#5, 200) +- SortAggregate(key=[a#5], functions=[partial_max(b#6)], output=[a#5, max#19]) +- *Sort [a#5 ASC], false, 0 +- LocalTableScan [a#5, b#6] ``` ## How was this patch tested? Added tests in `PlannerSuite`. Author: Takeshi YAMAMURO <linguin.m.s@gmail.com> Closes #14865 from maropu/SPARK-17289.
Diffstat (limited to 'sql/core/src/test/scala')
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala22
1 files changed, 21 insertions, 1 deletions
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 07efc72bf6..b0aa3378e5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -18,12 +18,13 @@
package org.apache.spark.sql.execution
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{execution, DataFrame, Row}
+import org.apache.spark.sql.{execution, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition}
import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.execution.aggregate.SortAggregateExec
import org.apache.spark.sql.execution.columnar.InMemoryRelation
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchange}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
@@ -70,6 +71,25 @@ class PlannerSuite extends SharedSQLContext {
s"The plan of query $query does not have partial aggregations.")
}
+ test("SPARK-17289 sort-based partial aggregation needs a sort operator as a child") {
+ withTempView("testSortBasedPartialAggregation") {
+ val schema = StructType(
+ StructField(s"key", IntegerType, true) :: StructField(s"value", StringType, true) :: Nil)
+ val rowRDD = sparkContext.parallelize((0 until 1000).map(d => Row(d % 2, d.toString)))
+ spark.createDataFrame(rowRDD, schema)
+ .createOrReplaceTempView("testSortBasedPartialAggregation")
+
+ // This test assumes a query below uses sort-based aggregations
+ val planned = sql("SELECT MAX(value) FROM testSortBasedPartialAggregation GROUP BY key")
+ .queryExecution.executedPlan
+ // This line extracts both SortAggregate and Sort operators
+ val extractedOps = planned.collect { case n if n.nodeName contains "Sort" => n }
+ val aggOps = extractedOps.collect { case n if n.nodeName contains "SortAggregate" => n }
+ assert(extractedOps.size == 4 && aggOps.size == 2,
+ s"The plan $planned does not have correct sort-based partial aggregate pairs.")
+ }
+ }
+
test("non-partial aggregation for aggregates") {
withTempView("testNonPartialAggregation") {
val schema = StructType(StructField(s"value", IntegerType, true) :: Nil)