diff options
9 files changed, 457 insertions, 260 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index 8aad0b7dee..c0bc7ec09c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -472,10 +472,20 @@ case class Sum(child: Expression) extends DeclarativeAggregate { * @param relativeSD the maximum estimation error allowed. */ // scalastyle:on -case class HyperLogLogPlusPlus(child: Expression, relativeSD: Double = 0.05) - extends ImperativeAggregate { +case class HyperLogLogPlusPlus( + child: Expression, + relativeSD: Double = 0.05, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends ImperativeAggregate { import HyperLogLogPlusPlus._ + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + /** * HLL++ uses 'p' bits for addressing. The more addressing bits we use, the more precise the * algorithm will be, and the more memory it will require. The 'p' value is based on the relative @@ -546,6 +556,11 @@ case class HyperLogLogPlusPlus(child: Expression, relativeSD: Double = 0.05) AttributeReference(s"MS[$i]", LongType)() } + // Note: although this simply copies aggBufferAttributes, this common code can not be placed + // in the superclass because that will lead to initialization ordering issues. + override val inputAggBufferAttributes: Seq[AttributeReference] = + aggBufferAttributes.map(_.newInstance()) + /** Fill all words with zeros. */ override def initialize(buffer: MutableRow): Unit = { var word = 0 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 9ba3a9c980..a2fab258fc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -150,6 +150,10 @@ sealed abstract class AggregateFunction2 extends Expression with ImplicitCastInp * We need to perform similar field number arithmetic when merging multiple intermediate * aggregate buffers together in `merge()` (in this case, use `inputAggBufferOffset` when accessing * the input buffer). + * + * Correct ImperativeAggregate evaluation depends on the correctness of `mutableAggBufferOffset` and + * `inputAggBufferOffset`, but not on the correctness of the attribute ids in `aggBufferAttributes` + * and `inputAggBufferAttributes`. */ abstract class ImperativeAggregate extends AggregateFunction2 { @@ -172,11 +176,13 @@ abstract class ImperativeAggregate extends AggregateFunction2 { * avg(y) mutableAggBufferOffset = 2 * */ - protected var mutableAggBufferOffset: Int = 0 + protected val mutableAggBufferOffset: Int - def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): Unit = { - mutableAggBufferOffset = newMutableAggBufferOffset - } + /** + * Returns a copy of this ImperativeAggregate with an updated mutableAggBufferOffset. + * This new copy's attributes may have different ids than the original. + */ + def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate /** * The offset of this function's start buffer value in the underlying shared input aggregation @@ -203,11 +209,17 @@ abstract class ImperativeAggregate extends AggregateFunction2 { * avg(y) inputAggBufferOffset = 3 * */ - protected var inputAggBufferOffset: Int = 0 + protected val inputAggBufferOffset: Int - def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): Unit = { - inputAggBufferOffset = newInputAggBufferOffset - } + /** + * Returns a copy of this ImperativeAggregate with an updated mutableAggBufferOffset. + * This new copy's attributes may have different ids than the original. + */ + def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate + + // Note: although all subclasses implement inputAggBufferAttributes by simply cloning + // aggBufferAttributes, that common clone code cannot be placed here in the abstract + // ImperativeAggregate class, since that will lead to initialization ordering issues. /** * Initializes the mutable aggregation buffer located in `mutableAggBuffer`. @@ -231,9 +243,6 @@ abstract class ImperativeAggregate extends AggregateFunction2 { * Use `fieldNumber + inputAggBufferOffset` to access fields of `inputAggBuffer`. */ def merge(mutableAggBuffer: MutableRow, inputAggBuffer: InternalRow): Unit - - final lazy val inputAggBufferAttributes: Seq[AttributeReference] = - aggBufferAttributes.map(_.newInstance()) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala index 8e0fbd109b..99fb7a40b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -83,7 +83,7 @@ abstract class AggregationIterator( var i = 0 while (i < allAggregateExpressions.length) { val func = allAggregateExpressions(i).aggregateFunction - val funcWithBoundReferences = allAggregateExpressions(i).mode match { + val funcWithBoundReferences: AggregateFunction2 = allAggregateExpressions(i).mode match { case Partial | Complete if func.isInstanceOf[ImperativeAggregate] => // We need to create BoundReferences if the function is not an // expression-based aggregate function (it does not support code-gen) and the mode of @@ -94,24 +94,24 @@ abstract class AggregationIterator( case _ => // We only need to set inputBufferOffset for aggregate functions with mode // PartialMerge and Final. - func match { + val updatedFunc = func match { case function: ImperativeAggregate => function.withNewInputAggBufferOffset(inputBufferOffset) - case _ => + case function => function } inputBufferOffset += func.aggBufferSchema.length - func + updatedFunc } - // Set mutableBufferOffset for this function. It is important that setting - // mutableBufferOffset happens after all potential bindReference operations - // because bindReference will create a new instance of the function. - funcWithBoundReferences match { + val funcWithUpdatedAggBufferOffset = funcWithBoundReferences match { case function: ImperativeAggregate => + // Set mutableBufferOffset for this function. It is important that setting + // mutableBufferOffset happens after all potential bindReference operations + // because bindReference will create a new instance of the function. function.withNewMutableAggBufferOffset(mutableBufferOffset) - case _ => + case function => function } - mutableBufferOffset += funcWithBoundReferences.aggBufferSchema.length - functions(i) = funcWithBoundReferences + mutableBufferOffset += funcWithUpdatedAggBufferOffset.aggBufferSchema.length + functions(i) = funcWithUpdatedAggBufferOffset i += 1 } functions @@ -320,7 +320,7 @@ abstract class AggregationIterator( // Initializing the function used to generate the output row. protected val generateOutput: (InternalRow, MutableRow) => InternalRow = { val rowToBeEvaluated = new JoinedRow - val safeOutputRow = new GenericMutableRow(resultExpressions.length) + val safeOutputRow = new SpecificMutableRow(resultExpressions.map(_.dataType)) val mutableOutput = if (outputsUnsafeRows) { UnsafeProjection.create(resultExpressions.map(_.dataType).toArray).apply(safeOutputRow) } else { @@ -358,7 +358,8 @@ abstract class AggregationIterator( val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferSchemata)() val aggregateResultSchema = nonCompleteAggregateAttributes ++ completeAggregateAttributes // TODO: Use unsafe row. - val aggregateResult = new GenericMutableRow(aggregateResultSchema.length) + val aggregateResult = new SpecificMutableRow(aggregateResultSchema.map(_.dataType)) + expressionAggEvalProjection.target(aggregateResult) val resultProjection = newMutableProjection( resultExpressions, groupingKeyAttributes ++ aggregateResultSchema)() @@ -366,7 +367,7 @@ abstract class AggregationIterator( (currentGroupingKey: InternalRow, currentBuffer: MutableRow) => { // Generate results for all expression-based aggregate functions. - expressionAggEvalProjection.target(aggregateResult)(currentBuffer) + expressionAggEvalProjection(currentBuffer) // Generate results for all imperative aggregate functions. var i = 0 while (i < allImperativeAggregateFunctions.length) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 7b3d072b2e..c342940e6e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -24,8 +24,9 @@ import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2 import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{UnaryNode, SparkPlan} +import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, UnaryNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.types.StructType case class TungstenAggregate( requiredChildDistributionExpressions: Option[Seq[Expression]], @@ -34,10 +35,18 @@ case class TungstenAggregate( nonCompleteAggregateAttributes: Seq[Attribute], completeAggregateExpressions: Seq[AggregateExpression2], completeAggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { + private[this] val aggregateBufferAttributes = { + (nonCompleteAggregateExpressions ++ completeAggregateExpressions) + .flatMap(_.aggregateFunction.aggBufferAttributes) + } + + require(TungstenAggregate.supportsAggregate(groupingExpressions, aggregateBufferAttributes)) + override private[sql] lazy val metrics = Map( "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) @@ -82,6 +91,7 @@ case class TungstenAggregate( nonCompleteAggregateAttributes, completeAggregateExpressions, completeAggregateAttributes, + initialInputBufferOffset, resultExpressions, newMutableProjection, child.output, @@ -138,3 +148,13 @@ case class TungstenAggregate( } } } + +object TungstenAggregate { + def supportsAggregate( + groupingExpressions: Seq[Expression], + aggregateBufferAttributes: Seq[Attribute]): Boolean = { + val aggregationBufferSchema = StructType.fromAttributes(aggregateBufferAttributes) + UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) && + UnsafeProjection.canSupport(groupingExpressions) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 4bb95c9eb7..fe708a5f71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.aggregate +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.unsafe.KVIterator import org.apache.spark.{InternalAccumulator, Logging, SparkEnv, TaskContext} import org.apache.spark.sql.catalyst.expressions._ @@ -79,6 +81,7 @@ class TungstenAggregationIterator( nonCompleteAggregateAttributes: Seq[Attribute], completeAggregateExpressions: Seq[AggregateExpression2], completeAggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), originalInputAttributes: Seq[Attribute], @@ -134,19 +137,74 @@ class TungstenAggregationIterator( completeAggregateExpressions.map(_.mode).distinct.headOption } - // All aggregate functions. TungstenAggregationIterator only handles expression-based aggregate. - // If there is any functions that is an ImperativeAggregateFunction, we throw an - // IllegalStateException. - private[this] val allAggregateFunctions: Array[DeclarativeAggregate] = { - if (!allAggregateExpressions.forall( - _.aggregateFunction.isInstanceOf[DeclarativeAggregate])) { - throw new IllegalStateException( - "Only ExpressionAggregateFunctions should be passed in TungstenAggregationIterator.") + // Initialize all AggregateFunctions by binding references, if necessary, + // and setting inputBufferOffset and mutableBufferOffset. + private def initializeAllAggregateFunctions( + startingInputBufferOffset: Int): Array[AggregateFunction2] = { + var mutableBufferOffset = 0 + var inputBufferOffset: Int = startingInputBufferOffset + val functions = new Array[AggregateFunction2](allAggregateExpressions.length) + var i = 0 + while (i < allAggregateExpressions.length) { + val func = allAggregateExpressions(i).aggregateFunction + val aggregateExpressionIsNonComplete = i < nonCompleteAggregateExpressions.length + // We need to use this mode instead of func.mode in order to handle aggregation mode switching + // when switching to sort-based aggregation: + val mode = if (aggregateExpressionIsNonComplete) aggregationMode._1 else aggregationMode._2 + val funcWithBoundReferences = mode match { + case Some(Partial) | Some(Complete) if func.isInstanceOf[ImperativeAggregate] => + // We need to create BoundReferences if the function is not an + // expression-based aggregate function (it does not support code-gen) and the mode of + // this function is Partial or Complete because we will call eval of this + // function's children in the update method of this aggregate function. + // Those eval calls require BoundReferences to work. + BindReferences.bindReference(func, originalInputAttributes) + case _ => + // We only need to set inputBufferOffset for aggregate functions with mode + // PartialMerge and Final. + val updatedFunc = func match { + case function: ImperativeAggregate => + function.withNewInputAggBufferOffset(inputBufferOffset) + case function => function + } + inputBufferOffset += func.aggBufferSchema.length + updatedFunc + } + val funcWithUpdatedAggBufferOffset = funcWithBoundReferences match { + case function: ImperativeAggregate => + // Set mutableBufferOffset for this function. It is important that setting + // mutableBufferOffset happens after all potential bindReference operations + // because bindReference will create a new instance of the function. + function.withNewMutableAggBufferOffset(mutableBufferOffset) + case function => function + } + mutableBufferOffset += funcWithUpdatedAggBufferOffset.aggBufferSchema.length + functions(i) = funcWithUpdatedAggBufferOffset + i += 1 } + functions + } - allAggregateExpressions - .map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) - .toArray + private[this] var allAggregateFunctions: Array[AggregateFunction2] = + initializeAllAggregateFunctions(initialInputBufferOffset) + + // Positions of those imperative aggregate functions in allAggregateFunctions. + // For example, say that we have func1, func2, func3, func4 in aggregateFunctions, and + // func2 and func3 are imperative aggregate functions. Then + // allImperativeAggregateFunctionPositions will be [1, 2]. Note that this does not need to be + // updated when falling back to sort-based aggregation because the positions of the aggregate + // functions do not change in that case. + private[this] val allImperativeAggregateFunctionPositions: Array[Int] = { + val positions = new ArrayBuffer[Int]() + var i = 0 + while (i < allAggregateFunctions.length) { + allAggregateFunctions(i) match { + case agg: DeclarativeAggregate => + case _ => positions += i + } + i += 1 + } + positions.toArray } /////////////////////////////////////////////////////////////////////////// @@ -155,25 +213,31 @@ class TungstenAggregationIterator( // rows. /////////////////////////////////////////////////////////////////////////// - // The projection used to initialize buffer values. - private[this] val initialProjection: MutableProjection = { - val initExpressions = allAggregateFunctions.flatMap(_.initialValues) + // The projection used to initialize buffer values for all expression-based aggregates. + // Note that this projection does not need to be updated when switching to sort-based aggregation + // because the schema of empty aggregation buffers does not change in that case. + private[this] val expressionAggInitialProjection: MutableProjection = { + val initExpressions = allAggregateFunctions.flatMap { + case ae: DeclarativeAggregate => ae.initialValues + // For the positions corresponding to imperative aggregate functions, we'll use special + // no-op expressions which are ignored during projection code-generation. + case i: ImperativeAggregate => Seq.fill(i.aggBufferAttributes.length)(NoOp) + } newMutableProjection(initExpressions, Nil)() } // Creates a new aggregation buffer and initializes buffer values. - // This functions should be only called at most three times (when we create the hash map, + // This function should be only called at most three times (when we create the hash map, // when we switch to sort-based aggregation, and when we create the re-used buffer for // sort-based aggregation). private def createNewAggregationBuffer(): UnsafeRow = { val bufferSchema = allAggregateFunctions.flatMap(_.aggBufferAttributes) - val bufferRowSize: Int = bufferSchema.length - - val genericMutableBuffer = new GenericMutableRow(bufferRowSize) - val unsafeProjection = - UnsafeProjection.create(bufferSchema.map(_.dataType)) - val buffer = unsafeProjection.apply(genericMutableBuffer) - initialProjection.target(buffer)(EmptyRow) + val buffer: UnsafeRow = UnsafeProjection.create(bufferSchema.map(_.dataType)) + .apply(new GenericMutableRow(bufferSchema.length)) + // Initialize declarative aggregates' buffer values + expressionAggInitialProjection.target(buffer)(EmptyRow) + // Initialize imperative aggregates' buffer values + allAggregateFunctions.collect { case f: ImperativeAggregate => f }.foreach(_.initialize(buffer)) buffer } @@ -187,72 +251,124 @@ class TungstenAggregationIterator( aggregationMode match { // Partial-only case (Some(Partial), None) => - val updateExpressions = allAggregateFunctions.flatMap(_.updateExpressions) - val updateProjection = + val updateExpressions = allAggregateFunctions.flatMap { + case ae: DeclarativeAggregate => ae.updateExpressions + case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + } + val imperativeAggregateFunctions: Array[ImperativeAggregate] = + allAggregateFunctions.collect { case func: ImperativeAggregate => func} + val expressionAggUpdateProjection = newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() (currentBuffer: UnsafeRow, row: InternalRow) => { - updateProjection.target(currentBuffer) - updateProjection(joinedRow(currentBuffer, row)) + expressionAggUpdateProjection.target(currentBuffer) + // Process all expression-based aggregate functions. + expressionAggUpdateProjection(joinedRow(currentBuffer, row)) + // Process all imperative aggregate functions + var i = 0 + while (i < imperativeAggregateFunctions.length) { + imperativeAggregateFunctions(i).update(currentBuffer, row) + i += 1 + } } // PartialMerge-only or Final-only case (Some(PartialMerge), None) | (Some(Final), None) => - val mergeExpressions = allAggregateFunctions.flatMap(_.mergeExpressions) - val mergeProjection = + val mergeExpressions = allAggregateFunctions.flatMap { + case ae: DeclarativeAggregate => ae.mergeExpressions + case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + } + val imperativeAggregateFunctions: Array[ImperativeAggregate] = + allAggregateFunctions.collect { case func: ImperativeAggregate => func} + // This projection is used to merge buffer values for all expression-based aggregates. + val expressionAggMergeProjection = newMutableProjection(mergeExpressions, aggregationBufferAttributes ++ inputAttributes)() (currentBuffer: UnsafeRow, row: InternalRow) => { - mergeProjection.target(currentBuffer) - mergeProjection(joinedRow(currentBuffer, row)) + // Process all expression-based aggregate functions. + expressionAggMergeProjection.target(currentBuffer)(joinedRow(currentBuffer, row)) + // Process all imperative aggregate functions. + var i = 0 + while (i < imperativeAggregateFunctions.length) { + imperativeAggregateFunctions(i).merge(currentBuffer, row) + i += 1 + } } // Final-Complete case (Some(Final), Some(Complete)) => - val nonCompleteAggregateFunctions: Array[DeclarativeAggregate] = - allAggregateFunctions.take(nonCompleteAggregateExpressions.length) - val completeAggregateFunctions: Array[DeclarativeAggregate] = + val completeAggregateFunctions: Array[AggregateFunction2] = allAggregateFunctions.takeRight(completeAggregateExpressions.length) + val completeImperativeAggregateFunctions: Array[ImperativeAggregate] = + completeAggregateFunctions.collect { case func: ImperativeAggregate => func } + val nonCompleteAggregateFunctions: Array[AggregateFunction2] = + allAggregateFunctions.take(nonCompleteAggregateExpressions.length) + val nonCompleteImperativeAggregateFunctions: Array[ImperativeAggregate] = + nonCompleteAggregateFunctions.collect { case func: ImperativeAggregate => func } val completeOffsetExpressions = Seq.fill(completeAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp) val mergeExpressions = - nonCompleteAggregateFunctions.flatMap(_.mergeExpressions) ++ completeOffsetExpressions + nonCompleteAggregateFunctions.flatMap { + case ae: DeclarativeAggregate => ae.mergeExpressions + case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + } ++ completeOffsetExpressions val finalMergeProjection = newMutableProjection(mergeExpressions, aggregationBufferAttributes ++ inputAttributes)() // We do not touch buffer values of aggregate functions with the Final mode. val finalOffsetExpressions = Seq.fill(nonCompleteAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp) - val updateExpressions = - finalOffsetExpressions ++ completeAggregateFunctions.flatMap(_.updateExpressions) + val updateExpressions = finalOffsetExpressions ++ completeAggregateFunctions.flatMap { + case ae: DeclarativeAggregate => ae.updateExpressions + case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + } val completeUpdateProjection = newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() (currentBuffer: UnsafeRow, row: InternalRow) => { val input = joinedRow(currentBuffer, row) - // For all aggregate functions with mode Complete, update the given currentBuffer. + // For all aggregate functions with mode Complete, update buffers. completeUpdateProjection.target(currentBuffer)(input) + var i = 0 + while (i < completeImperativeAggregateFunctions.length) { + completeImperativeAggregateFunctions(i).update(currentBuffer, row) + i += 1 + } // For all aggregate functions with mode Final, merge buffer values in row to // currentBuffer. finalMergeProjection.target(currentBuffer)(input) + i = 0 + while (i < nonCompleteImperativeAggregateFunctions.length) { + nonCompleteImperativeAggregateFunctions(i).merge(currentBuffer, row) + i += 1 + } } // Complete-only case (None, Some(Complete)) => - val completeAggregateFunctions: Array[DeclarativeAggregate] = + val completeAggregateFunctions: Array[AggregateFunction2] = allAggregateFunctions.takeRight(completeAggregateExpressions.length) + // All imperative aggregate functions with mode Complete. + val completeImperativeAggregateFunctions: Array[ImperativeAggregate] = + completeAggregateFunctions.collect { case func: ImperativeAggregate => func } - val updateExpressions = - completeAggregateFunctions.flatMap(_.updateExpressions) - val completeUpdateProjection = + val updateExpressions = completeAggregateFunctions.flatMap { + case ae: DeclarativeAggregate => ae.updateExpressions + case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + } + val completeExpressionAggUpdateProjection = newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() (currentBuffer: UnsafeRow, row: InternalRow) => { - completeUpdateProjection.target(currentBuffer) - // For all aggregate functions with mode Complete, update the given currentBuffer. - completeUpdateProjection(joinedRow(currentBuffer, row)) + // For all aggregate functions with mode Complete, update buffers. + completeExpressionAggUpdateProjection.target(currentBuffer)(joinedRow(currentBuffer, row)) + var i = 0 + while (i < completeImperativeAggregateFunctions.length) { + completeImperativeAggregateFunctions(i).update(currentBuffer, row) + i += 1 + } } // Grouping only. @@ -288,17 +404,30 @@ class TungstenAggregationIterator( val joinedRow = new JoinedRow() val evalExpressions = allAggregateFunctions.map { case ae: DeclarativeAggregate => ae.evaluateExpression - // case agg: AggregateFunction2 => Literal.create(null, agg.dataType) + case agg: AggregateFunction2 => NoOp } - val expressionAggEvalProjection = UnsafeProjection.create(evalExpressions, bufferAttributes) + val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferAttributes)() // These are the attributes of the row produced by `expressionAggEvalProjection` val aggregateResultSchema = nonCompleteAggregateAttributes ++ completeAggregateAttributes + val aggregateResult = new SpecificMutableRow(aggregateResultSchema.map(_.dataType)) + expressionAggEvalProjection.target(aggregateResult) val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes ++ aggregateResultSchema) + val allImperativeAggregateFunctions: Array[ImperativeAggregate] = + allAggregateFunctions.collect { case func: ImperativeAggregate => func} + (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => { // Generate results for all expression-based aggregate functions. - val aggregateResult = expressionAggEvalProjection.apply(currentBuffer) + expressionAggEvalProjection(currentBuffer) + // Generate results for all imperative aggregate functions. + var i = 0 + while (i < allImperativeAggregateFunctions.length) { + aggregateResult.update( + allImperativeAggregateFunctionPositions(i), + allImperativeAggregateFunctions(i).eval(currentBuffer)) + i += 1 + } resultProjection(joinedRow(currentGroupingKey, aggregateResult)) } @@ -481,10 +610,27 @@ class TungstenAggregationIterator( // When needsProcess is false, the format of input rows is groupingKey + aggregation buffer. // We need to project the aggregation buffer part from an input row. val buffer = createNewAggregationBuffer() - // The originalInputAttributes are using cloneBufferAttributes. So, we need to use - // allAggregateFunctions.flatMap(_.cloneBufferAttributes). + // In principle, we could use `allAggregateFunctions.flatMap(_.inputAggBufferAttributes)` to + // extract the aggregation buffer. In practice, however, we extract it positionally by relying + // on it being present at the end of the row. The reason for this relates to how the different + // aggregates handle input binding. + // + // ImperativeAggregate uses field numbers and field number offsets to manipulate its buffers, + // so its correctness does not rely on attribute bindings. When we fall back to sort-based + // aggregation, these field number offsets (mutableAggBufferOffset and inputAggBufferOffset) + // need to be updated and any internal state in the aggregate functions themselves must be + // reset, so we call withNewMutableAggBufferOffset and withNewInputAggBufferOffset to reset + // this state and update the offsets. + // + // The updated ImperativeAggregate will have different attribute ids for its + // aggBufferAttributes and inputAggBufferAttributes. This isn't a problem for the actual + // ImperativeAggregate evaluation, but it means that + // `allAggregateFunctions.flatMap(_.inputAggBufferAttributes)` will no longer match the + // attributes in `originalInputAttributes`, which is why we can't use those attributes here. + // + // For more details, see the discussion on PR #9038. val bufferExtractor = newMutableProjection( - allAggregateFunctions.flatMap(_.inputAggBufferAttributes), + originalInputAttributes.drop(initialInputBufferOffset), originalInputAttributes)() bufferExtractor.target(buffer) @@ -511,8 +657,10 @@ class TungstenAggregationIterator( } aggregationMode = newAggregationMode + allAggregateFunctions = initializeAllAggregateFunctions(startingInputBufferOffset = 0) + // Basically the value of the KVIterator returned by externalSorter - // will just aggregation buffer. At here, we use cloneBufferAttributes. + // will just aggregation buffer. At here, we use inputAggBufferAttributes. val newInputAttributes: Seq[Attribute] = allAggregateFunctions.flatMap(_.inputAggBufferAttributes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index fd02be1225..d2f56e0fc1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -321,9 +321,17 @@ private[sql] class InputAggregationBuffer private[sql] ( */ private[sql] case class ScalaUDAF( children: Seq[Expression], - udaf: UserDefinedAggregateFunction) + udaf: UserDefinedAggregateFunction, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) extends ImperativeAggregate with Logging { + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + require( children.length == udaf.inputSchema.length, s"$udaf only accepts ${udaf.inputSchema.length} arguments, " + @@ -341,6 +349,11 @@ private[sql] case class ScalaUDAF( override val aggBufferAttributes: Seq[AttributeReference] = aggBufferSchema.toAttributes + // Note: although this simply copies aggBufferAttributes, this common code can not be placed + // in the superclass because that will lead to initialization ordering issues. + override val inputAggBufferAttributes: Seq[AttributeReference] = + aggBufferAttributes.map(_.newInstance()) + private[this] lazy val childrenSchema: StructType = { val inputFields = children.zipWithIndex.map { case (child, index) => @@ -382,51 +395,33 @@ private[sql] case class ScalaUDAF( } // This buffer is only used at executor side. - private[this] var inputAggregateBuffer: InputAggregationBuffer = null - - // This buffer is only used at executor side. - private[this] var mutableAggregateBuffer: MutableAggregationBufferImpl = null + private[this] lazy val inputAggregateBuffer: InputAggregationBuffer = { + new InputAggregationBuffer( + aggBufferSchema, + bufferValuesToCatalystConverters, + bufferValuesToScalaConverters, + inputAggBufferOffset, + null) + } // This buffer is only used at executor side. - private[this] var evalAggregateBuffer: InputAggregationBuffer = null - - /** - * Sets the inputBufferOffset to newInputBufferOffset and then create a new instance of - * `inputAggregateBuffer` based on this new inputBufferOffset. - */ - override def withNewInputAggBufferOffset(newInputBufferOffset: Int): Unit = { - super.withNewInputAggBufferOffset(newInputBufferOffset) - // inputBufferOffset has been updated. - inputAggregateBuffer = - new InputAggregationBuffer( - aggBufferSchema, - bufferValuesToCatalystConverters, - bufferValuesToScalaConverters, - inputAggBufferOffset, - null) + private[this] lazy val mutableAggregateBuffer: MutableAggregationBufferImpl = { + new MutableAggregationBufferImpl( + aggBufferSchema, + bufferValuesToCatalystConverters, + bufferValuesToScalaConverters, + mutableAggBufferOffset, + null) } - /** - * Sets the mutableBufferOffset to newMutableBufferOffset and then create a new instance of - * `mutableAggregateBuffer` and `evalAggregateBuffer` based on this new mutableBufferOffset. - */ - override def withNewMutableAggBufferOffset(newMutableBufferOffset: Int): Unit = { - super.withNewMutableAggBufferOffset(newMutableBufferOffset) - // mutableBufferOffset has been updated. - mutableAggregateBuffer = - new MutableAggregationBufferImpl( - aggBufferSchema, - bufferValuesToCatalystConverters, - bufferValuesToScalaConverters, - mutableAggBufferOffset, - null) - evalAggregateBuffer = - new InputAggregationBuffer( - aggBufferSchema, - bufferValuesToCatalystConverters, - bufferValuesToScalaConverters, - mutableAggBufferOffset, - null) + // This buffer is only used at executor side. + private[this] lazy val evalAggregateBuffer: InputAggregationBuffer = { + new InputAggregationBuffer( + aggBufferSchema, + bufferValuesToCatalystConverters, + bufferValuesToScalaConverters, + mutableAggBufferOffset, + null) } override def initialize(buffer: MutableRow): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala index cf6e7ed0d3..eaafd83158 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala @@ -19,21 +19,12 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.execution.SparkPlan /** * Utility functions used by the query planner to convert our plan to new aggregation code path. */ object Utils { - def supportsTungstenAggregate( - groupingExpressions: Seq[Expression], - aggregateBufferAttributes: Seq[Attribute]): Boolean = { - val aggregationBufferSchema = StructType.fromAttributes(aggregateBufferAttributes) - - UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) && - UnsafeProjection.canSupport(groupingExpressions) - } def planAggregateWithoutPartial( groupingExpressions: Seq[NamedExpression], @@ -70,8 +61,7 @@ object Utils { // Check if we can use TungstenAggregate. val usesTungstenAggregate = child.sqlContext.conf.unsafeEnabled && - aggregateExpressions.forall(_.aggregateFunction.isInstanceOf[DeclarativeAggregate]) && - supportsTungstenAggregate( + TungstenAggregate.supportsAggregate( groupingExpressions, aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) @@ -94,6 +84,7 @@ object Utils { nonCompleteAggregateAttributes = partialAggregateAttributes, completeAggregateExpressions = Nil, completeAggregateAttributes = Nil, + initialInputBufferOffset = 0, resultExpressions = partialResultExpressions, child = child) } else { @@ -125,6 +116,7 @@ object Utils { nonCompleteAggregateAttributes = finalAggregateAttributes, completeAggregateExpressions = Nil, completeAggregateAttributes = Nil, + initialInputBufferOffset = groupingExpressions.length, resultExpressions = resultExpressions, child = partialAggregate) } else { @@ -154,143 +146,150 @@ object Utils { val aggregateExpressions = functionsWithDistinct ++ functionsWithoutDistinct val usesTungstenAggregate = child.sqlContext.conf.unsafeEnabled && - aggregateExpressions.forall( - _.aggregateFunction.isInstanceOf[DeclarativeAggregate]) && - supportsTungstenAggregate( + TungstenAggregate.supportsAggregate( groupingExpressions, aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) - // 1. Create an Aggregate Operator for partial aggregations. - val groupingAttributes = groupingExpressions.map(_.toAttribute) - - // It is safe to call head at here since functionsWithDistinct has at least one - // AggregateExpression2. - val distinctColumnExpressions = - functionsWithDistinct.head.aggregateFunction.children - val namedDistinctColumnExpressions = distinctColumnExpressions.map { - case ne: NamedExpression => ne -> ne - case other => - val withAlias = Alias(other, other.toString)() - other -> withAlias + // functionsWithDistinct is guaranteed to be non-empty. Even though it may contain more than one + // DISTINCT aggregate function, all of those functions will have the same column expression. + // For example, it would be valid for functionsWithDistinct to be + // [COUNT(DISTINCT foo), MAX(DISTINCT foo)], but [COUNT(DISTINCT bar), COUNT(DISTINCT foo)] is + // disallowed because those two distinct aggregates have different column expressions. + val distinctColumnExpression: Expression = { + val allDistinctColumnExpressions = functionsWithDistinct.head.aggregateFunction.children + assert(allDistinctColumnExpressions.length == 1) + allDistinctColumnExpressions.head + } + val namedDistinctColumnExpression: NamedExpression = distinctColumnExpression match { + case ne: NamedExpression => ne + case other => Alias(other, other.toString)() } - val distinctColumnExpressionMap = namedDistinctColumnExpressions.toMap - val distinctColumnAttributes = namedDistinctColumnExpressions.map(_._2.toAttribute) + val distinctColumnAttribute: Attribute = namedDistinctColumnExpression.toAttribute + val groupingAttributes = groupingExpressions.map(_.toAttribute) - val partialAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) - val partialAggregateAttributes = - partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) - val partialAggregateGroupingExpressions = - groupingExpressions ++ namedDistinctColumnExpressions.map(_._2) - val partialAggregateResult = + // 1. Create an Aggregate Operator for partial aggregations. + val partialAggregate: SparkPlan = { + val partialAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) + val partialAggregateAttributes = + partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) + // We will group by the original grouping expression, plus an additional expression for the + // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping + // expressions will be [key, value]. + val partialAggregateGroupingExpressions = groupingExpressions :+ namedDistinctColumnExpression + val partialAggregateResult = groupingAttributes ++ - distinctColumnAttributes ++ - partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) - val partialAggregate = if (usesTungstenAggregate) { - TungstenAggregate( - requiredChildDistributionExpressions = None, - // The grouping expressions are original groupingExpressions and - // distinct columns. For example, for avg(distinct value) ... group by key - // the grouping expressions of this Aggregate Operator will be [key, value]. - groupingExpressions = partialAggregateGroupingExpressions, - nonCompleteAggregateExpressions = partialAggregateExpressions, - nonCompleteAggregateAttributes = partialAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, - resultExpressions = partialAggregateResult, - child = child) - } else { - SortBasedAggregate( - requiredChildDistributionExpressions = None, - groupingExpressions = partialAggregateGroupingExpressions, - nonCompleteAggregateExpressions = partialAggregateExpressions, - nonCompleteAggregateAttributes = partialAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, - initialInputBufferOffset = 0, - resultExpressions = partialAggregateResult, - child = child) + Seq(distinctColumnAttribute) ++ + partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) + if (usesTungstenAggregate) { + TungstenAggregate( + requiredChildDistributionExpressions = None, + groupingExpressions = partialAggregateGroupingExpressions, + nonCompleteAggregateExpressions = partialAggregateExpressions, + nonCompleteAggregateAttributes = partialAggregateAttributes, + completeAggregateExpressions = Nil, + completeAggregateAttributes = Nil, + initialInputBufferOffset = 0, + resultExpressions = partialAggregateResult, + child = child) + } else { + SortBasedAggregate( + requiredChildDistributionExpressions = None, + groupingExpressions = partialAggregateGroupingExpressions, + nonCompleteAggregateExpressions = partialAggregateExpressions, + nonCompleteAggregateAttributes = partialAggregateAttributes, + completeAggregateExpressions = Nil, + completeAggregateAttributes = Nil, + initialInputBufferOffset = 0, + resultExpressions = partialAggregateResult, + child = child) + } } // 2. Create an Aggregate Operator for partial merge aggregations. - val partialMergeAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) - val partialMergeAggregateAttributes = - partialMergeAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) - val partialMergeAggregateResult = + val partialMergeAggregate: SparkPlan = { + val partialMergeAggregateExpressions = + functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) + val partialMergeAggregateAttributes = + partialMergeAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) + val partialMergeAggregateResult = groupingAttributes ++ - distinctColumnAttributes ++ - partialMergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) - val partialMergeAggregate = if (usesTungstenAggregate) { - TungstenAggregate( - requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes ++ distinctColumnAttributes, - nonCompleteAggregateExpressions = partialMergeAggregateExpressions, - nonCompleteAggregateAttributes = partialMergeAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, - resultExpressions = partialMergeAggregateResult, - child = partialAggregate) - } else { - SortBasedAggregate( - requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes ++ distinctColumnAttributes, - nonCompleteAggregateExpressions = partialMergeAggregateExpressions, - nonCompleteAggregateAttributes = partialMergeAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, - initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length, - resultExpressions = partialMergeAggregateResult, - child = partialAggregate) + Seq(distinctColumnAttribute) ++ + partialMergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) + if (usesTungstenAggregate) { + TungstenAggregate( + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes :+ distinctColumnAttribute, + nonCompleteAggregateExpressions = partialMergeAggregateExpressions, + nonCompleteAggregateAttributes = partialMergeAggregateAttributes, + completeAggregateExpressions = Nil, + completeAggregateAttributes = Nil, + initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length, + resultExpressions = partialMergeAggregateResult, + child = partialAggregate) + } else { + SortBasedAggregate( + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes :+ distinctColumnAttribute, + nonCompleteAggregateExpressions = partialMergeAggregateExpressions, + nonCompleteAggregateAttributes = partialMergeAggregateAttributes, + completeAggregateExpressions = Nil, + completeAggregateAttributes = Nil, + initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length, + resultExpressions = partialMergeAggregateResult, + child = partialAggregate) + } } - // 3. Create an Aggregate Operator for partial merge aggregations. - val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) - // The attributes of the final aggregation buffer, which is presented as input to the result - // projection: - val finalAggregateAttributes = finalAggregateExpressions.map { - expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) - } + // 3. Create an Aggregate Operator for the final aggregation. + val finalAndCompleteAggregate: SparkPlan = { + val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) + // The attributes of the final aggregation buffer, which is presented as input to the result + // projection: + val finalAggregateAttributes = finalAggregateExpressions.map { + expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) + } - val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map { - // Children of an AggregateFunction with DISTINCT keyword has already - // been evaluated. At here, we need to replace original children - // to AttributeReferences. - case agg @ AggregateExpression2(aggregateFunction, mode, true) => - val rewrittenAggregateFunction = aggregateFunction.transformDown { - case expr if distinctColumnExpressionMap.contains(expr) => - distinctColumnExpressionMap(expr).toAttribute - }.asInstanceOf[AggregateFunction2] - // We rewrite the aggregate function to a non-distinct aggregation because - // its input will have distinct arguments. - // We just keep the isDistinct setting to true, so when users look at the query plan, - // they still can see distinct aggregations. - val rewrittenAggregateExpression = - AggregateExpression2(rewrittenAggregateFunction, Complete, true) + val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map { + // Children of an AggregateFunction with DISTINCT keyword has already + // been evaluated. At here, we need to replace original children + // to AttributeReferences. + case agg @ AggregateExpression2(aggregateFunction, mode, true) => + val rewrittenAggregateFunction = aggregateFunction.transformDown { + case expr if expr == distinctColumnExpression => distinctColumnAttribute + }.asInstanceOf[AggregateFunction2] + // We rewrite the aggregate function to a non-distinct aggregation because + // its input will have distinct arguments. + // We just keep the isDistinct setting to true, so when users look at the query plan, + // they still can see distinct aggregations. + val rewrittenAggregateExpression = + AggregateExpression2(rewrittenAggregateFunction, Complete, isDistinct = true) - val aggregateFunctionAttribute = aggregateFunctionToAttribute(agg.aggregateFunction, true) - (rewrittenAggregateExpression, aggregateFunctionAttribute) - }.unzip - - val finalAndCompleteAggregate = if (usesTungstenAggregate) { - TungstenAggregate( - requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes, - nonCompleteAggregateExpressions = finalAggregateExpressions, - nonCompleteAggregateAttributes = finalAggregateAttributes, - completeAggregateExpressions = completeAggregateExpressions, - completeAggregateAttributes = completeAggregateAttributes, - resultExpressions = resultExpressions, - child = partialMergeAggregate) - } else { - SortBasedAggregate( - requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes, - nonCompleteAggregateExpressions = finalAggregateExpressions, - nonCompleteAggregateAttributes = finalAggregateAttributes, - completeAggregateExpressions = completeAggregateExpressions, - completeAggregateAttributes = completeAggregateAttributes, - initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length, - resultExpressions = resultExpressions, - child = partialMergeAggregate) + val aggregateFunctionAttribute = aggregateFunctionToAttribute(agg.aggregateFunction, true) + (rewrittenAggregateExpression, aggregateFunctionAttribute) + }.unzip + if (usesTungstenAggregate) { + TungstenAggregate( + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes, + nonCompleteAggregateExpressions = finalAggregateExpressions, + nonCompleteAggregateAttributes = finalAggregateAttributes, + completeAggregateExpressions = completeAggregateExpressions, + completeAggregateAttributes = completeAggregateAttributes, + initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length, + resultExpressions = resultExpressions, + child = partialMergeAggregate) + } else { + SortBasedAggregate( + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes, + nonCompleteAggregateExpressions = finalAggregateExpressions, + nonCompleteAggregateAttributes = finalAggregateAttributes, + completeAggregateExpressions = completeAggregateExpressions, + completeAggregateAttributes = completeAggregateAttributes, + initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length, + resultExpressions = resultExpressions, + child = partialMergeAggregate) + } } finalAndCompleteAggregate :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala index ed974b3a53..0cc4988ff6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala @@ -39,7 +39,7 @@ class TungstenAggregationIteratorSuite extends SparkFunSuite with SharedSQLConte } val dummyAccum = SQLMetrics.createLongMetric(sparkContext, "dummy") iter = new TungstenAggregationIterator(Seq.empty, Seq.empty, Seq.empty, Seq.empty, Seq.empty, - Seq.empty, newMutableProjection, Seq.empty, None, dummyAccum, dummyAccum) + 0, Seq.empty, newMutableProjection, Seq.empty, None, dummyAccum, dummyAccum) val numPages = iter.getHashMap.getNumDataPages assert(numPages === 1) } finally { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 18bbdb9908..a2ebf6552f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -553,10 +553,16 @@ private[hive] case class HiveGenericUDTF( private[hive] case class HiveUDAFFunction( funcWrapper: HiveFunctionWrapper, children: Seq[Expression], - isUDAFBridgeRequired: Boolean = false) + isUDAFBridgeRequired: Boolean = false, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) extends ImperativeAggregate with HiveInspectors { - def this() = this(null, null) + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) @transient private lazy val resolver = @@ -614,7 +620,11 @@ private[hive] case class HiveUDAFFunction( buffer = function.getNewAggregationBuffer } - override def aggBufferAttributes: Seq[AttributeReference] = Nil + override val aggBufferAttributes: Seq[AttributeReference] = Nil + + // Note: although this simply copies aggBufferAttributes, this common code can not be placed + // in the superclass because that will lead to initialization ordering issues. + override val inputAggBufferAttributes: Seq[AttributeReference] = Nil // We rely on Hive to check the input data types, so use `AnyDataType` here to bypass our // catalyst type checking framework. |