diff options
author | Josh Rosen <joshrosen@databricks.com> | 2015-10-08 14:53:21 -0700 |
---|---|---|
committer | Yin Huai <yhuai@databricks.com> | 2015-10-08 14:56:27 -0700 |
commit | 2816c89b6a304cb0b5214e14ebbc320158e88260 (patch) | |
tree | 41adbf5368a298b0744a33d11588138b98bae5cb | |
parent | 9e66a53c9955285a85c19f55c3ef62db2e1b868a (diff) | |
download | spark-2816c89b6a304cb0b5214e14ebbc320158e88260.tar.gz spark-2816c89b6a304cb0b5214e14ebbc320158e88260.tar.bz2 spark-2816c89b6a304cb0b5214e14ebbc320158e88260.zip |
[SPARK-10988] [SQL] Reduce duplication in Aggregate2's expression rewriting logic
In `aggregate/utils.scala`, there is a substantial amount of duplication in the expression-rewriting logic. As a prerequisite to supporting imperative aggregate functions in `TungstenAggregate`, this patch refactors this file so that the same expression-rewriting logic is used for both `SortAggregate` and `TungstenAggregate`.
In order to allow both operators to use the same rewriting logic, `TungstenAggregationIterator. generateResultProjection()` has been updated so that it first evaluates all declarative aggregate functions' `evaluateExpression`s and writes the results into a temporary buffer, and then uses this temporary buffer and the grouping expressions to evaluate the final resultExpressions. This matches the logic in SortAggregateIterator, where this two-pass approach is necessary in order to support imperative aggregates. If this change turns out to cause performance regressions, then we can look into re-implementing the single-pass evaluation in a cleaner way as part of a followup patch.
Since the rewriting logic is now shared across both operators, this patch also extracts that logic and places it in `SparkStrategies`. This makes the rewriting logic a bit easier to follow, I think.
Author: Josh Rosen <joshrosen@databricks.com>
Closes #9015 from JoshRosen/SPARK-10988.
5 files changed, 143 insertions, 196 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index d1bbf2e20f..79bd1a4180 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -195,19 +195,22 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { converted match { case None => Nil // Cannot convert to new aggregation code path. case Some(logical.Aggregate(groupingExpressions, resultExpressions, child)) => - // Extracts all distinct aggregate expressions from the resultExpressions. + // A single aggregate expression might appear multiple times in resultExpressions. + // In order to avoid evaluating an individual aggregate function multiple times, we'll + // build a set of the distinct aggregate expressions and build a function which can + // be used to re-write expressions so that they reference the single copy of the + // aggregate function which actually gets computed. val aggregateExpressions = resultExpressions.flatMap { expr => expr.collect { case agg: AggregateExpression2 => agg } - }.toSet.toSeq + }.distinct // For those distinct aggregate expressions, we create a map from the // aggregate function to the corresponding attribute of the function. - val aggregateFunctionMap = aggregateExpressions.map { agg => + val aggregateFunctionToAttribute = aggregateExpressions.map { agg => val aggregateFunction = agg.aggregateFunction - val attribtue = Alias(aggregateFunction, aggregateFunction.toString)().toAttribute - (aggregateFunction, agg.isDistinct) -> - (aggregateFunction -> attribtue) + val attribute = Alias(aggregateFunction, aggregateFunction.toString)().toAttribute + (aggregateFunction, agg.isDistinct) -> attribute }.toMap val (functionsWithDistinct, functionsWithoutDistinct) = @@ -220,6 +223,40 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { "code path.") } + val namedGroupingExpressions = groupingExpressions.map { + case ne: NamedExpression => ne -> ne + // If the expression is not a NamedExpressions, we add an alias. + // So, when we generate the result of the operator, the Aggregate Operator + // can directly get the Seq of attributes representing the grouping expressions. + case other => + val withAlias = Alias(other, other.toString)() + other -> withAlias + } + val groupExpressionMap = namedGroupingExpressions.toMap + + // The original `resultExpressions` are a set of expressions which may reference + // aggregate expressions, grouping column values, and constants. When aggregate operator + // emits output rows, we will use `resultExpressions` to generate an output projection + // which takes the grouping columns and final aggregate result buffer as input. + // Thus, we must re-write the result expressions so that their attributes match up with + // the attributes of the final result projection's input row: + val rewrittenResultExpressions = resultExpressions.map { expr => + expr.transformDown { + case AggregateExpression2(aggregateFunction, _, isDistinct) => + // The final aggregation buffer's attributes will be `finalAggregationAttributes`, + // so replace each aggregate expression by its corresponding attribute in the set: + aggregateFunctionToAttribute(aggregateFunction, isDistinct) + case expression => + // Since we're using `namedGroupingAttributes` to extract the grouping key + // columns, we need to replace grouping key expressions with their corresponding + // attributes. We do not rely on the equality check at here since attributes may + // differ cosmetically. Instead, we use semanticEquals. + groupExpressionMap.collectFirst { + case (expr, ne) if expr semanticEquals expression => ne.toAttribute + }.getOrElse(expression) + }.asInstanceOf[NamedExpression] + } + val aggregateOperator = if (aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) { if (functionsWithDistinct.nonEmpty) { @@ -227,26 +264,26 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { "aggregate functions which don't support partial aggregation.") } else { aggregate.Utils.planAggregateWithoutPartial( - groupingExpressions, + namedGroupingExpressions.map(_._2), aggregateExpressions, - aggregateFunctionMap, - resultExpressions, + aggregateFunctionToAttribute, + rewrittenResultExpressions, planLater(child)) } } else if (functionsWithDistinct.isEmpty) { aggregate.Utils.planAggregateWithoutDistinct( - groupingExpressions, + namedGroupingExpressions.map(_._2), aggregateExpressions, - aggregateFunctionMap, - resultExpressions, + aggregateFunctionToAttribute, + rewrittenResultExpressions, planLater(child)) } else { aggregate.Utils.planAggregateWithOneDistinct( - groupingExpressions, + namedGroupingExpressions.map(_._2), functionsWithDistinct, functionsWithoutDistinct, - aggregateFunctionMap, - resultExpressions, + aggregateFunctionToAttribute, + rewrittenResultExpressions, planLater(child)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 3cd22af305..7b3d072b2e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -31,7 +31,9 @@ case class TungstenAggregate( requiredChildDistributionExpressions: Option[Seq[Expression]], groupingExpressions: Seq[NamedExpression], nonCompleteAggregateExpressions: Seq[AggregateExpression2], + nonCompleteAggregateAttributes: Seq[Attribute], completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateAttributes: Seq[Attribute], resultExpressions: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { @@ -77,7 +79,9 @@ case class TungstenAggregate( new TungstenAggregationIterator( groupingExpressions, nonCompleteAggregateExpressions, + nonCompleteAggregateAttributes, completeAggregateExpressions, + completeAggregateAttributes, resultExpressions, newMutableProjection, child.output, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index a6f4c1d92f..4bb95c9eb7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -60,8 +60,12 @@ import org.apache.spark.sql.types.StructType * @param nonCompleteAggregateExpressions * [[AggregateExpression2]] containing [[AggregateFunction2]]s with mode [[Partial]], * [[PartialMerge]], or [[Final]]. + * @param nonCompleteAggregateAttributes the attributes of the nonCompleteAggregateExpressions' + * outputs when they are stored in the final aggregation buffer. * @param completeAggregateExpressions * [[AggregateExpression2]] containing [[AggregateFunction2]]s with mode [[Complete]]. + * @param completeAggregateAttributes the attributes of completeAggregateExpressions' outputs + * when they are stored in the final aggregation buffer. * @param resultExpressions * expressions for generating output rows. * @param newMutableProjection @@ -72,7 +76,9 @@ import org.apache.spark.sql.types.StructType class TungstenAggregationIterator( groupingExpressions: Seq[NamedExpression], nonCompleteAggregateExpressions: Seq[AggregateExpression2], + nonCompleteAggregateAttributes: Seq[Attribute], completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateAttributes: Seq[Attribute], resultExpressions: Seq[NamedExpression], newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), originalInputAttributes: Seq[Attribute], @@ -280,17 +286,25 @@ class TungstenAggregationIterator( // resultExpressions. case (Some(Final), None) | (Some(Final) | None, Some(Complete)) => val joinedRow = new JoinedRow() + val evalExpressions = allAggregateFunctions.map { + case ae: DeclarativeAggregate => ae.evaluateExpression + // case agg: AggregateFunction2 => Literal.create(null, agg.dataType) + } + val expressionAggEvalProjection = UnsafeProjection.create(evalExpressions, bufferAttributes) + // These are the attributes of the row produced by `expressionAggEvalProjection` + val aggregateResultSchema = nonCompleteAggregateAttributes ++ completeAggregateAttributes val resultProjection = - UnsafeProjection.create(resultExpressions, groupingAttributes ++ bufferAttributes) + UnsafeProjection.create(resultExpressions, groupingAttributes ++ aggregateResultSchema) (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => { - resultProjection(joinedRow(currentGroupingKey, currentBuffer)) + // Generate results for all expression-based aggregate functions. + val aggregateResult = expressionAggEvalProjection.apply(currentBuffer) + resultProjection(joinedRow(currentGroupingKey, aggregateResult)) } // Grouping-only: a output row is generated from values of grouping expressions. case (None, None) => - val resultProjection = - UnsafeProjection.create(resultExpressions, groupingAttributes) + val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes) (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => { resultProjection(currentGroupingKey) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala index e1c2d9475a..cf6e7ed0d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution.aggregate -import scala.collection.mutable - import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan} @@ -38,60 +36,35 @@ object Utils { } def planAggregateWithoutPartial( - groupingExpressions: Seq[Expression], + groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression2], - aggregateFunctionMap: Map[(AggregateFunction2, Boolean), (AggregateFunction2, Attribute)], + aggregateFunctionToAttribute: Map[(AggregateFunction2, Boolean), Attribute], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { - val namedGroupingExpressions = groupingExpressions.map { - case ne: NamedExpression => ne -> ne - // If the expression is not a NamedExpressions, we add an alias. - // So, when we generate the result of the operator, the Aggregate Operator - // can directly get the Seq of attributes representing the grouping expressions. - case other => - val withAlias = Alias(other, other.toString)() - other -> withAlias - } - val groupExpressionMap = namedGroupingExpressions.toMap - val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute) - + val groupingAttributes = groupingExpressions.map(_.toAttribute) val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete)) - val completeAggregateAttributes = - completeAggregateExpressions.map { - expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)._2 - } - - val rewrittenResultExpressions = resultExpressions.map { expr => - expr.transformDown { - case agg: AggregateExpression2 => - aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct)._2 - case expression => - // We do not rely on the equality check at here since attributes may - // different cosmetically. Instead, we use semanticEquals. - groupExpressionMap.collectFirst { - case (expr, ne) if expr semanticEquals expression => ne.toAttribute - }.getOrElse(expression) - }.asInstanceOf[NamedExpression] + val completeAggregateAttributes = completeAggregateExpressions.map { + expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) } SortBasedAggregate( - requiredChildDistributionExpressions = Some(namedGroupingAttributes), - groupingExpressions = namedGroupingExpressions.map(_._2), + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes, nonCompleteAggregateExpressions = Nil, nonCompleteAggregateAttributes = Nil, completeAggregateExpressions = completeAggregateExpressions, completeAggregateAttributes = completeAggregateAttributes, initialInputBufferOffset = 0, - resultExpressions = rewrittenResultExpressions, + resultExpressions = resultExpressions, child = child ) :: Nil } def planAggregateWithoutDistinct( - groupingExpressions: Seq[Expression], + groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression2], - aggregateFunctionMap: Map[(AggregateFunction2, Boolean), (AggregateFunction2, Attribute)], + aggregateFunctionToAttribute: Map[(AggregateFunction2, Boolean), Attribute], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { // Check if we can use TungstenAggregate. @@ -104,36 +77,29 @@ object Utils { // 1. Create an Aggregate Operator for partial aggregations. - val namedGroupingExpressions = groupingExpressions.map { - case ne: NamedExpression => ne -> ne - // If the expression is not a NamedExpressions, we add an alias. - // So, when we generate the result of the operator, the Aggregate Operator - // can directly get the Seq of attributes representing the grouping expressions. - case other => - val withAlias = Alias(other, other.toString)() - other -> withAlias - } - val groupExpressionMap = namedGroupingExpressions.toMap - val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute) + + val groupingAttributes = groupingExpressions.map(_.toAttribute) val partialAggregateExpressions = aggregateExpressions.map(_.copy(mode = Partial)) val partialAggregateAttributes = partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) val partialResultExpressions = - namedGroupingAttributes ++ + groupingAttributes ++ partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) val partialAggregate = if (usesTungstenAggregate) { TungstenAggregate( requiredChildDistributionExpressions = None: Option[Seq[Expression]], - groupingExpressions = namedGroupingExpressions.map(_._2), + groupingExpressions = groupingExpressions, nonCompleteAggregateExpressions = partialAggregateExpressions, + nonCompleteAggregateAttributes = partialAggregateAttributes, completeAggregateExpressions = Nil, + completeAggregateAttributes = Nil, resultExpressions = partialResultExpressions, child = child) } else { SortBasedAggregate( requiredChildDistributionExpressions = None: Option[Seq[Expression]], - groupingExpressions = namedGroupingExpressions.map(_._2), + groupingExpressions = groupingExpressions, nonCompleteAggregateExpressions = partialAggregateExpressions, nonCompleteAggregateAttributes = partialAggregateAttributes, completeAggregateExpressions = Nil, @@ -145,58 +111,32 @@ object Utils { // 2. Create an Aggregate Operator for final aggregations. val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final)) - val finalAggregateAttributes = - finalAggregateExpressions.map { - expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)._2 - } + // The attributes of the final aggregation buffer, which is presented as input to the result + // projection: + val finalAggregateAttributes = finalAggregateExpressions.map { + expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) + } val finalAggregate = if (usesTungstenAggregate) { - val rewrittenResultExpressions = resultExpressions.map { expr => - expr.transformDown { - case agg: AggregateExpression2 => - // aggregateFunctionMap contains unique aggregate functions. - val aggregateFunction = - aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct)._1 - aggregateFunction.asInstanceOf[DeclarativeAggregate].evaluateExpression - case expression => - // We do not rely on the equality check at here since attributes may - // different cosmetically. Instead, we use semanticEquals. - groupExpressionMap.collectFirst { - case (expr, ne) if expr semanticEquals expression => ne.toAttribute - }.getOrElse(expression) - }.asInstanceOf[NamedExpression] - } - TungstenAggregate( - requiredChildDistributionExpressions = Some(namedGroupingAttributes), - groupingExpressions = namedGroupingAttributes, + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes, nonCompleteAggregateExpressions = finalAggregateExpressions, + nonCompleteAggregateAttributes = finalAggregateAttributes, completeAggregateExpressions = Nil, - resultExpressions = rewrittenResultExpressions, + completeAggregateAttributes = Nil, + resultExpressions = resultExpressions, child = partialAggregate) } else { - val rewrittenResultExpressions = resultExpressions.map { expr => - expr.transformDown { - case agg: AggregateExpression2 => - aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct)._2 - case expression => - // We do not rely on the equality check at here since attributes may - // different cosmetically. Instead, we use semanticEquals. - groupExpressionMap.collectFirst { - case (expr, ne) if expr semanticEquals expression => ne.toAttribute - }.getOrElse(expression) - }.asInstanceOf[NamedExpression] - } - SortBasedAggregate( - requiredChildDistributionExpressions = Some(namedGroupingAttributes), - groupingExpressions = namedGroupingAttributes, + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes, nonCompleteAggregateExpressions = finalAggregateExpressions, nonCompleteAggregateAttributes = finalAggregateAttributes, completeAggregateExpressions = Nil, completeAggregateAttributes = Nil, - initialInputBufferOffset = namedGroupingAttributes.length, - resultExpressions = rewrittenResultExpressions, + initialInputBufferOffset = groupingExpressions.length, + resultExpressions = resultExpressions, child = partialAggregate) } @@ -204,10 +144,10 @@ object Utils { } def planAggregateWithOneDistinct( - groupingExpressions: Seq[Expression], + groupingExpressions: Seq[NamedExpression], functionsWithDistinct: Seq[AggregateExpression2], functionsWithoutDistinct: Seq[AggregateExpression2], - aggregateFunctionMap: Map[(AggregateFunction2, Boolean), (AggregateFunction2, Attribute)], + aggregateFunctionToAttribute: Map[(AggregateFunction2, Boolean), Attribute], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { @@ -221,20 +161,7 @@ object Utils { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) // 1. Create an Aggregate Operator for partial aggregations. - // The grouping expressions are original groupingExpressions and - // distinct columns. For example, for avg(distinct value) ... group by key - // the grouping expressions of this Aggregate Operator will be [key, value]. - val namedGroupingExpressions = groupingExpressions.map { - case ne: NamedExpression => ne -> ne - // If the expression is not a NamedExpressions, we add an alias. - // So, when we generate the result of the operator, the Aggregate Operator - // can directly get the Seq of attributes representing the grouping expressions. - case other => - val withAlias = Alias(other, other.toString)() - other -> withAlias - } - val groupExpressionMap = namedGroupingExpressions.toMap - val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute) + val groupingAttributes = groupingExpressions.map(_.toAttribute) // It is safe to call head at here since functionsWithDistinct has at least one // AggregateExpression2. @@ -253,22 +180,27 @@ object Utils { val partialAggregateAttributes = partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) val partialAggregateGroupingExpressions = - (namedGroupingExpressions ++ namedDistinctColumnExpressions).map(_._2) + groupingExpressions ++ namedDistinctColumnExpressions.map(_._2) val partialAggregateResult = - namedGroupingAttributes ++ + groupingAttributes ++ distinctColumnAttributes ++ partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) val partialAggregate = if (usesTungstenAggregate) { TungstenAggregate( - requiredChildDistributionExpressions = None: Option[Seq[Expression]], + requiredChildDistributionExpressions = None, + // The grouping expressions are original groupingExpressions and + // distinct columns. For example, for avg(distinct value) ... group by key + // the grouping expressions of this Aggregate Operator will be [key, value]. groupingExpressions = partialAggregateGroupingExpressions, nonCompleteAggregateExpressions = partialAggregateExpressions, + nonCompleteAggregateAttributes = partialAggregateAttributes, completeAggregateExpressions = Nil, + completeAggregateAttributes = Nil, resultExpressions = partialAggregateResult, child = child) } else { SortBasedAggregate( - requiredChildDistributionExpressions = None: Option[Seq[Expression]], + requiredChildDistributionExpressions = None, groupingExpressions = partialAggregateGroupingExpressions, nonCompleteAggregateExpressions = partialAggregateExpressions, nonCompleteAggregateAttributes = partialAggregateAttributes, @@ -284,41 +216,40 @@ object Utils { val partialMergeAggregateAttributes = partialMergeAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) val partialMergeAggregateResult = - namedGroupingAttributes ++ + groupingAttributes ++ distinctColumnAttributes ++ partialMergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) val partialMergeAggregate = if (usesTungstenAggregate) { TungstenAggregate( - requiredChildDistributionExpressions = Some(namedGroupingAttributes), - groupingExpressions = namedGroupingAttributes ++ distinctColumnAttributes, + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes ++ distinctColumnAttributes, nonCompleteAggregateExpressions = partialMergeAggregateExpressions, + nonCompleteAggregateAttributes = partialMergeAggregateAttributes, completeAggregateExpressions = Nil, + completeAggregateAttributes = Nil, resultExpressions = partialMergeAggregateResult, child = partialAggregate) } else { SortBasedAggregate( - requiredChildDistributionExpressions = Some(namedGroupingAttributes), - groupingExpressions = namedGroupingAttributes ++ distinctColumnAttributes, + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes ++ distinctColumnAttributes, nonCompleteAggregateExpressions = partialMergeAggregateExpressions, nonCompleteAggregateAttributes = partialMergeAggregateAttributes, completeAggregateExpressions = Nil, completeAggregateAttributes = Nil, - initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length, + initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length, resultExpressions = partialMergeAggregateResult, child = partialAggregate) } // 3. Create an Aggregate Operator for partial merge aggregations. val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) - val finalAggregateAttributes = - finalAggregateExpressions.map { - expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)._2 - } - // Create a map to store those rewritten aggregate functions. We always need to use - // both function and its corresponding isDistinct flag as the key because function itself - // does not knows if it is has distinct keyword or now. - val rewrittenAggregateFunctions = - mutable.Map.empty[(AggregateFunction2, Boolean), AggregateFunction2] + // The attributes of the final aggregation buffer, which is presented as input to the result + // projection: + val finalAggregateAttributes = finalAggregateExpressions.map { + expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) + } + val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map { // Children of an AggregateFunction with DISTINCT keyword has already // been evaluated. At here, we need to replace original children @@ -328,9 +259,6 @@ object Utils { case expr if distinctColumnExpressionMap.contains(expr) => distinctColumnExpressionMap(expr).toAttribute }.asInstanceOf[AggregateFunction2] - // Because we have rewritten the aggregate function, we use rewrittenAggregateFunctions - // to track the old version and the new version of this function. - rewrittenAggregateFunctions += (aggregateFunction, true) -> rewrittenAggregateFunction // We rewrite the aggregate function to a non-distinct aggregation because // its input will have distinct arguments. // We just keep the isDistinct setting to true, so when users look at the query plan, @@ -338,66 +266,30 @@ object Utils { val rewrittenAggregateExpression = AggregateExpression2(rewrittenAggregateFunction, Complete, true) - val aggregateFunctionAttribute = - aggregateFunctionMap(agg.aggregateFunction, true)._2 - (rewrittenAggregateExpression -> aggregateFunctionAttribute) + val aggregateFunctionAttribute = aggregateFunctionToAttribute(agg.aggregateFunction, true) + (rewrittenAggregateExpression, aggregateFunctionAttribute) }.unzip val finalAndCompleteAggregate = if (usesTungstenAggregate) { - val rewrittenResultExpressions = resultExpressions.map { expr => - expr.transform { - case agg: AggregateExpression2 => - val function = agg.aggregateFunction - val isDistinct = agg.isDistinct - val aggregateFunction = - if (rewrittenAggregateFunctions.contains(function, isDistinct)) { - // If this function has been rewritten, we get the rewritten version from - // rewrittenAggregateFunctions. - rewrittenAggregateFunctions(function, isDistinct) - } else { - // Oterwise, we get it from aggregateFunctionMap, which contains unique - // aggregate functions that have not been rewritten. - aggregateFunctionMap(function, isDistinct)._1 - } - aggregateFunction.asInstanceOf[DeclarativeAggregate].evaluateExpression - case expression => - // We do not rely on the equality check at here since attributes may - // different cosmetically. Instead, we use semanticEquals. - groupExpressionMap.collectFirst { - case (expr, ne) if expr semanticEquals expression => ne.toAttribute - }.getOrElse(expression) - }.asInstanceOf[NamedExpression] - } - TungstenAggregate( - requiredChildDistributionExpressions = Some(namedGroupingAttributes), - groupingExpressions = namedGroupingAttributes, + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes, nonCompleteAggregateExpressions = finalAggregateExpressions, + nonCompleteAggregateAttributes = finalAggregateAttributes, completeAggregateExpressions = completeAggregateExpressions, - resultExpressions = rewrittenResultExpressions, + completeAggregateAttributes = completeAggregateAttributes, + resultExpressions = resultExpressions, child = partialMergeAggregate) } else { - val rewrittenResultExpressions = resultExpressions.map { expr => - expr.transform { - case agg: AggregateExpression2 => - aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct)._2 - case expression => - // We do not rely on the equality check at here since attributes may - // different cosmetically. Instead, we use semanticEquals. - groupExpressionMap.collectFirst { - case (expr, ne) if expr semanticEquals expression => ne.toAttribute - }.getOrElse(expression) - }.asInstanceOf[NamedExpression] - } SortBasedAggregate( - requiredChildDistributionExpressions = Some(namedGroupingAttributes), - groupingExpressions = namedGroupingAttributes, + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes, nonCompleteAggregateExpressions = finalAggregateExpressions, nonCompleteAggregateAttributes = finalAggregateAttributes, completeAggregateExpressions = completeAggregateExpressions, completeAggregateAttributes = completeAggregateAttributes, - initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length, - resultExpressions = rewrittenResultExpressions, + initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length, + resultExpressions = resultExpressions, child = partialMergeAggregate) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala index 7ca677a6c7..ed974b3a53 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala @@ -38,7 +38,7 @@ class TungstenAggregationIteratorSuite extends SparkFunSuite with SharedSQLConte () => new InterpretedMutableProjection(expr, schema) } val dummyAccum = SQLMetrics.createLongMetric(sparkContext, "dummy") - iter = new TungstenAggregationIterator(Seq.empty, Seq.empty, Seq.empty, + iter = new TungstenAggregationIterator(Seq.empty, Seq.empty, Seq.empty, Seq.empty, Seq.empty, Seq.empty, newMutableProjection, Seq.empty, None, dummyAccum, dummyAccum) val numPages = iter.getHashMap.getNumDataPages assert(numPages === 1) |