aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/test/scala
diff options
context:
space:
mode:
authorTakeshi YAMAMURO <linguin.m.s@gmail.com>2016-08-25 12:39:58 +0200
committerHerman van Hovell <hvanhovell@databricks.com>2016-08-25 12:39:58 +0200
commit2b0cc4e0dfa4ffb9f21ff4a303015bc9c962d42b (patch)
treefcf6b3fea6c9604bee57830c007f9d82ea6937c3 /sql/core/src/test/scala
parent6b8cb1fe52e2c8b4b87b0c7d820f3a1824287328 (diff)
downloadspark-2b0cc4e0dfa4ffb9f21ff4a303015bc9c962d42b.tar.gz
spark-2b0cc4e0dfa4ffb9f21ff4a303015bc9c962d42b.tar.bz2
spark-2b0cc4e0dfa4ffb9f21ff4a303015bc9c962d42b.zip
[SPARK-12978][SQL] Skip unnecessary final group-by when input data already clustered with group-by keys
This ticket targets the optimization to skip an unnecessary group-by operation below; Without opt.: ``` == Physical Plan == TungstenAggregate(key=[col0#159], functions=[(sum(col1#160),mode=Final,isDistinct=false),(avg(col2#161),mode=Final,isDistinct=false)], output=[col0#159,sum(col1)#177,avg(col2)#178]) +- TungstenAggregate(key=[col0#159], functions=[(sum(col1#160),mode=Partial,isDistinct=false),(avg(col2#161),mode=Partial,isDistinct=false)], output=[col0#159,sum#200,sum#201,count#202L]) +- TungstenExchange hashpartitioning(col0#159,200), None +- InMemoryColumnarTableScan [col0#159,col1#160,col2#161], InMemoryRelation [col0#159,col1#160,col2#161], true, 10000, StorageLevel(true, true, false, true, 1), ConvertToUnsafe, None ``` With opt.: ``` == Physical Plan == TungstenAggregate(key=[col0#159], functions=[(sum(col1#160),mode=Complete,isDistinct=false),(avg(col2#161),mode=Final,isDistinct=false)], output=[col0#159,sum(col1)#177,avg(col2)#178]) +- TungstenExchange hashpartitioning(col0#159,200), None +- InMemoryColumnarTableScan [col0#159,col1#160,col2#161], InMemoryRelation [col0#159,col1#160,col2#161], true, 10000, StorageLevel(true, true, false, true, 1), ConvertToUnsafe, None ``` Author: Takeshi YAMAMURO <linguin.m.s@gmail.com> Closes #10896 from maropu/SkipGroupbySpike.
Diffstat (limited to 'sql/core/src/test/scala')
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala15
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala59
2 files changed, 52 insertions, 22 deletions
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 499f318037..cd485770d2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -1248,17 +1248,17 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
}
/**
- * Verifies that there is no Exchange between the Aggregations for `df`
+ * Verifies that there is a single Aggregation for `df`
*/
- private def verifyNonExchangingAgg(df: DataFrame) = {
+ private def verifyNonExchangingSingleAgg(df: DataFrame) = {
var atFirstAgg: Boolean = false
df.queryExecution.executedPlan.foreach {
case agg: HashAggregateExec =>
- atFirstAgg = !atFirstAgg
- case _ =>
if (atFirstAgg) {
- fail("Should not have operators between the two aggregations")
+ fail("Should not have back to back Aggregates")
}
+ atFirstAgg = true
+ case _ =>
}
}
@@ -1292,9 +1292,10 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
// Group by the column we are distributed by. This should generate a plan with no exchange
// between the aggregates
val df3 = testData.repartition($"key").groupBy("key").count()
- verifyNonExchangingAgg(df3)
- verifyNonExchangingAgg(testData.repartition($"key", $"value")
+ verifyNonExchangingSingleAgg(df3)
+ verifyNonExchangingSingleAgg(testData.repartition($"key", $"value")
.groupBy("key", "value").count())
+ verifyNonExchangingSingleAgg(testData.repartition($"key").groupBy("key", "value").count())
// Grouping by just the first distributeBy expr, need to exchange.
verifyExchangingAgg(testData.repartition($"key", $"value")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 13490c3567..436ff59c4d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.execution
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{execution, Row}
+import org.apache.spark.sql.{execution, DataFrame, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder}
import org.apache.spark.sql.catalyst.plans.Inner
@@ -37,36 +37,65 @@ class PlannerSuite extends SharedSQLContext {
setupTestData()
- private def testPartialAggregationPlan(query: LogicalPlan): Unit = {
+ private def testPartialAggregationPlan(query: LogicalPlan): Seq[SparkPlan] = {
val planner = spark.sessionState.planner
import planner._
- val plannedOption = Aggregation(query).headOption
- val planned =
- plannedOption.getOrElse(
- fail(s"Could query play aggregation query $query. Is it an aggregation query?"))
- val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n }
-
- // For the new aggregation code path, there will be four aggregate operator for
- // distinct aggregations.
- assert(
- aggregations.size == 2 || aggregations.size == 4,
- s"The plan of query $query does not have partial aggregations.")
+ val ensureRequirements = EnsureRequirements(spark.sessionState.conf)
+ val planned = Aggregation(query).headOption.map(ensureRequirements(_))
+ .getOrElse(fail(s"Could query play aggregation query $query. Is it an aggregation query?"))
+ planned.collect { case n if n.nodeName contains "Aggregate" => n }
}
test("count is partially aggregated") {
val query = testData.groupBy('value).agg(count('key)).queryExecution.analyzed
- testPartialAggregationPlan(query)
+ assert(testPartialAggregationPlan(query).size == 2,
+ s"The plan of query $query does not have partial aggregations.")
}
test("count distinct is partially aggregated") {
val query = testData.groupBy('value).agg(countDistinct('key)).queryExecution.analyzed
testPartialAggregationPlan(query)
+ // For the new aggregation code path, there will be four aggregate operator for distinct
+ // aggregations.
+ assert(testPartialAggregationPlan(query).size == 4,
+ s"The plan of query $query does not have partial aggregations.")
}
test("mixed aggregates are partially aggregated") {
val query =
testData.groupBy('value).agg(count('value), countDistinct('key)).queryExecution.analyzed
- testPartialAggregationPlan(query)
+ // For the new aggregation code path, there will be four aggregate operator for distinct
+ // aggregations.
+ assert(testPartialAggregationPlan(query).size == 4,
+ s"The plan of query $query does not have partial aggregations.")
+ }
+
+ test("non-partial aggregation for aggregates") {
+ withTempView("testNonPartialAggregation") {
+ val schema = StructType(StructField(s"value", IntegerType, true) :: Nil)
+ val row = Row.fromSeq(Seq.fill(1)(null))
+ val rowRDD = sparkContext.parallelize(row :: Nil)
+ spark.createDataFrame(rowRDD, schema).repartition($"value")
+ .createOrReplaceTempView("testNonPartialAggregation")
+
+ val planned1 = sql("SELECT SUM(value) FROM testNonPartialAggregation GROUP BY value")
+ .queryExecution.executedPlan
+
+ // If input data are already partitioned and the same columns are used in grouping keys and
+ // aggregation values, no partial aggregation exist in query plans.
+ val aggOps1 = planned1.collect { case n if n.nodeName contains "Aggregate" => n }
+ assert(aggOps1.size == 1, s"The plan $planned1 has partial aggregations.")
+
+ val planned2 = sql(
+ """
+ |SELECT t.value, SUM(DISTINCT t.value)
+ |FROM (SELECT * FROM testNonPartialAggregation ORDER BY value) t
+ |GROUP BY t.value
+ """.stripMargin).queryExecution.executedPlan
+
+ val aggOps2 = planned1.collect { case n if n.nodeName contains "Aggregate" => n }
+ assert(aggOps2.size == 1, s"The plan $planned2 has partial aggregations.")
+ }
}
test("sizeInBytes estimation of limit operator for broadcast hash join optimization") {