aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@databricks.com>2015-10-14 17:27:27 -0700
committerYin Huai <yhuai@databricks.com>2015-10-14 17:27:50 -0700
commit4ace4f8a9c91beb21a0077e12b75637a4560a542 (patch)
tree3d9c2224cfec1cc0d839114e72b52856d68b8356 /sql/core
parent1baaf2b9bd7c949a8f95cd14fc1be2a56e1139b3 (diff)
downloadspark-4ace4f8a9c91beb21a0077e12b75637a4560a542.tar.gz
spark-4ace4f8a9c91beb21a0077e12b75637a4560a542.tar.bz2
spark-4ace4f8a9c91beb21a0077e12b75637a4560a542.zip
[SPARK-11017] [SQL] Support ImperativeAggregates in TungstenAggregate
This patch extends TungstenAggregate to support ImperativeAggregate functions. The existing TungstenAggregate operator only supported DeclarativeAggregate functions, which are defined in terms of Catalyst expressions and can be evaluated via generated projections. ImperativeAggregate functions, on the other hand, are evaluated by calling their `initialize`, `update`, `merge`, and `eval` methods. The basic strategy here is similar to how SortBasedAggregate evaluates both types of aggregate functions: use a generated projection to evaluate the expression-based declarative aggregates with dummy placeholder expressions inserted in place of the imperative aggregate function output, then invoke the imperative aggregate functions and target them against the aggregation buffer. The bulk of the diff here consists of code that was copied and adapted from SortBasedAggregate, with some key changes to handle TungstenAggregate's sort fallback path. Author: Josh Rosen <joshrosen@databricks.com> Closes #9038 from JoshRosen/support-interpreted-in-tungsten-agg-final.
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala29
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala22
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala250
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala79
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala269
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala2
6 files changed, 407 insertions, 244 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
index 8e0fbd109b..99fb7a40b7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
@@ -83,7 +83,7 @@ abstract class AggregationIterator(
var i = 0
while (i < allAggregateExpressions.length) {
val func = allAggregateExpressions(i).aggregateFunction
- val funcWithBoundReferences = allAggregateExpressions(i).mode match {
+ val funcWithBoundReferences: AggregateFunction2 = allAggregateExpressions(i).mode match {
case Partial | Complete if func.isInstanceOf[ImperativeAggregate] =>
// We need to create BoundReferences if the function is not an
// expression-based aggregate function (it does not support code-gen) and the mode of
@@ -94,24 +94,24 @@ abstract class AggregationIterator(
case _ =>
// We only need to set inputBufferOffset for aggregate functions with mode
// PartialMerge and Final.
- func match {
+ val updatedFunc = func match {
case function: ImperativeAggregate =>
function.withNewInputAggBufferOffset(inputBufferOffset)
- case _ =>
+ case function => function
}
inputBufferOffset += func.aggBufferSchema.length
- func
+ updatedFunc
}
- // Set mutableBufferOffset for this function. It is important that setting
- // mutableBufferOffset happens after all potential bindReference operations
- // because bindReference will create a new instance of the function.
- funcWithBoundReferences match {
+ val funcWithUpdatedAggBufferOffset = funcWithBoundReferences match {
case function: ImperativeAggregate =>
+ // Set mutableBufferOffset for this function. It is important that setting
+ // mutableBufferOffset happens after all potential bindReference operations
+ // because bindReference will create a new instance of the function.
function.withNewMutableAggBufferOffset(mutableBufferOffset)
- case _ =>
+ case function => function
}
- mutableBufferOffset += funcWithBoundReferences.aggBufferSchema.length
- functions(i) = funcWithBoundReferences
+ mutableBufferOffset += funcWithUpdatedAggBufferOffset.aggBufferSchema.length
+ functions(i) = funcWithUpdatedAggBufferOffset
i += 1
}
functions
@@ -320,7 +320,7 @@ abstract class AggregationIterator(
// Initializing the function used to generate the output row.
protected val generateOutput: (InternalRow, MutableRow) => InternalRow = {
val rowToBeEvaluated = new JoinedRow
- val safeOutputRow = new GenericMutableRow(resultExpressions.length)
+ val safeOutputRow = new SpecificMutableRow(resultExpressions.map(_.dataType))
val mutableOutput = if (outputsUnsafeRows) {
UnsafeProjection.create(resultExpressions.map(_.dataType).toArray).apply(safeOutputRow)
} else {
@@ -358,7 +358,8 @@ abstract class AggregationIterator(
val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferSchemata)()
val aggregateResultSchema = nonCompleteAggregateAttributes ++ completeAggregateAttributes
// TODO: Use unsafe row.
- val aggregateResult = new GenericMutableRow(aggregateResultSchema.length)
+ val aggregateResult = new SpecificMutableRow(aggregateResultSchema.map(_.dataType))
+ expressionAggEvalProjection.target(aggregateResult)
val resultProjection =
newMutableProjection(
resultExpressions, groupingKeyAttributes ++ aggregateResultSchema)()
@@ -366,7 +367,7 @@ abstract class AggregationIterator(
(currentGroupingKey: InternalRow, currentBuffer: MutableRow) => {
// Generate results for all expression-based aggregate functions.
- expressionAggEvalProjection.target(aggregateResult)(currentBuffer)
+ expressionAggEvalProjection(currentBuffer)
// Generate results for all imperative aggregate functions.
var i = 0
while (i < allImperativeAggregateFunctions.length) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index 7b3d072b2e..c342940e6e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -24,8 +24,9 @@ import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.{UnaryNode, SparkPlan}
+import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, UnaryNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
+import org.apache.spark.sql.types.StructType
case class TungstenAggregate(
requiredChildDistributionExpressions: Option[Seq[Expression]],
@@ -34,10 +35,18 @@ case class TungstenAggregate(
nonCompleteAggregateAttributes: Seq[Attribute],
completeAggregateExpressions: Seq[AggregateExpression2],
completeAggregateAttributes: Seq[Attribute],
+ initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
extends UnaryNode {
+ private[this] val aggregateBufferAttributes = {
+ (nonCompleteAggregateExpressions ++ completeAggregateExpressions)
+ .flatMap(_.aggregateFunction.aggBufferAttributes)
+ }
+
+ require(TungstenAggregate.supportsAggregate(groupingExpressions, aggregateBufferAttributes))
+
override private[sql] lazy val metrics = Map(
"numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"),
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
@@ -82,6 +91,7 @@ case class TungstenAggregate(
nonCompleteAggregateAttributes,
completeAggregateExpressions,
completeAggregateAttributes,
+ initialInputBufferOffset,
resultExpressions,
newMutableProjection,
child.output,
@@ -138,3 +148,13 @@ case class TungstenAggregate(
}
}
}
+
+object TungstenAggregate {
+ def supportsAggregate(
+ groupingExpressions: Seq[Expression],
+ aggregateBufferAttributes: Seq[Attribute]): Boolean = {
+ val aggregationBufferSchema = StructType.fromAttributes(aggregateBufferAttributes)
+ UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) &&
+ UnsafeProjection.canSupport(groupingExpressions)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
index 4bb95c9eb7..fe708a5f71 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution.aggregate
+import scala.collection.mutable.ArrayBuffer
+
import org.apache.spark.unsafe.KVIterator
import org.apache.spark.{InternalAccumulator, Logging, SparkEnv, TaskContext}
import org.apache.spark.sql.catalyst.expressions._
@@ -79,6 +81,7 @@ class TungstenAggregationIterator(
nonCompleteAggregateAttributes: Seq[Attribute],
completeAggregateExpressions: Seq[AggregateExpression2],
completeAggregateAttributes: Seq[Attribute],
+ initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
originalInputAttributes: Seq[Attribute],
@@ -134,19 +137,74 @@ class TungstenAggregationIterator(
completeAggregateExpressions.map(_.mode).distinct.headOption
}
- // All aggregate functions. TungstenAggregationIterator only handles expression-based aggregate.
- // If there is any functions that is an ImperativeAggregateFunction, we throw an
- // IllegalStateException.
- private[this] val allAggregateFunctions: Array[DeclarativeAggregate] = {
- if (!allAggregateExpressions.forall(
- _.aggregateFunction.isInstanceOf[DeclarativeAggregate])) {
- throw new IllegalStateException(
- "Only ExpressionAggregateFunctions should be passed in TungstenAggregationIterator.")
+ // Initialize all AggregateFunctions by binding references, if necessary,
+ // and setting inputBufferOffset and mutableBufferOffset.
+ private def initializeAllAggregateFunctions(
+ startingInputBufferOffset: Int): Array[AggregateFunction2] = {
+ var mutableBufferOffset = 0
+ var inputBufferOffset: Int = startingInputBufferOffset
+ val functions = new Array[AggregateFunction2](allAggregateExpressions.length)
+ var i = 0
+ while (i < allAggregateExpressions.length) {
+ val func = allAggregateExpressions(i).aggregateFunction
+ val aggregateExpressionIsNonComplete = i < nonCompleteAggregateExpressions.length
+ // We need to use this mode instead of func.mode in order to handle aggregation mode switching
+ // when switching to sort-based aggregation:
+ val mode = if (aggregateExpressionIsNonComplete) aggregationMode._1 else aggregationMode._2
+ val funcWithBoundReferences = mode match {
+ case Some(Partial) | Some(Complete) if func.isInstanceOf[ImperativeAggregate] =>
+ // We need to create BoundReferences if the function is not an
+ // expression-based aggregate function (it does not support code-gen) and the mode of
+ // this function is Partial or Complete because we will call eval of this
+ // function's children in the update method of this aggregate function.
+ // Those eval calls require BoundReferences to work.
+ BindReferences.bindReference(func, originalInputAttributes)
+ case _ =>
+ // We only need to set inputBufferOffset for aggregate functions with mode
+ // PartialMerge and Final.
+ val updatedFunc = func match {
+ case function: ImperativeAggregate =>
+ function.withNewInputAggBufferOffset(inputBufferOffset)
+ case function => function
+ }
+ inputBufferOffset += func.aggBufferSchema.length
+ updatedFunc
+ }
+ val funcWithUpdatedAggBufferOffset = funcWithBoundReferences match {
+ case function: ImperativeAggregate =>
+ // Set mutableBufferOffset for this function. It is important that setting
+ // mutableBufferOffset happens after all potential bindReference operations
+ // because bindReference will create a new instance of the function.
+ function.withNewMutableAggBufferOffset(mutableBufferOffset)
+ case function => function
+ }
+ mutableBufferOffset += funcWithUpdatedAggBufferOffset.aggBufferSchema.length
+ functions(i) = funcWithUpdatedAggBufferOffset
+ i += 1
}
+ functions
+ }
- allAggregateExpressions
- .map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
- .toArray
+ private[this] var allAggregateFunctions: Array[AggregateFunction2] =
+ initializeAllAggregateFunctions(initialInputBufferOffset)
+
+ // Positions of those imperative aggregate functions in allAggregateFunctions.
+ // For example, say that we have func1, func2, func3, func4 in aggregateFunctions, and
+ // func2 and func3 are imperative aggregate functions. Then
+ // allImperativeAggregateFunctionPositions will be [1, 2]. Note that this does not need to be
+ // updated when falling back to sort-based aggregation because the positions of the aggregate
+ // functions do not change in that case.
+ private[this] val allImperativeAggregateFunctionPositions: Array[Int] = {
+ val positions = new ArrayBuffer[Int]()
+ var i = 0
+ while (i < allAggregateFunctions.length) {
+ allAggregateFunctions(i) match {
+ case agg: DeclarativeAggregate =>
+ case _ => positions += i
+ }
+ i += 1
+ }
+ positions.toArray
}
///////////////////////////////////////////////////////////////////////////
@@ -155,25 +213,31 @@ class TungstenAggregationIterator(
// rows.
///////////////////////////////////////////////////////////////////////////
- // The projection used to initialize buffer values.
- private[this] val initialProjection: MutableProjection = {
- val initExpressions = allAggregateFunctions.flatMap(_.initialValues)
+ // The projection used to initialize buffer values for all expression-based aggregates.
+ // Note that this projection does not need to be updated when switching to sort-based aggregation
+ // because the schema of empty aggregation buffers does not change in that case.
+ private[this] val expressionAggInitialProjection: MutableProjection = {
+ val initExpressions = allAggregateFunctions.flatMap {
+ case ae: DeclarativeAggregate => ae.initialValues
+ // For the positions corresponding to imperative aggregate functions, we'll use special
+ // no-op expressions which are ignored during projection code-generation.
+ case i: ImperativeAggregate => Seq.fill(i.aggBufferAttributes.length)(NoOp)
+ }
newMutableProjection(initExpressions, Nil)()
}
// Creates a new aggregation buffer and initializes buffer values.
- // This functions should be only called at most three times (when we create the hash map,
+ // This function should be only called at most three times (when we create the hash map,
// when we switch to sort-based aggregation, and when we create the re-used buffer for
// sort-based aggregation).
private def createNewAggregationBuffer(): UnsafeRow = {
val bufferSchema = allAggregateFunctions.flatMap(_.aggBufferAttributes)
- val bufferRowSize: Int = bufferSchema.length
-
- val genericMutableBuffer = new GenericMutableRow(bufferRowSize)
- val unsafeProjection =
- UnsafeProjection.create(bufferSchema.map(_.dataType))
- val buffer = unsafeProjection.apply(genericMutableBuffer)
- initialProjection.target(buffer)(EmptyRow)
+ val buffer: UnsafeRow = UnsafeProjection.create(bufferSchema.map(_.dataType))
+ .apply(new GenericMutableRow(bufferSchema.length))
+ // Initialize declarative aggregates' buffer values
+ expressionAggInitialProjection.target(buffer)(EmptyRow)
+ // Initialize imperative aggregates' buffer values
+ allAggregateFunctions.collect { case f: ImperativeAggregate => f }.foreach(_.initialize(buffer))
buffer
}
@@ -187,72 +251,124 @@ class TungstenAggregationIterator(
aggregationMode match {
// Partial-only
case (Some(Partial), None) =>
- val updateExpressions = allAggregateFunctions.flatMap(_.updateExpressions)
- val updateProjection =
+ val updateExpressions = allAggregateFunctions.flatMap {
+ case ae: DeclarativeAggregate => ae.updateExpressions
+ case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
+ }
+ val imperativeAggregateFunctions: Array[ImperativeAggregate] =
+ allAggregateFunctions.collect { case func: ImperativeAggregate => func}
+ val expressionAggUpdateProjection =
newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)()
(currentBuffer: UnsafeRow, row: InternalRow) => {
- updateProjection.target(currentBuffer)
- updateProjection(joinedRow(currentBuffer, row))
+ expressionAggUpdateProjection.target(currentBuffer)
+ // Process all expression-based aggregate functions.
+ expressionAggUpdateProjection(joinedRow(currentBuffer, row))
+ // Process all imperative aggregate functions
+ var i = 0
+ while (i < imperativeAggregateFunctions.length) {
+ imperativeAggregateFunctions(i).update(currentBuffer, row)
+ i += 1
+ }
}
// PartialMerge-only or Final-only
case (Some(PartialMerge), None) | (Some(Final), None) =>
- val mergeExpressions = allAggregateFunctions.flatMap(_.mergeExpressions)
- val mergeProjection =
+ val mergeExpressions = allAggregateFunctions.flatMap {
+ case ae: DeclarativeAggregate => ae.mergeExpressions
+ case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
+ }
+ val imperativeAggregateFunctions: Array[ImperativeAggregate] =
+ allAggregateFunctions.collect { case func: ImperativeAggregate => func}
+ // This projection is used to merge buffer values for all expression-based aggregates.
+ val expressionAggMergeProjection =
newMutableProjection(mergeExpressions, aggregationBufferAttributes ++ inputAttributes)()
(currentBuffer: UnsafeRow, row: InternalRow) => {
- mergeProjection.target(currentBuffer)
- mergeProjection(joinedRow(currentBuffer, row))
+ // Process all expression-based aggregate functions.
+ expressionAggMergeProjection.target(currentBuffer)(joinedRow(currentBuffer, row))
+ // Process all imperative aggregate functions.
+ var i = 0
+ while (i < imperativeAggregateFunctions.length) {
+ imperativeAggregateFunctions(i).merge(currentBuffer, row)
+ i += 1
+ }
}
// Final-Complete
case (Some(Final), Some(Complete)) =>
- val nonCompleteAggregateFunctions: Array[DeclarativeAggregate] =
- allAggregateFunctions.take(nonCompleteAggregateExpressions.length)
- val completeAggregateFunctions: Array[DeclarativeAggregate] =
+ val completeAggregateFunctions: Array[AggregateFunction2] =
allAggregateFunctions.takeRight(completeAggregateExpressions.length)
+ val completeImperativeAggregateFunctions: Array[ImperativeAggregate] =
+ completeAggregateFunctions.collect { case func: ImperativeAggregate => func }
+ val nonCompleteAggregateFunctions: Array[AggregateFunction2] =
+ allAggregateFunctions.take(nonCompleteAggregateExpressions.length)
+ val nonCompleteImperativeAggregateFunctions: Array[ImperativeAggregate] =
+ nonCompleteAggregateFunctions.collect { case func: ImperativeAggregate => func }
val completeOffsetExpressions =
Seq.fill(completeAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp)
val mergeExpressions =
- nonCompleteAggregateFunctions.flatMap(_.mergeExpressions) ++ completeOffsetExpressions
+ nonCompleteAggregateFunctions.flatMap {
+ case ae: DeclarativeAggregate => ae.mergeExpressions
+ case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
+ } ++ completeOffsetExpressions
val finalMergeProjection =
newMutableProjection(mergeExpressions, aggregationBufferAttributes ++ inputAttributes)()
// We do not touch buffer values of aggregate functions with the Final mode.
val finalOffsetExpressions =
Seq.fill(nonCompleteAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp)
- val updateExpressions =
- finalOffsetExpressions ++ completeAggregateFunctions.flatMap(_.updateExpressions)
+ val updateExpressions = finalOffsetExpressions ++ completeAggregateFunctions.flatMap {
+ case ae: DeclarativeAggregate => ae.updateExpressions
+ case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
+ }
val completeUpdateProjection =
newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)()
(currentBuffer: UnsafeRow, row: InternalRow) => {
val input = joinedRow(currentBuffer, row)
- // For all aggregate functions with mode Complete, update the given currentBuffer.
+ // For all aggregate functions with mode Complete, update buffers.
completeUpdateProjection.target(currentBuffer)(input)
+ var i = 0
+ while (i < completeImperativeAggregateFunctions.length) {
+ completeImperativeAggregateFunctions(i).update(currentBuffer, row)
+ i += 1
+ }
// For all aggregate functions with mode Final, merge buffer values in row to
// currentBuffer.
finalMergeProjection.target(currentBuffer)(input)
+ i = 0
+ while (i < nonCompleteImperativeAggregateFunctions.length) {
+ nonCompleteImperativeAggregateFunctions(i).merge(currentBuffer, row)
+ i += 1
+ }
}
// Complete-only
case (None, Some(Complete)) =>
- val completeAggregateFunctions: Array[DeclarativeAggregate] =
+ val completeAggregateFunctions: Array[AggregateFunction2] =
allAggregateFunctions.takeRight(completeAggregateExpressions.length)
+ // All imperative aggregate functions with mode Complete.
+ val completeImperativeAggregateFunctions: Array[ImperativeAggregate] =
+ completeAggregateFunctions.collect { case func: ImperativeAggregate => func }
- val updateExpressions =
- completeAggregateFunctions.flatMap(_.updateExpressions)
- val completeUpdateProjection =
+ val updateExpressions = completeAggregateFunctions.flatMap {
+ case ae: DeclarativeAggregate => ae.updateExpressions
+ case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
+ }
+ val completeExpressionAggUpdateProjection =
newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)()
(currentBuffer: UnsafeRow, row: InternalRow) => {
- completeUpdateProjection.target(currentBuffer)
- // For all aggregate functions with mode Complete, update the given currentBuffer.
- completeUpdateProjection(joinedRow(currentBuffer, row))
+ // For all aggregate functions with mode Complete, update buffers.
+ completeExpressionAggUpdateProjection.target(currentBuffer)(joinedRow(currentBuffer, row))
+ var i = 0
+ while (i < completeImperativeAggregateFunctions.length) {
+ completeImperativeAggregateFunctions(i).update(currentBuffer, row)
+ i += 1
+ }
}
// Grouping only.
@@ -288,17 +404,30 @@ class TungstenAggregationIterator(
val joinedRow = new JoinedRow()
val evalExpressions = allAggregateFunctions.map {
case ae: DeclarativeAggregate => ae.evaluateExpression
- // case agg: AggregateFunction2 => Literal.create(null, agg.dataType)
+ case agg: AggregateFunction2 => NoOp
}
- val expressionAggEvalProjection = UnsafeProjection.create(evalExpressions, bufferAttributes)
+ val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferAttributes)()
// These are the attributes of the row produced by `expressionAggEvalProjection`
val aggregateResultSchema = nonCompleteAggregateAttributes ++ completeAggregateAttributes
+ val aggregateResult = new SpecificMutableRow(aggregateResultSchema.map(_.dataType))
+ expressionAggEvalProjection.target(aggregateResult)
val resultProjection =
UnsafeProjection.create(resultExpressions, groupingAttributes ++ aggregateResultSchema)
+ val allImperativeAggregateFunctions: Array[ImperativeAggregate] =
+ allAggregateFunctions.collect { case func: ImperativeAggregate => func}
+
(currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => {
// Generate results for all expression-based aggregate functions.
- val aggregateResult = expressionAggEvalProjection.apply(currentBuffer)
+ expressionAggEvalProjection(currentBuffer)
+ // Generate results for all imperative aggregate functions.
+ var i = 0
+ while (i < allImperativeAggregateFunctions.length) {
+ aggregateResult.update(
+ allImperativeAggregateFunctionPositions(i),
+ allImperativeAggregateFunctions(i).eval(currentBuffer))
+ i += 1
+ }
resultProjection(joinedRow(currentGroupingKey, aggregateResult))
}
@@ -481,10 +610,27 @@ class TungstenAggregationIterator(
// When needsProcess is false, the format of input rows is groupingKey + aggregation buffer.
// We need to project the aggregation buffer part from an input row.
val buffer = createNewAggregationBuffer()
- // The originalInputAttributes are using cloneBufferAttributes. So, we need to use
- // allAggregateFunctions.flatMap(_.cloneBufferAttributes).
+ // In principle, we could use `allAggregateFunctions.flatMap(_.inputAggBufferAttributes)` to
+ // extract the aggregation buffer. In practice, however, we extract it positionally by relying
+ // on it being present at the end of the row. The reason for this relates to how the different
+ // aggregates handle input binding.
+ //
+ // ImperativeAggregate uses field numbers and field number offsets to manipulate its buffers,
+ // so its correctness does not rely on attribute bindings. When we fall back to sort-based
+ // aggregation, these field number offsets (mutableAggBufferOffset and inputAggBufferOffset)
+ // need to be updated and any internal state in the aggregate functions themselves must be
+ // reset, so we call withNewMutableAggBufferOffset and withNewInputAggBufferOffset to reset
+ // this state and update the offsets.
+ //
+ // The updated ImperativeAggregate will have different attribute ids for its
+ // aggBufferAttributes and inputAggBufferAttributes. This isn't a problem for the actual
+ // ImperativeAggregate evaluation, but it means that
+ // `allAggregateFunctions.flatMap(_.inputAggBufferAttributes)` will no longer match the
+ // attributes in `originalInputAttributes`, which is why we can't use those attributes here.
+ //
+ // For more details, see the discussion on PR #9038.
val bufferExtractor = newMutableProjection(
- allAggregateFunctions.flatMap(_.inputAggBufferAttributes),
+ originalInputAttributes.drop(initialInputBufferOffset),
originalInputAttributes)()
bufferExtractor.target(buffer)
@@ -511,8 +657,10 @@ class TungstenAggregationIterator(
}
aggregationMode = newAggregationMode
+ allAggregateFunctions = initializeAllAggregateFunctions(startingInputBufferOffset = 0)
+
// Basically the value of the KVIterator returned by externalSorter
- // will just aggregation buffer. At here, we use cloneBufferAttributes.
+ // will just aggregation buffer. At here, we use inputAggBufferAttributes.
val newInputAttributes: Seq[Attribute] =
allAggregateFunctions.flatMap(_.inputAggBufferAttributes)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
index fd02be1225..d2f56e0fc1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
@@ -321,9 +321,17 @@ private[sql] class InputAggregationBuffer private[sql] (
*/
private[sql] case class ScalaUDAF(
children: Seq[Expression],
- udaf: UserDefinedAggregateFunction)
+ udaf: UserDefinedAggregateFunction,
+ mutableAggBufferOffset: Int = 0,
+ inputAggBufferOffset: Int = 0)
extends ImperativeAggregate with Logging {
+ override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
+ copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+ override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
+ copy(inputAggBufferOffset = newInputAggBufferOffset)
+
require(
children.length == udaf.inputSchema.length,
s"$udaf only accepts ${udaf.inputSchema.length} arguments, " +
@@ -341,6 +349,11 @@ private[sql] case class ScalaUDAF(
override val aggBufferAttributes: Seq[AttributeReference] = aggBufferSchema.toAttributes
+ // Note: although this simply copies aggBufferAttributes, this common code can not be placed
+ // in the superclass because that will lead to initialization ordering issues.
+ override val inputAggBufferAttributes: Seq[AttributeReference] =
+ aggBufferAttributes.map(_.newInstance())
+
private[this] lazy val childrenSchema: StructType = {
val inputFields = children.zipWithIndex.map {
case (child, index) =>
@@ -382,51 +395,33 @@ private[sql] case class ScalaUDAF(
}
// This buffer is only used at executor side.
- private[this] var inputAggregateBuffer: InputAggregationBuffer = null
-
- // This buffer is only used at executor side.
- private[this] var mutableAggregateBuffer: MutableAggregationBufferImpl = null
+ private[this] lazy val inputAggregateBuffer: InputAggregationBuffer = {
+ new InputAggregationBuffer(
+ aggBufferSchema,
+ bufferValuesToCatalystConverters,
+ bufferValuesToScalaConverters,
+ inputAggBufferOffset,
+ null)
+ }
// This buffer is only used at executor side.
- private[this] var evalAggregateBuffer: InputAggregationBuffer = null
-
- /**
- * Sets the inputBufferOffset to newInputBufferOffset and then create a new instance of
- * `inputAggregateBuffer` based on this new inputBufferOffset.
- */
- override def withNewInputAggBufferOffset(newInputBufferOffset: Int): Unit = {
- super.withNewInputAggBufferOffset(newInputBufferOffset)
- // inputBufferOffset has been updated.
- inputAggregateBuffer =
- new InputAggregationBuffer(
- aggBufferSchema,
- bufferValuesToCatalystConverters,
- bufferValuesToScalaConverters,
- inputAggBufferOffset,
- null)
+ private[this] lazy val mutableAggregateBuffer: MutableAggregationBufferImpl = {
+ new MutableAggregationBufferImpl(
+ aggBufferSchema,
+ bufferValuesToCatalystConverters,
+ bufferValuesToScalaConverters,
+ mutableAggBufferOffset,
+ null)
}
- /**
- * Sets the mutableBufferOffset to newMutableBufferOffset and then create a new instance of
- * `mutableAggregateBuffer` and `evalAggregateBuffer` based on this new mutableBufferOffset.
- */
- override def withNewMutableAggBufferOffset(newMutableBufferOffset: Int): Unit = {
- super.withNewMutableAggBufferOffset(newMutableBufferOffset)
- // mutableBufferOffset has been updated.
- mutableAggregateBuffer =
- new MutableAggregationBufferImpl(
- aggBufferSchema,
- bufferValuesToCatalystConverters,
- bufferValuesToScalaConverters,
- mutableAggBufferOffset,
- null)
- evalAggregateBuffer =
- new InputAggregationBuffer(
- aggBufferSchema,
- bufferValuesToCatalystConverters,
- bufferValuesToScalaConverters,
- mutableAggBufferOffset,
- null)
+ // This buffer is only used at executor side.
+ private[this] lazy val evalAggregateBuffer: InputAggregationBuffer = {
+ new InputAggregationBuffer(
+ aggBufferSchema,
+ bufferValuesToCatalystConverters,
+ bufferValuesToScalaConverters,
+ mutableAggBufferOffset,
+ null)
}
override def initialize(buffer: MutableRow): Unit = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
index cf6e7ed0d3..eaafd83158 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
@@ -19,21 +19,12 @@ package org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan}
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.execution.SparkPlan
/**
* Utility functions used by the query planner to convert our plan to new aggregation code path.
*/
object Utils {
- def supportsTungstenAggregate(
- groupingExpressions: Seq[Expression],
- aggregateBufferAttributes: Seq[Attribute]): Boolean = {
- val aggregationBufferSchema = StructType.fromAttributes(aggregateBufferAttributes)
-
- UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) &&
- UnsafeProjection.canSupport(groupingExpressions)
- }
def planAggregateWithoutPartial(
groupingExpressions: Seq[NamedExpression],
@@ -70,8 +61,7 @@ object Utils {
// Check if we can use TungstenAggregate.
val usesTungstenAggregate =
child.sqlContext.conf.unsafeEnabled &&
- aggregateExpressions.forall(_.aggregateFunction.isInstanceOf[DeclarativeAggregate]) &&
- supportsTungstenAggregate(
+ TungstenAggregate.supportsAggregate(
groupingExpressions,
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes))
@@ -94,6 +84,7 @@ object Utils {
nonCompleteAggregateAttributes = partialAggregateAttributes,
completeAggregateExpressions = Nil,
completeAggregateAttributes = Nil,
+ initialInputBufferOffset = 0,
resultExpressions = partialResultExpressions,
child = child)
} else {
@@ -125,6 +116,7 @@ object Utils {
nonCompleteAggregateAttributes = finalAggregateAttributes,
completeAggregateExpressions = Nil,
completeAggregateAttributes = Nil,
+ initialInputBufferOffset = groupingExpressions.length,
resultExpressions = resultExpressions,
child = partialAggregate)
} else {
@@ -154,143 +146,150 @@ object Utils {
val aggregateExpressions = functionsWithDistinct ++ functionsWithoutDistinct
val usesTungstenAggregate =
child.sqlContext.conf.unsafeEnabled &&
- aggregateExpressions.forall(
- _.aggregateFunction.isInstanceOf[DeclarativeAggregate]) &&
- supportsTungstenAggregate(
+ TungstenAggregate.supportsAggregate(
groupingExpressions,
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes))
- // 1. Create an Aggregate Operator for partial aggregations.
- val groupingAttributes = groupingExpressions.map(_.toAttribute)
-
- // It is safe to call head at here since functionsWithDistinct has at least one
- // AggregateExpression2.
- val distinctColumnExpressions =
- functionsWithDistinct.head.aggregateFunction.children
- val namedDistinctColumnExpressions = distinctColumnExpressions.map {
- case ne: NamedExpression => ne -> ne
- case other =>
- val withAlias = Alias(other, other.toString)()
- other -> withAlias
+ // functionsWithDistinct is guaranteed to be non-empty. Even though it may contain more than one
+ // DISTINCT aggregate function, all of those functions will have the same column expression.
+ // For example, it would be valid for functionsWithDistinct to be
+ // [COUNT(DISTINCT foo), MAX(DISTINCT foo)], but [COUNT(DISTINCT bar), COUNT(DISTINCT foo)] is
+ // disallowed because those two distinct aggregates have different column expressions.
+ val distinctColumnExpression: Expression = {
+ val allDistinctColumnExpressions = functionsWithDistinct.head.aggregateFunction.children
+ assert(allDistinctColumnExpressions.length == 1)
+ allDistinctColumnExpressions.head
+ }
+ val namedDistinctColumnExpression: NamedExpression = distinctColumnExpression match {
+ case ne: NamedExpression => ne
+ case other => Alias(other, other.toString)()
}
- val distinctColumnExpressionMap = namedDistinctColumnExpressions.toMap
- val distinctColumnAttributes = namedDistinctColumnExpressions.map(_._2.toAttribute)
+ val distinctColumnAttribute: Attribute = namedDistinctColumnExpression.toAttribute
+ val groupingAttributes = groupingExpressions.map(_.toAttribute)
- val partialAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial))
- val partialAggregateAttributes =
- partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
- val partialAggregateGroupingExpressions =
- groupingExpressions ++ namedDistinctColumnExpressions.map(_._2)
- val partialAggregateResult =
+ // 1. Create an Aggregate Operator for partial aggregations.
+ val partialAggregate: SparkPlan = {
+ val partialAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial))
+ val partialAggregateAttributes =
+ partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
+ // We will group by the original grouping expression, plus an additional expression for the
+ // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping
+ // expressions will be [key, value].
+ val partialAggregateGroupingExpressions = groupingExpressions :+ namedDistinctColumnExpression
+ val partialAggregateResult =
groupingAttributes ++
- distinctColumnAttributes ++
- partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
- val partialAggregate = if (usesTungstenAggregate) {
- TungstenAggregate(
- requiredChildDistributionExpressions = None,
- // The grouping expressions are original groupingExpressions and
- // distinct columns. For example, for avg(distinct value) ... group by key
- // the grouping expressions of this Aggregate Operator will be [key, value].
- groupingExpressions = partialAggregateGroupingExpressions,
- nonCompleteAggregateExpressions = partialAggregateExpressions,
- nonCompleteAggregateAttributes = partialAggregateAttributes,
- completeAggregateExpressions = Nil,
- completeAggregateAttributes = Nil,
- resultExpressions = partialAggregateResult,
- child = child)
- } else {
- SortBasedAggregate(
- requiredChildDistributionExpressions = None,
- groupingExpressions = partialAggregateGroupingExpressions,
- nonCompleteAggregateExpressions = partialAggregateExpressions,
- nonCompleteAggregateAttributes = partialAggregateAttributes,
- completeAggregateExpressions = Nil,
- completeAggregateAttributes = Nil,
- initialInputBufferOffset = 0,
- resultExpressions = partialAggregateResult,
- child = child)
+ Seq(distinctColumnAttribute) ++
+ partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
+ if (usesTungstenAggregate) {
+ TungstenAggregate(
+ requiredChildDistributionExpressions = None,
+ groupingExpressions = partialAggregateGroupingExpressions,
+ nonCompleteAggregateExpressions = partialAggregateExpressions,
+ nonCompleteAggregateAttributes = partialAggregateAttributes,
+ completeAggregateExpressions = Nil,
+ completeAggregateAttributes = Nil,
+ initialInputBufferOffset = 0,
+ resultExpressions = partialAggregateResult,
+ child = child)
+ } else {
+ SortBasedAggregate(
+ requiredChildDistributionExpressions = None,
+ groupingExpressions = partialAggregateGroupingExpressions,
+ nonCompleteAggregateExpressions = partialAggregateExpressions,
+ nonCompleteAggregateAttributes = partialAggregateAttributes,
+ completeAggregateExpressions = Nil,
+ completeAggregateAttributes = Nil,
+ initialInputBufferOffset = 0,
+ resultExpressions = partialAggregateResult,
+ child = child)
+ }
}
// 2. Create an Aggregate Operator for partial merge aggregations.
- val partialMergeAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge))
- val partialMergeAggregateAttributes =
- partialMergeAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
- val partialMergeAggregateResult =
+ val partialMergeAggregate: SparkPlan = {
+ val partialMergeAggregateExpressions =
+ functionsWithoutDistinct.map(_.copy(mode = PartialMerge))
+ val partialMergeAggregateAttributes =
+ partialMergeAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
+ val partialMergeAggregateResult =
groupingAttributes ++
- distinctColumnAttributes ++
- partialMergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
- val partialMergeAggregate = if (usesTungstenAggregate) {
- TungstenAggregate(
- requiredChildDistributionExpressions = Some(groupingAttributes),
- groupingExpressions = groupingAttributes ++ distinctColumnAttributes,
- nonCompleteAggregateExpressions = partialMergeAggregateExpressions,
- nonCompleteAggregateAttributes = partialMergeAggregateAttributes,
- completeAggregateExpressions = Nil,
- completeAggregateAttributes = Nil,
- resultExpressions = partialMergeAggregateResult,
- child = partialAggregate)
- } else {
- SortBasedAggregate(
- requiredChildDistributionExpressions = Some(groupingAttributes),
- groupingExpressions = groupingAttributes ++ distinctColumnAttributes,
- nonCompleteAggregateExpressions = partialMergeAggregateExpressions,
- nonCompleteAggregateAttributes = partialMergeAggregateAttributes,
- completeAggregateExpressions = Nil,
- completeAggregateAttributes = Nil,
- initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length,
- resultExpressions = partialMergeAggregateResult,
- child = partialAggregate)
+ Seq(distinctColumnAttribute) ++
+ partialMergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
+ if (usesTungstenAggregate) {
+ TungstenAggregate(
+ requiredChildDistributionExpressions = Some(groupingAttributes),
+ groupingExpressions = groupingAttributes :+ distinctColumnAttribute,
+ nonCompleteAggregateExpressions = partialMergeAggregateExpressions,
+ nonCompleteAggregateAttributes = partialMergeAggregateAttributes,
+ completeAggregateExpressions = Nil,
+ completeAggregateAttributes = Nil,
+ initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length,
+ resultExpressions = partialMergeAggregateResult,
+ child = partialAggregate)
+ } else {
+ SortBasedAggregate(
+ requiredChildDistributionExpressions = Some(groupingAttributes),
+ groupingExpressions = groupingAttributes :+ distinctColumnAttribute,
+ nonCompleteAggregateExpressions = partialMergeAggregateExpressions,
+ nonCompleteAggregateAttributes = partialMergeAggregateAttributes,
+ completeAggregateExpressions = Nil,
+ completeAggregateAttributes = Nil,
+ initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length,
+ resultExpressions = partialMergeAggregateResult,
+ child = partialAggregate)
+ }
}
- // 3. Create an Aggregate Operator for partial merge aggregations.
- val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final))
- // The attributes of the final aggregation buffer, which is presented as input to the result
- // projection:
- val finalAggregateAttributes = finalAggregateExpressions.map {
- expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct)
- }
+ // 3. Create an Aggregate Operator for the final aggregation.
+ val finalAndCompleteAggregate: SparkPlan = {
+ val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final))
+ // The attributes of the final aggregation buffer, which is presented as input to the result
+ // projection:
+ val finalAggregateAttributes = finalAggregateExpressions.map {
+ expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct)
+ }
- val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map {
- // Children of an AggregateFunction with DISTINCT keyword has already
- // been evaluated. At here, we need to replace original children
- // to AttributeReferences.
- case agg @ AggregateExpression2(aggregateFunction, mode, true) =>
- val rewrittenAggregateFunction = aggregateFunction.transformDown {
- case expr if distinctColumnExpressionMap.contains(expr) =>
- distinctColumnExpressionMap(expr).toAttribute
- }.asInstanceOf[AggregateFunction2]
- // We rewrite the aggregate function to a non-distinct aggregation because
- // its input will have distinct arguments.
- // We just keep the isDistinct setting to true, so when users look at the query plan,
- // they still can see distinct aggregations.
- val rewrittenAggregateExpression =
- AggregateExpression2(rewrittenAggregateFunction, Complete, true)
+ val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map {
+ // Children of an AggregateFunction with DISTINCT keyword has already
+ // been evaluated. At here, we need to replace original children
+ // to AttributeReferences.
+ case agg @ AggregateExpression2(aggregateFunction, mode, true) =>
+ val rewrittenAggregateFunction = aggregateFunction.transformDown {
+ case expr if expr == distinctColumnExpression => distinctColumnAttribute
+ }.asInstanceOf[AggregateFunction2]
+ // We rewrite the aggregate function to a non-distinct aggregation because
+ // its input will have distinct arguments.
+ // We just keep the isDistinct setting to true, so when users look at the query plan,
+ // they still can see distinct aggregations.
+ val rewrittenAggregateExpression =
+ AggregateExpression2(rewrittenAggregateFunction, Complete, isDistinct = true)
- val aggregateFunctionAttribute = aggregateFunctionToAttribute(agg.aggregateFunction, true)
- (rewrittenAggregateExpression, aggregateFunctionAttribute)
- }.unzip
-
- val finalAndCompleteAggregate = if (usesTungstenAggregate) {
- TungstenAggregate(
- requiredChildDistributionExpressions = Some(groupingAttributes),
- groupingExpressions = groupingAttributes,
- nonCompleteAggregateExpressions = finalAggregateExpressions,
- nonCompleteAggregateAttributes = finalAggregateAttributes,
- completeAggregateExpressions = completeAggregateExpressions,
- completeAggregateAttributes = completeAggregateAttributes,
- resultExpressions = resultExpressions,
- child = partialMergeAggregate)
- } else {
- SortBasedAggregate(
- requiredChildDistributionExpressions = Some(groupingAttributes),
- groupingExpressions = groupingAttributes,
- nonCompleteAggregateExpressions = finalAggregateExpressions,
- nonCompleteAggregateAttributes = finalAggregateAttributes,
- completeAggregateExpressions = completeAggregateExpressions,
- completeAggregateAttributes = completeAggregateAttributes,
- initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length,
- resultExpressions = resultExpressions,
- child = partialMergeAggregate)
+ val aggregateFunctionAttribute = aggregateFunctionToAttribute(agg.aggregateFunction, true)
+ (rewrittenAggregateExpression, aggregateFunctionAttribute)
+ }.unzip
+ if (usesTungstenAggregate) {
+ TungstenAggregate(
+ requiredChildDistributionExpressions = Some(groupingAttributes),
+ groupingExpressions = groupingAttributes,
+ nonCompleteAggregateExpressions = finalAggregateExpressions,
+ nonCompleteAggregateAttributes = finalAggregateAttributes,
+ completeAggregateExpressions = completeAggregateExpressions,
+ completeAggregateAttributes = completeAggregateAttributes,
+ initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length,
+ resultExpressions = resultExpressions,
+ child = partialMergeAggregate)
+ } else {
+ SortBasedAggregate(
+ requiredChildDistributionExpressions = Some(groupingAttributes),
+ groupingExpressions = groupingAttributes,
+ nonCompleteAggregateExpressions = finalAggregateExpressions,
+ nonCompleteAggregateAttributes = finalAggregateAttributes,
+ completeAggregateExpressions = completeAggregateExpressions,
+ completeAggregateAttributes = completeAggregateAttributes,
+ initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length,
+ resultExpressions = resultExpressions,
+ child = partialMergeAggregate)
+ }
}
finalAndCompleteAggregate :: Nil
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
index ed974b3a53..0cc4988ff6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
@@ -39,7 +39,7 @@ class TungstenAggregationIteratorSuite extends SparkFunSuite with SharedSQLConte
}
val dummyAccum = SQLMetrics.createLongMetric(sparkContext, "dummy")
iter = new TungstenAggregationIterator(Seq.empty, Seq.empty, Seq.empty, Seq.empty, Seq.empty,
- Seq.empty, newMutableProjection, Seq.empty, None, dummyAccum, dummyAccum)
+ 0, Seq.empty, newMutableProjection, Seq.empty, None, dummyAccum, dummyAccum)
val numPages = iter.getHashMap.getNumDataPages
assert(numPages === 1)
} finally {