aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@databricks.com>2015-10-08 14:53:21 -0700
committerYin Huai <yhuai@databricks.com>2015-10-08 14:56:27 -0700
commit2816c89b6a304cb0b5214e14ebbc320158e88260 (patch)
tree41adbf5368a298b0744a33d11588138b98bae5cb /sql
parent9e66a53c9955285a85c19f55c3ef62db2e1b868a (diff)
downloadspark-2816c89b6a304cb0b5214e14ebbc320158e88260.tar.gz
spark-2816c89b6a304cb0b5214e14ebbc320158e88260.tar.bz2
spark-2816c89b6a304cb0b5214e14ebbc320158e88260.zip
[SPARK-10988] [SQL] Reduce duplication in Aggregate2's expression rewriting logic
In `aggregate/utils.scala`, there is a substantial amount of duplication in the expression-rewriting logic. As a prerequisite to supporting imperative aggregate functions in `TungstenAggregate`, this patch refactors this file so that the same expression-rewriting logic is used for both `SortAggregate` and `TungstenAggregate`. In order to allow both operators to use the same rewriting logic, `TungstenAggregationIterator. generateResultProjection()` has been updated so that it first evaluates all declarative aggregate functions' `evaluateExpression`s and writes the results into a temporary buffer, and then uses this temporary buffer and the grouping expressions to evaluate the final resultExpressions. This matches the logic in SortAggregateIterator, where this two-pass approach is necessary in order to support imperative aggregates. If this change turns out to cause performance regressions, then we can look into re-implementing the single-pass evaluation in a cleaner way as part of a followup patch. Since the rewriting logic is now shared across both operators, this patch also extracts that logic and places it in `SparkStrategies`. This makes the rewriting logic a bit easier to follow, I think. Author: Josh Rosen <joshrosen@databricks.com> Closes #9015 from JoshRosen/SPARK-10988.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala67
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala22
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala244
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala2
5 files changed, 143 insertions, 196 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index d1bbf2e20f..79bd1a4180 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -195,19 +195,22 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
converted match {
case None => Nil // Cannot convert to new aggregation code path.
case Some(logical.Aggregate(groupingExpressions, resultExpressions, child)) =>
- // Extracts all distinct aggregate expressions from the resultExpressions.
+ // A single aggregate expression might appear multiple times in resultExpressions.
+ // In order to avoid evaluating an individual aggregate function multiple times, we'll
+ // build a set of the distinct aggregate expressions and build a function which can
+ // be used to re-write expressions so that they reference the single copy of the
+ // aggregate function which actually gets computed.
val aggregateExpressions = resultExpressions.flatMap { expr =>
expr.collect {
case agg: AggregateExpression2 => agg
}
- }.toSet.toSeq
+ }.distinct
// For those distinct aggregate expressions, we create a map from the
// aggregate function to the corresponding attribute of the function.
- val aggregateFunctionMap = aggregateExpressions.map { agg =>
+ val aggregateFunctionToAttribute = aggregateExpressions.map { agg =>
val aggregateFunction = agg.aggregateFunction
- val attribtue = Alias(aggregateFunction, aggregateFunction.toString)().toAttribute
- (aggregateFunction, agg.isDistinct) ->
- (aggregateFunction -> attribtue)
+ val attribute = Alias(aggregateFunction, aggregateFunction.toString)().toAttribute
+ (aggregateFunction, agg.isDistinct) -> attribute
}.toMap
val (functionsWithDistinct, functionsWithoutDistinct) =
@@ -220,6 +223,40 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
"code path.")
}
+ val namedGroupingExpressions = groupingExpressions.map {
+ case ne: NamedExpression => ne -> ne
+ // If the expression is not a NamedExpressions, we add an alias.
+ // So, when we generate the result of the operator, the Aggregate Operator
+ // can directly get the Seq of attributes representing the grouping expressions.
+ case other =>
+ val withAlias = Alias(other, other.toString)()
+ other -> withAlias
+ }
+ val groupExpressionMap = namedGroupingExpressions.toMap
+
+ // The original `resultExpressions` are a set of expressions which may reference
+ // aggregate expressions, grouping column values, and constants. When aggregate operator
+ // emits output rows, we will use `resultExpressions` to generate an output projection
+ // which takes the grouping columns and final aggregate result buffer as input.
+ // Thus, we must re-write the result expressions so that their attributes match up with
+ // the attributes of the final result projection's input row:
+ val rewrittenResultExpressions = resultExpressions.map { expr =>
+ expr.transformDown {
+ case AggregateExpression2(aggregateFunction, _, isDistinct) =>
+ // The final aggregation buffer's attributes will be `finalAggregationAttributes`,
+ // so replace each aggregate expression by its corresponding attribute in the set:
+ aggregateFunctionToAttribute(aggregateFunction, isDistinct)
+ case expression =>
+ // Since we're using `namedGroupingAttributes` to extract the grouping key
+ // columns, we need to replace grouping key expressions with their corresponding
+ // attributes. We do not rely on the equality check at here since attributes may
+ // differ cosmetically. Instead, we use semanticEquals.
+ groupExpressionMap.collectFirst {
+ case (expr, ne) if expr semanticEquals expression => ne.toAttribute
+ }.getOrElse(expression)
+ }.asInstanceOf[NamedExpression]
+ }
+
val aggregateOperator =
if (aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) {
if (functionsWithDistinct.nonEmpty) {
@@ -227,26 +264,26 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
"aggregate functions which don't support partial aggregation.")
} else {
aggregate.Utils.planAggregateWithoutPartial(
- groupingExpressions,
+ namedGroupingExpressions.map(_._2),
aggregateExpressions,
- aggregateFunctionMap,
- resultExpressions,
+ aggregateFunctionToAttribute,
+ rewrittenResultExpressions,
planLater(child))
}
} else if (functionsWithDistinct.isEmpty) {
aggregate.Utils.planAggregateWithoutDistinct(
- groupingExpressions,
+ namedGroupingExpressions.map(_._2),
aggregateExpressions,
- aggregateFunctionMap,
- resultExpressions,
+ aggregateFunctionToAttribute,
+ rewrittenResultExpressions,
planLater(child))
} else {
aggregate.Utils.planAggregateWithOneDistinct(
- groupingExpressions,
+ namedGroupingExpressions.map(_._2),
functionsWithDistinct,
functionsWithoutDistinct,
- aggregateFunctionMap,
- resultExpressions,
+ aggregateFunctionToAttribute,
+ rewrittenResultExpressions,
planLater(child))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index 3cd22af305..7b3d072b2e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -31,7 +31,9 @@ case class TungstenAggregate(
requiredChildDistributionExpressions: Option[Seq[Expression]],
groupingExpressions: Seq[NamedExpression],
nonCompleteAggregateExpressions: Seq[AggregateExpression2],
+ nonCompleteAggregateAttributes: Seq[Attribute],
completeAggregateExpressions: Seq[AggregateExpression2],
+ completeAggregateAttributes: Seq[Attribute],
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
extends UnaryNode {
@@ -77,7 +79,9 @@ case class TungstenAggregate(
new TungstenAggregationIterator(
groupingExpressions,
nonCompleteAggregateExpressions,
+ nonCompleteAggregateAttributes,
completeAggregateExpressions,
+ completeAggregateAttributes,
resultExpressions,
newMutableProjection,
child.output,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
index a6f4c1d92f..4bb95c9eb7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
@@ -60,8 +60,12 @@ import org.apache.spark.sql.types.StructType
* @param nonCompleteAggregateExpressions
* [[AggregateExpression2]] containing [[AggregateFunction2]]s with mode [[Partial]],
* [[PartialMerge]], or [[Final]].
+ * @param nonCompleteAggregateAttributes the attributes of the nonCompleteAggregateExpressions'
+ * outputs when they are stored in the final aggregation buffer.
* @param completeAggregateExpressions
* [[AggregateExpression2]] containing [[AggregateFunction2]]s with mode [[Complete]].
+ * @param completeAggregateAttributes the attributes of completeAggregateExpressions' outputs
+ * when they are stored in the final aggregation buffer.
* @param resultExpressions
* expressions for generating output rows.
* @param newMutableProjection
@@ -72,7 +76,9 @@ import org.apache.spark.sql.types.StructType
class TungstenAggregationIterator(
groupingExpressions: Seq[NamedExpression],
nonCompleteAggregateExpressions: Seq[AggregateExpression2],
+ nonCompleteAggregateAttributes: Seq[Attribute],
completeAggregateExpressions: Seq[AggregateExpression2],
+ completeAggregateAttributes: Seq[Attribute],
resultExpressions: Seq[NamedExpression],
newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
originalInputAttributes: Seq[Attribute],
@@ -280,17 +286,25 @@ class TungstenAggregationIterator(
// resultExpressions.
case (Some(Final), None) | (Some(Final) | None, Some(Complete)) =>
val joinedRow = new JoinedRow()
+ val evalExpressions = allAggregateFunctions.map {
+ case ae: DeclarativeAggregate => ae.evaluateExpression
+ // case agg: AggregateFunction2 => Literal.create(null, agg.dataType)
+ }
+ val expressionAggEvalProjection = UnsafeProjection.create(evalExpressions, bufferAttributes)
+ // These are the attributes of the row produced by `expressionAggEvalProjection`
+ val aggregateResultSchema = nonCompleteAggregateAttributes ++ completeAggregateAttributes
val resultProjection =
- UnsafeProjection.create(resultExpressions, groupingAttributes ++ bufferAttributes)
+ UnsafeProjection.create(resultExpressions, groupingAttributes ++ aggregateResultSchema)
(currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => {
- resultProjection(joinedRow(currentGroupingKey, currentBuffer))
+ // Generate results for all expression-based aggregate functions.
+ val aggregateResult = expressionAggEvalProjection.apply(currentBuffer)
+ resultProjection(joinedRow(currentGroupingKey, aggregateResult))
}
// Grouping-only: a output row is generated from values of grouping expressions.
case (None, None) =>
- val resultProjection =
- UnsafeProjection.create(resultExpressions, groupingAttributes)
+ val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes)
(currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => {
resultProjection(currentGroupingKey)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
index e1c2d9475a..cf6e7ed0d3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.execution.aggregate
-import scala.collection.mutable
-
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan}
@@ -38,60 +36,35 @@ object Utils {
}
def planAggregateWithoutPartial(
- groupingExpressions: Seq[Expression],
+ groupingExpressions: Seq[NamedExpression],
aggregateExpressions: Seq[AggregateExpression2],
- aggregateFunctionMap: Map[(AggregateFunction2, Boolean), (AggregateFunction2, Attribute)],
+ aggregateFunctionToAttribute: Map[(AggregateFunction2, Boolean), Attribute],
resultExpressions: Seq[NamedExpression],
child: SparkPlan): Seq[SparkPlan] = {
- val namedGroupingExpressions = groupingExpressions.map {
- case ne: NamedExpression => ne -> ne
- // If the expression is not a NamedExpressions, we add an alias.
- // So, when we generate the result of the operator, the Aggregate Operator
- // can directly get the Seq of attributes representing the grouping expressions.
- case other =>
- val withAlias = Alias(other, other.toString)()
- other -> withAlias
- }
- val groupExpressionMap = namedGroupingExpressions.toMap
- val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute)
-
+ val groupingAttributes = groupingExpressions.map(_.toAttribute)
val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete))
- val completeAggregateAttributes =
- completeAggregateExpressions.map {
- expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)._2
- }
-
- val rewrittenResultExpressions = resultExpressions.map { expr =>
- expr.transformDown {
- case agg: AggregateExpression2 =>
- aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct)._2
- case expression =>
- // We do not rely on the equality check at here since attributes may
- // different cosmetically. Instead, we use semanticEquals.
- groupExpressionMap.collectFirst {
- case (expr, ne) if expr semanticEquals expression => ne.toAttribute
- }.getOrElse(expression)
- }.asInstanceOf[NamedExpression]
+ val completeAggregateAttributes = completeAggregateExpressions.map {
+ expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct)
}
SortBasedAggregate(
- requiredChildDistributionExpressions = Some(namedGroupingAttributes),
- groupingExpressions = namedGroupingExpressions.map(_._2),
+ requiredChildDistributionExpressions = Some(groupingAttributes),
+ groupingExpressions = groupingAttributes,
nonCompleteAggregateExpressions = Nil,
nonCompleteAggregateAttributes = Nil,
completeAggregateExpressions = completeAggregateExpressions,
completeAggregateAttributes = completeAggregateAttributes,
initialInputBufferOffset = 0,
- resultExpressions = rewrittenResultExpressions,
+ resultExpressions = resultExpressions,
child = child
) :: Nil
}
def planAggregateWithoutDistinct(
- groupingExpressions: Seq[Expression],
+ groupingExpressions: Seq[NamedExpression],
aggregateExpressions: Seq[AggregateExpression2],
- aggregateFunctionMap: Map[(AggregateFunction2, Boolean), (AggregateFunction2, Attribute)],
+ aggregateFunctionToAttribute: Map[(AggregateFunction2, Boolean), Attribute],
resultExpressions: Seq[NamedExpression],
child: SparkPlan): Seq[SparkPlan] = {
// Check if we can use TungstenAggregate.
@@ -104,36 +77,29 @@ object Utils {
// 1. Create an Aggregate Operator for partial aggregations.
- val namedGroupingExpressions = groupingExpressions.map {
- case ne: NamedExpression => ne -> ne
- // If the expression is not a NamedExpressions, we add an alias.
- // So, when we generate the result of the operator, the Aggregate Operator
- // can directly get the Seq of attributes representing the grouping expressions.
- case other =>
- val withAlias = Alias(other, other.toString)()
- other -> withAlias
- }
- val groupExpressionMap = namedGroupingExpressions.toMap
- val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute)
+
+ val groupingAttributes = groupingExpressions.map(_.toAttribute)
val partialAggregateExpressions = aggregateExpressions.map(_.copy(mode = Partial))
val partialAggregateAttributes =
partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
val partialResultExpressions =
- namedGroupingAttributes ++
+ groupingAttributes ++
partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
val partialAggregate = if (usesTungstenAggregate) {
TungstenAggregate(
requiredChildDistributionExpressions = None: Option[Seq[Expression]],
- groupingExpressions = namedGroupingExpressions.map(_._2),
+ groupingExpressions = groupingExpressions,
nonCompleteAggregateExpressions = partialAggregateExpressions,
+ nonCompleteAggregateAttributes = partialAggregateAttributes,
completeAggregateExpressions = Nil,
+ completeAggregateAttributes = Nil,
resultExpressions = partialResultExpressions,
child = child)
} else {
SortBasedAggregate(
requiredChildDistributionExpressions = None: Option[Seq[Expression]],
- groupingExpressions = namedGroupingExpressions.map(_._2),
+ groupingExpressions = groupingExpressions,
nonCompleteAggregateExpressions = partialAggregateExpressions,
nonCompleteAggregateAttributes = partialAggregateAttributes,
completeAggregateExpressions = Nil,
@@ -145,58 +111,32 @@ object Utils {
// 2. Create an Aggregate Operator for final aggregations.
val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final))
- val finalAggregateAttributes =
- finalAggregateExpressions.map {
- expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)._2
- }
+ // The attributes of the final aggregation buffer, which is presented as input to the result
+ // projection:
+ val finalAggregateAttributes = finalAggregateExpressions.map {
+ expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct)
+ }
val finalAggregate = if (usesTungstenAggregate) {
- val rewrittenResultExpressions = resultExpressions.map { expr =>
- expr.transformDown {
- case agg: AggregateExpression2 =>
- // aggregateFunctionMap contains unique aggregate functions.
- val aggregateFunction =
- aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct)._1
- aggregateFunction.asInstanceOf[DeclarativeAggregate].evaluateExpression
- case expression =>
- // We do not rely on the equality check at here since attributes may
- // different cosmetically. Instead, we use semanticEquals.
- groupExpressionMap.collectFirst {
- case (expr, ne) if expr semanticEquals expression => ne.toAttribute
- }.getOrElse(expression)
- }.asInstanceOf[NamedExpression]
- }
-
TungstenAggregate(
- requiredChildDistributionExpressions = Some(namedGroupingAttributes),
- groupingExpressions = namedGroupingAttributes,
+ requiredChildDistributionExpressions = Some(groupingAttributes),
+ groupingExpressions = groupingAttributes,
nonCompleteAggregateExpressions = finalAggregateExpressions,
+ nonCompleteAggregateAttributes = finalAggregateAttributes,
completeAggregateExpressions = Nil,
- resultExpressions = rewrittenResultExpressions,
+ completeAggregateAttributes = Nil,
+ resultExpressions = resultExpressions,
child = partialAggregate)
} else {
- val rewrittenResultExpressions = resultExpressions.map { expr =>
- expr.transformDown {
- case agg: AggregateExpression2 =>
- aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct)._2
- case expression =>
- // We do not rely on the equality check at here since attributes may
- // different cosmetically. Instead, we use semanticEquals.
- groupExpressionMap.collectFirst {
- case (expr, ne) if expr semanticEquals expression => ne.toAttribute
- }.getOrElse(expression)
- }.asInstanceOf[NamedExpression]
- }
-
SortBasedAggregate(
- requiredChildDistributionExpressions = Some(namedGroupingAttributes),
- groupingExpressions = namedGroupingAttributes,
+ requiredChildDistributionExpressions = Some(groupingAttributes),
+ groupingExpressions = groupingAttributes,
nonCompleteAggregateExpressions = finalAggregateExpressions,
nonCompleteAggregateAttributes = finalAggregateAttributes,
completeAggregateExpressions = Nil,
completeAggregateAttributes = Nil,
- initialInputBufferOffset = namedGroupingAttributes.length,
- resultExpressions = rewrittenResultExpressions,
+ initialInputBufferOffset = groupingExpressions.length,
+ resultExpressions = resultExpressions,
child = partialAggregate)
}
@@ -204,10 +144,10 @@ object Utils {
}
def planAggregateWithOneDistinct(
- groupingExpressions: Seq[Expression],
+ groupingExpressions: Seq[NamedExpression],
functionsWithDistinct: Seq[AggregateExpression2],
functionsWithoutDistinct: Seq[AggregateExpression2],
- aggregateFunctionMap: Map[(AggregateFunction2, Boolean), (AggregateFunction2, Attribute)],
+ aggregateFunctionToAttribute: Map[(AggregateFunction2, Boolean), Attribute],
resultExpressions: Seq[NamedExpression],
child: SparkPlan): Seq[SparkPlan] = {
@@ -221,20 +161,7 @@ object Utils {
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes))
// 1. Create an Aggregate Operator for partial aggregations.
- // The grouping expressions are original groupingExpressions and
- // distinct columns. For example, for avg(distinct value) ... group by key
- // the grouping expressions of this Aggregate Operator will be [key, value].
- val namedGroupingExpressions = groupingExpressions.map {
- case ne: NamedExpression => ne -> ne
- // If the expression is not a NamedExpressions, we add an alias.
- // So, when we generate the result of the operator, the Aggregate Operator
- // can directly get the Seq of attributes representing the grouping expressions.
- case other =>
- val withAlias = Alias(other, other.toString)()
- other -> withAlias
- }
- val groupExpressionMap = namedGroupingExpressions.toMap
- val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute)
+ val groupingAttributes = groupingExpressions.map(_.toAttribute)
// It is safe to call head at here since functionsWithDistinct has at least one
// AggregateExpression2.
@@ -253,22 +180,27 @@ object Utils {
val partialAggregateAttributes =
partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
val partialAggregateGroupingExpressions =
- (namedGroupingExpressions ++ namedDistinctColumnExpressions).map(_._2)
+ groupingExpressions ++ namedDistinctColumnExpressions.map(_._2)
val partialAggregateResult =
- namedGroupingAttributes ++
+ groupingAttributes ++
distinctColumnAttributes ++
partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
val partialAggregate = if (usesTungstenAggregate) {
TungstenAggregate(
- requiredChildDistributionExpressions = None: Option[Seq[Expression]],
+ requiredChildDistributionExpressions = None,
+ // The grouping expressions are original groupingExpressions and
+ // distinct columns. For example, for avg(distinct value) ... group by key
+ // the grouping expressions of this Aggregate Operator will be [key, value].
groupingExpressions = partialAggregateGroupingExpressions,
nonCompleteAggregateExpressions = partialAggregateExpressions,
+ nonCompleteAggregateAttributes = partialAggregateAttributes,
completeAggregateExpressions = Nil,
+ completeAggregateAttributes = Nil,
resultExpressions = partialAggregateResult,
child = child)
} else {
SortBasedAggregate(
- requiredChildDistributionExpressions = None: Option[Seq[Expression]],
+ requiredChildDistributionExpressions = None,
groupingExpressions = partialAggregateGroupingExpressions,
nonCompleteAggregateExpressions = partialAggregateExpressions,
nonCompleteAggregateAttributes = partialAggregateAttributes,
@@ -284,41 +216,40 @@ object Utils {
val partialMergeAggregateAttributes =
partialMergeAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
val partialMergeAggregateResult =
- namedGroupingAttributes ++
+ groupingAttributes ++
distinctColumnAttributes ++
partialMergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
val partialMergeAggregate = if (usesTungstenAggregate) {
TungstenAggregate(
- requiredChildDistributionExpressions = Some(namedGroupingAttributes),
- groupingExpressions = namedGroupingAttributes ++ distinctColumnAttributes,
+ requiredChildDistributionExpressions = Some(groupingAttributes),
+ groupingExpressions = groupingAttributes ++ distinctColumnAttributes,
nonCompleteAggregateExpressions = partialMergeAggregateExpressions,
+ nonCompleteAggregateAttributes = partialMergeAggregateAttributes,
completeAggregateExpressions = Nil,
+ completeAggregateAttributes = Nil,
resultExpressions = partialMergeAggregateResult,
child = partialAggregate)
} else {
SortBasedAggregate(
- requiredChildDistributionExpressions = Some(namedGroupingAttributes),
- groupingExpressions = namedGroupingAttributes ++ distinctColumnAttributes,
+ requiredChildDistributionExpressions = Some(groupingAttributes),
+ groupingExpressions = groupingAttributes ++ distinctColumnAttributes,
nonCompleteAggregateExpressions = partialMergeAggregateExpressions,
nonCompleteAggregateAttributes = partialMergeAggregateAttributes,
completeAggregateExpressions = Nil,
completeAggregateAttributes = Nil,
- initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length,
+ initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length,
resultExpressions = partialMergeAggregateResult,
child = partialAggregate)
}
// 3. Create an Aggregate Operator for partial merge aggregations.
val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final))
- val finalAggregateAttributes =
- finalAggregateExpressions.map {
- expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)._2
- }
- // Create a map to store those rewritten aggregate functions. We always need to use
- // both function and its corresponding isDistinct flag as the key because function itself
- // does not knows if it is has distinct keyword or now.
- val rewrittenAggregateFunctions =
- mutable.Map.empty[(AggregateFunction2, Boolean), AggregateFunction2]
+ // The attributes of the final aggregation buffer, which is presented as input to the result
+ // projection:
+ val finalAggregateAttributes = finalAggregateExpressions.map {
+ expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct)
+ }
+
val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map {
// Children of an AggregateFunction with DISTINCT keyword has already
// been evaluated. At here, we need to replace original children
@@ -328,9 +259,6 @@ object Utils {
case expr if distinctColumnExpressionMap.contains(expr) =>
distinctColumnExpressionMap(expr).toAttribute
}.asInstanceOf[AggregateFunction2]
- // Because we have rewritten the aggregate function, we use rewrittenAggregateFunctions
- // to track the old version and the new version of this function.
- rewrittenAggregateFunctions += (aggregateFunction, true) -> rewrittenAggregateFunction
// We rewrite the aggregate function to a non-distinct aggregation because
// its input will have distinct arguments.
// We just keep the isDistinct setting to true, so when users look at the query plan,
@@ -338,66 +266,30 @@ object Utils {
val rewrittenAggregateExpression =
AggregateExpression2(rewrittenAggregateFunction, Complete, true)
- val aggregateFunctionAttribute =
- aggregateFunctionMap(agg.aggregateFunction, true)._2
- (rewrittenAggregateExpression -> aggregateFunctionAttribute)
+ val aggregateFunctionAttribute = aggregateFunctionToAttribute(agg.aggregateFunction, true)
+ (rewrittenAggregateExpression, aggregateFunctionAttribute)
}.unzip
val finalAndCompleteAggregate = if (usesTungstenAggregate) {
- val rewrittenResultExpressions = resultExpressions.map { expr =>
- expr.transform {
- case agg: AggregateExpression2 =>
- val function = agg.aggregateFunction
- val isDistinct = agg.isDistinct
- val aggregateFunction =
- if (rewrittenAggregateFunctions.contains(function, isDistinct)) {
- // If this function has been rewritten, we get the rewritten version from
- // rewrittenAggregateFunctions.
- rewrittenAggregateFunctions(function, isDistinct)
- } else {
- // Oterwise, we get it from aggregateFunctionMap, which contains unique
- // aggregate functions that have not been rewritten.
- aggregateFunctionMap(function, isDistinct)._1
- }
- aggregateFunction.asInstanceOf[DeclarativeAggregate].evaluateExpression
- case expression =>
- // We do not rely on the equality check at here since attributes may
- // different cosmetically. Instead, we use semanticEquals.
- groupExpressionMap.collectFirst {
- case (expr, ne) if expr semanticEquals expression => ne.toAttribute
- }.getOrElse(expression)
- }.asInstanceOf[NamedExpression]
- }
-
TungstenAggregate(
- requiredChildDistributionExpressions = Some(namedGroupingAttributes),
- groupingExpressions = namedGroupingAttributes,
+ requiredChildDistributionExpressions = Some(groupingAttributes),
+ groupingExpressions = groupingAttributes,
nonCompleteAggregateExpressions = finalAggregateExpressions,
+ nonCompleteAggregateAttributes = finalAggregateAttributes,
completeAggregateExpressions = completeAggregateExpressions,
- resultExpressions = rewrittenResultExpressions,
+ completeAggregateAttributes = completeAggregateAttributes,
+ resultExpressions = resultExpressions,
child = partialMergeAggregate)
} else {
- val rewrittenResultExpressions = resultExpressions.map { expr =>
- expr.transform {
- case agg: AggregateExpression2 =>
- aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct)._2
- case expression =>
- // We do not rely on the equality check at here since attributes may
- // different cosmetically. Instead, we use semanticEquals.
- groupExpressionMap.collectFirst {
- case (expr, ne) if expr semanticEquals expression => ne.toAttribute
- }.getOrElse(expression)
- }.asInstanceOf[NamedExpression]
- }
SortBasedAggregate(
- requiredChildDistributionExpressions = Some(namedGroupingAttributes),
- groupingExpressions = namedGroupingAttributes,
+ requiredChildDistributionExpressions = Some(groupingAttributes),
+ groupingExpressions = groupingAttributes,
nonCompleteAggregateExpressions = finalAggregateExpressions,
nonCompleteAggregateAttributes = finalAggregateAttributes,
completeAggregateExpressions = completeAggregateExpressions,
completeAggregateAttributes = completeAggregateAttributes,
- initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length,
- resultExpressions = rewrittenResultExpressions,
+ initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length,
+ resultExpressions = resultExpressions,
child = partialMergeAggregate)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
index 7ca677a6c7..ed974b3a53 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
@@ -38,7 +38,7 @@ class TungstenAggregationIteratorSuite extends SparkFunSuite with SharedSQLConte
() => new InterpretedMutableProjection(expr, schema)
}
val dummyAccum = SQLMetrics.createLongMetric(sparkContext, "dummy")
- iter = new TungstenAggregationIterator(Seq.empty, Seq.empty, Seq.empty,
+ iter = new TungstenAggregationIterator(Seq.empty, Seq.empty, Seq.empty, Seq.empty, Seq.empty,
Seq.empty, newMutableProjection, Seq.empty, None, dummyAccum, dummyAccum)
val numPages = iter.getHashMap.getNumDataPages
assert(numPages === 1)