From 2b0cc4e0dfa4ffb9f21ff4a303015bc9c962d42b Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Thu, 25 Aug 2016 12:39:58 +0200 Subject: [SPARK-12978][SQL] Skip unnecessary final group-by when input data already clustered with group-by keys This ticket targets the optimization to skip an unnecessary group-by operation below; Without opt.: ``` == Physical Plan == TungstenAggregate(key=[col0#159], functions=[(sum(col1#160),mode=Final,isDistinct=false),(avg(col2#161),mode=Final,isDistinct=false)], output=[col0#159,sum(col1)#177,avg(col2)#178]) +- TungstenAggregate(key=[col0#159], functions=[(sum(col1#160),mode=Partial,isDistinct=false),(avg(col2#161),mode=Partial,isDistinct=false)], output=[col0#159,sum#200,sum#201,count#202L]) +- TungstenExchange hashpartitioning(col0#159,200), None +- InMemoryColumnarTableScan [col0#159,col1#160,col2#161], InMemoryRelation [col0#159,col1#160,col2#161], true, 10000, StorageLevel(true, true, false, true, 1), ConvertToUnsafe, None ``` With opt.: ``` == Physical Plan == TungstenAggregate(key=[col0#159], functions=[(sum(col1#160),mode=Complete,isDistinct=false),(avg(col2#161),mode=Final,isDistinct=false)], output=[col0#159,sum(col1)#177,avg(col2)#178]) +- TungstenExchange hashpartitioning(col0#159,200), None +- InMemoryColumnarTableScan [col0#159,col1#160,col2#161], InMemoryRelation [col0#159,col1#160,col2#161], true, 10000, StorageLevel(true, true, false, true, 1), ConvertToUnsafe, None ``` Author: Takeshi YAMAMURO Closes #10896 from maropu/SkipGroupbySpike. --- .../spark/sql/execution/SparkStrategies.scala | 17 +- .../spark/sql/execution/aggregate/AggUtils.scala | 250 ++++++++++----------- .../sql/execution/aggregate/AggregateExec.scala | 56 +++++ .../execution/aggregate/HashAggregateExec.scala | 22 +- .../execution/aggregate/SortAggregateExec.scala | 24 +- .../execution/exchange/EnsureRequirements.scala | 38 +++- .../org/apache/spark/sql/DataFrameSuite.scala | 15 +- .../apache/spark/sql/execution/PlannerSuite.scala | 59 +++-- 8 files changed, 257 insertions(+), 224 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 4aaf454285..cda3b2b75e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -259,24 +259,17 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } val aggregateOperator = - if (aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) { - if (functionsWithDistinct.nonEmpty) { - sys.error("Distinct columns cannot exist in Aggregate operator containing " + - "aggregate functions which don't support partial aggregation.") - } else { - aggregate.AggUtils.planAggregateWithoutPartial( - groupingExpressions, - aggregateExpressions, - resultExpressions, - planLater(child)) - } - } else if (functionsWithDistinct.isEmpty) { + if (functionsWithDistinct.isEmpty) { aggregate.AggUtils.planAggregateWithoutDistinct( groupingExpressions, aggregateExpressions, resultExpressions, planLater(child)) } else { + if (aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) { + sys.error("Distinct columns cannot exist in Aggregate operator containing " + + "aggregate functions which don't support partial aggregation.") + } aggregate.AggUtils.planAggregateWithOneDistinct( groupingExpressions, functionsWithDistinct, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 4fbb9d554c..fe75ecea17 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -19,34 +19,97 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.physical.Distribution import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.streaming.{StateStoreRestoreExec, StateStoreSaveExec} +/** + * A pattern that finds aggregate operators to support partial aggregations. + */ +object PartialAggregate { + + def unapply(plan: SparkPlan): Option[Distribution] = plan match { + case agg: AggregateExec if AggUtils.supportPartialAggregate(agg.aggregateExpressions) => + Some(agg.requiredChildDistribution.head) + case _ => + None + } +} + /** * Utility functions used by the query planner to convert our plan to new aggregation code path. */ object AggUtils { - def planAggregateWithoutPartial( + def supportPartialAggregate(aggregateExpressions: Seq[AggregateExpression]): Boolean = { + aggregateExpressions.map(_.aggregateFunction).forall(_.supportsPartial) + } + + private def createPartialAggregateExec( groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression], - resultExpressions: Seq[NamedExpression], - child: SparkPlan): Seq[SparkPlan] = { + child: SparkPlan): SparkPlan = { + val groupingAttributes = groupingExpressions.map(_.toAttribute) + val functionsWithDistinct = aggregateExpressions.filter(_.isDistinct) + val partialAggregateExpressions = aggregateExpressions.map { + case agg @ AggregateExpression(_, _, false, _) if functionsWithDistinct.length > 0 => + agg.copy(mode = PartialMerge) + case agg => + agg.copy(mode = Partial) + } + val partialAggregateAttributes = + partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) + val partialResultExpressions = + groupingAttributes ++ + partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) - val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete)) - val completeAggregateAttributes = completeAggregateExpressions.map(_.resultAttribute) - SortAggregateExec( - requiredChildDistributionExpressions = Some(groupingExpressions), + createAggregateExec( + requiredChildDistributionExpressions = None, groupingExpressions = groupingExpressions, - aggregateExpressions = completeAggregateExpressions, - aggregateAttributes = completeAggregateAttributes, - initialInputBufferOffset = 0, - resultExpressions = resultExpressions, - child = child - ) :: Nil + aggregateExpressions = partialAggregateExpressions, + aggregateAttributes = partialAggregateAttributes, + initialInputBufferOffset = if (functionsWithDistinct.length > 0) { + groupingExpressions.length + functionsWithDistinct.head.aggregateFunction.children.length + } else { + 0 + }, + resultExpressions = partialResultExpressions, + child = child) } - private def createAggregate( + private def updateMergeAggregateMode(aggregateExpressions: Seq[AggregateExpression]) = { + def updateMode(mode: AggregateMode) = mode match { + case Partial => PartialMerge + case Complete => Final + case mode => mode + } + aggregateExpressions.map(e => e.copy(mode = updateMode(e.mode))) + } + + /** + * Builds new merge and map-side [[AggregateExec]]s from an input aggregate operator. + * If an aggregation needs a shuffle for satisfying its own distribution and supports partial + * aggregations, a map-side aggregation is appended before the shuffle in + * [[org.apache.spark.sql.execution.exchange.EnsureRequirements]]. + */ + def createMapMergeAggregatePair(operator: SparkPlan): (SparkPlan, SparkPlan) = operator match { + case agg: AggregateExec => + val mapSideAgg = createPartialAggregateExec( + agg.groupingExpressions, agg.aggregateExpressions, agg.child) + val mergeAgg = createAggregateExec( + requiredChildDistributionExpressions = agg.requiredChildDistributionExpressions, + groupingExpressions = agg.groupingExpressions.map(_.toAttribute), + aggregateExpressions = updateMergeAggregateMode(agg.aggregateExpressions), + aggregateAttributes = agg.aggregateAttributes, + initialInputBufferOffset = agg.groupingExpressions.length, + resultExpressions = agg.resultExpressions, + child = mapSideAgg + ) + + (mergeAgg, mapSideAgg) + } + + private def createAggregateExec( requiredChildDistributionExpressions: Option[Seq[Expression]] = None, groupingExpressions: Seq[NamedExpression] = Nil, aggregateExpressions: Seq[AggregateExpression] = Nil, @@ -55,7 +118,8 @@ object AggUtils { resultExpressions: Seq[NamedExpression] = Nil, child: SparkPlan): SparkPlan = { val useHash = HashAggregateExec.supportsAggregate( - aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) && + supportPartialAggregate(aggregateExpressions) if (useHash) { HashAggregateExec( requiredChildDistributionExpressions = requiredChildDistributionExpressions, @@ -82,43 +146,21 @@ object AggUtils { aggregateExpressions: Seq[AggregateExpression], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { - // Check if we can use HashAggregate. - - // 1. Create an Aggregate Operator for partial aggregations. - val groupingAttributes = groupingExpressions.map(_.toAttribute) - val partialAggregateExpressions = aggregateExpressions.map(_.copy(mode = Partial)) - val partialAggregateAttributes = - partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) - val partialResultExpressions = - groupingAttributes ++ - partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) - - val partialAggregate = createAggregate( - requiredChildDistributionExpressions = None, - groupingExpressions = groupingExpressions, - aggregateExpressions = partialAggregateExpressions, - aggregateAttributes = partialAggregateAttributes, - initialInputBufferOffset = 0, - resultExpressions = partialResultExpressions, - child = child) - - // 2. Create an Aggregate Operator for final aggregations. - val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final)) - // The attributes of the final aggregation buffer, which is presented as input to the result - // projection: - val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute) - - val finalAggregate = createAggregate( - requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes, - aggregateExpressions = finalAggregateExpressions, - aggregateAttributes = finalAggregateAttributes, - initialInputBufferOffset = groupingExpressions.length, - resultExpressions = resultExpressions, - child = partialAggregate) + val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete)) + val completeAggregateAttributes = completeAggregateExpressions.map(_.resultAttribute) + val supportPartial = supportPartialAggregate(aggregateExpressions) - finalAggregate :: Nil + createAggregateExec( + requiredChildDistributionExpressions = + Some(if (supportPartial) groupingAttributes else groupingExpressions), + groupingExpressions = groupingExpressions, + aggregateExpressions = completeAggregateExpressions, + aggregateAttributes = completeAggregateAttributes, + initialInputBufferOffset = 0, + resultExpressions = resultExpressions, + child = child + ) :: Nil } def planAggregateWithOneDistinct( @@ -141,39 +183,23 @@ object AggUtils { val distinctAttributes = namedDistinctExpressions.map(_.toAttribute) val groupingAttributes = groupingExpressions.map(_.toAttribute) - // 1. Create an Aggregate Operator for partial aggregations. + // 1. Create an Aggregate Operator for non-distinct aggregations. val partialAggregate: SparkPlan = { val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) - // We will group by the original grouping expression, plus an additional expression for the - // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping - // expressions will be [key, value]. - createAggregate( - groupingExpressions = groupingExpressions ++ namedDistinctExpressions, - aggregateExpressions = aggregateExpressions, - aggregateAttributes = aggregateAttributes, - resultExpressions = groupingAttributes ++ distinctAttributes ++ - aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), - child = child) - } - - // 2. Create an Aggregate Operator for partial merge aggregations. - val partialMergeAggregate: SparkPlan = { - val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) - val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) - createAggregate( + createAggregateExec( requiredChildDistributionExpressions = Some(groupingAttributes ++ distinctAttributes), - groupingExpressions = groupingAttributes ++ distinctAttributes, + groupingExpressions = groupingExpressions ++ namedDistinctExpressions, aggregateExpressions = aggregateExpressions, aggregateAttributes = aggregateAttributes, initialInputBufferOffset = (groupingAttributes ++ distinctAttributes).length, resultExpressions = groupingAttributes ++ distinctAttributes ++ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), - child = partialAggregate) + child = child) } - // 3. Create an Aggregate operator for partial aggregation (for distinct) + // 2. Create an Aggregate Operator for the final aggregation. val distinctColumnAttributeLookup = distinctExpressions.zip(distinctAttributes).toMap val rewrittenDistinctFunctions = functionsWithDistinct.map { // Children of an AggregateFunction with DISTINCT keyword has already @@ -183,38 +209,6 @@ object AggUtils { aggregateFunction.transformDown(distinctColumnAttributeLookup) .asInstanceOf[AggregateFunction] } - - val partialDistinctAggregate: SparkPlan = { - val mergeAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) - // The attributes of the final aggregation buffer, which is presented as input to the result - // projection: - val mergeAggregateAttributes = mergeAggregateExpressions.map(_.resultAttribute) - val (distinctAggregateExpressions, distinctAggregateAttributes) = - rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) => - // We rewrite the aggregate function to a non-distinct aggregation because - // its input will have distinct arguments. - // We just keep the isDistinct setting to true, so when users look at the query plan, - // they still can see distinct aggregations. - val expr = AggregateExpression(func, Partial, isDistinct = true) - // Use original AggregationFunction to lookup attributes, which is used to build - // aggregateFunctionToAttribute - val attr = functionsWithDistinct(i).resultAttribute - (expr, attr) - }.unzip - - val partialAggregateResult = groupingAttributes ++ - mergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) ++ - distinctAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) - createAggregate( - groupingExpressions = groupingAttributes, - aggregateExpressions = mergeAggregateExpressions ++ distinctAggregateExpressions, - aggregateAttributes = mergeAggregateAttributes ++ distinctAggregateAttributes, - initialInputBufferOffset = (groupingAttributes ++ distinctAttributes).length, - resultExpressions = partialAggregateResult, - child = partialMergeAggregate) - } - - // 4. Create an Aggregate Operator for the final aggregation. val finalAndCompleteAggregate: SparkPlan = { val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) // The attributes of the final aggregation buffer, which is presented as input to the result @@ -225,23 +219,23 @@ object AggUtils { rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) => // We rewrite the aggregate function to a non-distinct aggregation because // its input will have distinct arguments. - // We just keep the isDistinct setting to true, so when users look at the query plan, - // they still can see distinct aggregations. - val expr = AggregateExpression(func, Final, isDistinct = true) + // We keep the isDistinct setting to true because this flag is used to generate partial + // aggregations and it is easy to see aggregation types in the query plan. + val expr = AggregateExpression(func, Complete, isDistinct = true) // Use original AggregationFunction to lookup attributes, which is used to build // aggregateFunctionToAttribute val attr = functionsWithDistinct(i).resultAttribute (expr, attr) - }.unzip + }.unzip - createAggregate( + createAggregateExec( requiredChildDistributionExpressions = Some(groupingAttributes), groupingExpressions = groupingAttributes, aggregateExpressions = finalAggregateExpressions ++ distinctAggregateExpressions, aggregateAttributes = finalAggregateAttributes ++ distinctAggregateAttributes, initialInputBufferOffset = groupingAttributes.length, resultExpressions = resultExpressions, - child = partialDistinctAggregate) + child = partialAggregate) } finalAndCompleteAggregate :: Nil @@ -249,13 +243,14 @@ object AggUtils { /** * Plans a streaming aggregation using the following progression: - * - Partial Aggregation - * - Shuffle - * - Partial Merge (now there is at most 1 tuple per group) + * - Partial Aggregation (now there is at most 1 tuple per group) * - StateStoreRestore (now there is 1 tuple from this batch + optionally one from the previous) * - PartialMerge (now there is at most 1 tuple per group) * - StateStoreSave (saves the tuple for the next batch) * - Complete (output the current result of the aggregation) + * + * If the first aggregation needs a shuffle to satisfy its distribution, a map-side partial + * an aggregation and a shuffle are added in `EnsureRequirements`. */ def planStreamingAggregation( groupingExpressions: Seq[NamedExpression], @@ -268,39 +263,24 @@ object AggUtils { val partialAggregate: SparkPlan = { val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) - // We will group by the original grouping expression, plus an additional expression for the - // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping - // expressions will be [key, value]. - createAggregate( - groupingExpressions = groupingExpressions, - aggregateExpressions = aggregateExpressions, - aggregateAttributes = aggregateAttributes, - resultExpressions = groupingAttributes ++ - aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), - child = child) - } - - val partialMerged1: SparkPlan = { - val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) - val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) - createAggregate( + createAggregateExec( requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes, + groupingExpressions = groupingExpressions, aggregateExpressions = aggregateExpressions, aggregateAttributes = aggregateAttributes, initialInputBufferOffset = groupingAttributes.length, resultExpressions = groupingAttributes ++ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), - child = partialAggregate) + child = child) } - val restored = StateStoreRestoreExec(groupingAttributes, None, partialMerged1) + val restored = StateStoreRestoreExec(groupingAttributes, None, partialAggregate) - val partialMerged2: SparkPlan = { + val partialMerged: SparkPlan = { val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) - createAggregate( + createAggregateExec( requiredChildDistributionExpressions = Some(groupingAttributes), groupingExpressions = groupingAttributes, @@ -314,7 +294,7 @@ object AggUtils { // Note: stateId and returnAllStates are filled in later with preparation rules // in IncrementalExecution. val saved = StateStoreSaveExec( - groupingAttributes, stateId = None, returnAllStates = None, partialMerged2) + groupingAttributes, stateId = None, returnAllStates = None, partialMerged) val finalAndCompleteAggregate: SparkPlan = { val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) @@ -322,7 +302,7 @@ object AggUtils { // projection: val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute) - createAggregate( + createAggregateExec( requiredChildDistributionExpressions = Some(groupingAttributes), groupingExpressions = groupingAttributes, aggregateExpressions = finalAggregateExpressions, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala new file mode 100644 index 0000000000..b88a8aa3da --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.UnaryExecNode + +/** + * A base class for aggregate implementation. + */ +abstract class AggregateExec extends UnaryExecNode { + + def requiredChildDistributionExpressions: Option[Seq[Expression]] + def groupingExpressions: Seq[NamedExpression] + def aggregateExpressions: Seq[AggregateExpression] + def aggregateAttributes: Seq[Attribute] + def initialInputBufferOffset: Int + def resultExpressions: Seq[NamedExpression] + + protected[this] val aggregateBufferAttributes = { + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) + } + + override def producedAttributes: AttributeSet = + AttributeSet(aggregateAttributes) ++ + AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++ + AttributeSet(aggregateBufferAttributes) + + override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) + + override def requiredChildDistribution: List[Distribution] = { + requiredChildDistributionExpressions match { + case Some(exprs) if exprs.isEmpty => AllTuples :: Nil + case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil + case None => UnspecifiedDistribution :: Nil + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index bd7efa606e..525c7e301a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.types.{DecimalType, StringType, StructType} @@ -42,11 +41,7 @@ case class HashAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode with CodegenSupport { - - private[this] val aggregateBufferAttributes = { - aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) - } + extends AggregateExec with CodegenSupport { require(HashAggregateExec.supportsAggregate(aggregateBufferAttributes)) @@ -60,21 +55,6 @@ case class HashAggregateExec( "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"), "aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "aggregate time")) - override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) - - override def producedAttributes: AttributeSet = - AttributeSet(aggregateAttributes) ++ - AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++ - AttributeSet(aggregateBufferAttributes) - - override def requiredChildDistribution: List[Distribution] = { - requiredChildDistributionExpressions match { - case Some(exprs) if exprs.isEmpty => AllTuples :: Nil - case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil - case None => UnspecifiedDistribution :: Nil - } - } - // This is for testing. We force TungstenAggregationIterator to fall back to the unsafe row hash // map and/or the sort-based aggregation once it has processed a given number of input rows. private val testFallbackStartsAt: Option[(Int, Int)] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index 2a81a823c4..68f86fca80 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -22,8 +22,7 @@ import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, UnspecifiedDistribution} -import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.Utils @@ -38,30 +37,11 @@ case class SortAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode { - - private[this] val aggregateBufferAttributes = { - aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) - } - - override def producedAttributes: AttributeSet = - AttributeSet(aggregateAttributes) ++ - AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++ - AttributeSet(aggregateBufferAttributes) + extends AggregateExec { override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) - override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) - - override def requiredChildDistribution: List[Distribution] = { - requiredChildDistributionExpressions match { - case Some(exprs) if exprs.isEmpty => AllTuples :: Nil - case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil - case None => UnspecifiedDistribution :: Nil - } - } - override def requiredChildOrdering: Seq[Seq[SortOrder]] = { groupingExpressions.map(SortOrder(_, Ascending)) :: Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 446571aa84..951051c4df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -21,6 +21,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.aggregate.AggUtils +import org.apache.spark.sql.execution.aggregate.PartialAggregate import org.apache.spark.sql.internal.SQLConf /** @@ -151,18 +153,30 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = { val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution val requiredChildOrderings: Seq[Seq[SortOrder]] = operator.requiredChildOrdering - var children: Seq[SparkPlan] = operator.children - assert(requiredChildDistributions.length == children.length) - assert(requiredChildOrderings.length == children.length) + assert(requiredChildDistributions.length == operator.children.length) + assert(requiredChildOrderings.length == operator.children.length) - // Ensure that the operator's children satisfy their output distribution requirements: - children = children.zip(requiredChildDistributions).map { - case (child, distribution) if child.outputPartitioning.satisfies(distribution) => - child - case (child, BroadcastDistribution(mode)) => - BroadcastExchangeExec(mode, child) - case (child, distribution) => - ShuffleExchange(createPartitioning(distribution, defaultNumPreShufflePartitions), child) + def createShuffleExchange(dist: Distribution, child: SparkPlan) = + ShuffleExchange(createPartitioning(dist, defaultNumPreShufflePartitions), child) + + var (parent, children) = operator match { + case PartialAggregate(childDist) if !operator.outputPartitioning.satisfies(childDist) => + // If an aggregation needs a shuffle and support partial aggregations, a map-side partial + // aggregation and a shuffle are added as children. + val (mergeAgg, mapSideAgg) = AggUtils.createMapMergeAggregatePair(operator) + (mergeAgg, createShuffleExchange(requiredChildDistributions.head, mapSideAgg) :: Nil) + case _ => + // Ensure that the operator's children satisfy their output distribution requirements: + val childrenWithDist = operator.children.zip(requiredChildDistributions) + val newChildren = childrenWithDist.map { + case (child, distribution) if child.outputPartitioning.satisfies(distribution) => + child + case (child, BroadcastDistribution(mode)) => + BroadcastExchangeExec(mode, child) + case (child, distribution) => + createShuffleExchange(distribution, child) + } + (operator, newChildren) } // If the operator has multiple children and specifies child output distributions (e.g. join), @@ -246,7 +260,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { } } - operator.withNewChildren(children) + parent.withNewChildren(children) } def apply(plan: SparkPlan): SparkPlan = plan.transformUp { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 499f318037..cd485770d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1248,17 +1248,17 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } /** - * Verifies that there is no Exchange between the Aggregations for `df` + * Verifies that there is a single Aggregation for `df` */ - private def verifyNonExchangingAgg(df: DataFrame) = { + private def verifyNonExchangingSingleAgg(df: DataFrame) = { var atFirstAgg: Boolean = false df.queryExecution.executedPlan.foreach { case agg: HashAggregateExec => - atFirstAgg = !atFirstAgg - case _ => if (atFirstAgg) { - fail("Should not have operators between the two aggregations") + fail("Should not have back to back Aggregates") } + atFirstAgg = true + case _ => } } @@ -1292,9 +1292,10 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { // Group by the column we are distributed by. This should generate a plan with no exchange // between the aggregates val df3 = testData.repartition($"key").groupBy("key").count() - verifyNonExchangingAgg(df3) - verifyNonExchangingAgg(testData.repartition($"key", $"value") + verifyNonExchangingSingleAgg(df3) + verifyNonExchangingSingleAgg(testData.repartition($"key", $"value") .groupBy("key", "value").count()) + verifyNonExchangingSingleAgg(testData.repartition($"key").groupBy("key", "value").count()) // Grouping by just the first distributeBy expr, need to exchange. verifyExchangingAgg(testData.repartition($"key", $"value") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 13490c3567..436ff59c4d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{execution, Row} +import org.apache.spark.sql.{execution, DataFrame, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder} import org.apache.spark.sql.catalyst.plans.Inner @@ -37,36 +37,65 @@ class PlannerSuite extends SharedSQLContext { setupTestData() - private def testPartialAggregationPlan(query: LogicalPlan): Unit = { + private def testPartialAggregationPlan(query: LogicalPlan): Seq[SparkPlan] = { val planner = spark.sessionState.planner import planner._ - val plannedOption = Aggregation(query).headOption - val planned = - plannedOption.getOrElse( - fail(s"Could query play aggregation query $query. Is it an aggregation query?")) - val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n } - - // For the new aggregation code path, there will be four aggregate operator for - // distinct aggregations. - assert( - aggregations.size == 2 || aggregations.size == 4, - s"The plan of query $query does not have partial aggregations.") + val ensureRequirements = EnsureRequirements(spark.sessionState.conf) + val planned = Aggregation(query).headOption.map(ensureRequirements(_)) + .getOrElse(fail(s"Could query play aggregation query $query. Is it an aggregation query?")) + planned.collect { case n if n.nodeName contains "Aggregate" => n } } test("count is partially aggregated") { val query = testData.groupBy('value).agg(count('key)).queryExecution.analyzed - testPartialAggregationPlan(query) + assert(testPartialAggregationPlan(query).size == 2, + s"The plan of query $query does not have partial aggregations.") } test("count distinct is partially aggregated") { val query = testData.groupBy('value).agg(countDistinct('key)).queryExecution.analyzed testPartialAggregationPlan(query) + // For the new aggregation code path, there will be four aggregate operator for distinct + // aggregations. + assert(testPartialAggregationPlan(query).size == 4, + s"The plan of query $query does not have partial aggregations.") } test("mixed aggregates are partially aggregated") { val query = testData.groupBy('value).agg(count('value), countDistinct('key)).queryExecution.analyzed - testPartialAggregationPlan(query) + // For the new aggregation code path, there will be four aggregate operator for distinct + // aggregations. + assert(testPartialAggregationPlan(query).size == 4, + s"The plan of query $query does not have partial aggregations.") + } + + test("non-partial aggregation for aggregates") { + withTempView("testNonPartialAggregation") { + val schema = StructType(StructField(s"value", IntegerType, true) :: Nil) + val row = Row.fromSeq(Seq.fill(1)(null)) + val rowRDD = sparkContext.parallelize(row :: Nil) + spark.createDataFrame(rowRDD, schema).repartition($"value") + .createOrReplaceTempView("testNonPartialAggregation") + + val planned1 = sql("SELECT SUM(value) FROM testNonPartialAggregation GROUP BY value") + .queryExecution.executedPlan + + // If input data are already partitioned and the same columns are used in grouping keys and + // aggregation values, no partial aggregation exist in query plans. + val aggOps1 = planned1.collect { case n if n.nodeName contains "Aggregate" => n } + assert(aggOps1.size == 1, s"The plan $planned1 has partial aggregations.") + + val planned2 = sql( + """ + |SELECT t.value, SUM(DISTINCT t.value) + |FROM (SELECT * FROM testNonPartialAggregation ORDER BY value) t + |GROUP BY t.value + """.stripMargin).queryExecution.executedPlan + + val aggOps2 = planned1.collect { case n if n.nodeName contains "Aggregate" => n } + assert(aggOps2.size == 1, s"The plan $planned2 has partial aggregations.") + } } test("sizeInBytes estimation of limit operator for broadcast hash join optimization") { -- cgit v1.2.3