aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala19
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala31
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala29
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala22
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala250
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala79
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala269
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala2
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala16
9 files changed, 457 insertions, 260 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
index 8aad0b7dee..c0bc7ec09c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
@@ -472,10 +472,20 @@ case class Sum(child: Expression) extends DeclarativeAggregate {
* @param relativeSD the maximum estimation error allowed.
*/
// scalastyle:on
-case class HyperLogLogPlusPlus(child: Expression, relativeSD: Double = 0.05)
- extends ImperativeAggregate {
+case class HyperLogLogPlusPlus(
+ child: Expression,
+ relativeSD: Double = 0.05,
+ mutableAggBufferOffset: Int = 0,
+ inputAggBufferOffset: Int = 0)
+ extends ImperativeAggregate {
import HyperLogLogPlusPlus._
+ override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
+ copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+ override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
+ copy(inputAggBufferOffset = newInputAggBufferOffset)
+
/**
* HLL++ uses 'p' bits for addressing. The more addressing bits we use, the more precise the
* algorithm will be, and the more memory it will require. The 'p' value is based on the relative
@@ -546,6 +556,11 @@ case class HyperLogLogPlusPlus(child: Expression, relativeSD: Double = 0.05)
AttributeReference(s"MS[$i]", LongType)()
}
+ // Note: although this simply copies aggBufferAttributes, this common code can not be placed
+ // in the superclass because that will lead to initialization ordering issues.
+ override val inputAggBufferAttributes: Seq[AttributeReference] =
+ aggBufferAttributes.map(_.newInstance())
+
/** Fill all words with zeros. */
override def initialize(buffer: MutableRow): Unit = {
var word = 0
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index 9ba3a9c980..a2fab258fc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -150,6 +150,10 @@ sealed abstract class AggregateFunction2 extends Expression with ImplicitCastInp
* We need to perform similar field number arithmetic when merging multiple intermediate
* aggregate buffers together in `merge()` (in this case, use `inputAggBufferOffset` when accessing
* the input buffer).
+ *
+ * Correct ImperativeAggregate evaluation depends on the correctness of `mutableAggBufferOffset` and
+ * `inputAggBufferOffset`, but not on the correctness of the attribute ids in `aggBufferAttributes`
+ * and `inputAggBufferAttributes`.
*/
abstract class ImperativeAggregate extends AggregateFunction2 {
@@ -172,11 +176,13 @@ abstract class ImperativeAggregate extends AggregateFunction2 {
* avg(y) mutableAggBufferOffset = 2
*
*/
- protected var mutableAggBufferOffset: Int = 0
+ protected val mutableAggBufferOffset: Int
- def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): Unit = {
- mutableAggBufferOffset = newMutableAggBufferOffset
- }
+ /**
+ * Returns a copy of this ImperativeAggregate with an updated mutableAggBufferOffset.
+ * This new copy's attributes may have different ids than the original.
+ */
+ def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate
/**
* The offset of this function's start buffer value in the underlying shared input aggregation
@@ -203,11 +209,17 @@ abstract class ImperativeAggregate extends AggregateFunction2 {
* avg(y) inputAggBufferOffset = 3
*
*/
- protected var inputAggBufferOffset: Int = 0
+ protected val inputAggBufferOffset: Int
- def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): Unit = {
- inputAggBufferOffset = newInputAggBufferOffset
- }
+ /**
+ * Returns a copy of this ImperativeAggregate with an updated mutableAggBufferOffset.
+ * This new copy's attributes may have different ids than the original.
+ */
+ def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate
+
+ // Note: although all subclasses implement inputAggBufferAttributes by simply cloning
+ // aggBufferAttributes, that common clone code cannot be placed here in the abstract
+ // ImperativeAggregate class, since that will lead to initialization ordering issues.
/**
* Initializes the mutable aggregation buffer located in `mutableAggBuffer`.
@@ -231,9 +243,6 @@ abstract class ImperativeAggregate extends AggregateFunction2 {
* Use `fieldNumber + inputAggBufferOffset` to access fields of `inputAggBuffer`.
*/
def merge(mutableAggBuffer: MutableRow, inputAggBuffer: InternalRow): Unit
-
- final lazy val inputAggBufferAttributes: Seq[AttributeReference] =
- aggBufferAttributes.map(_.newInstance())
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
index 8e0fbd109b..99fb7a40b7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
@@ -83,7 +83,7 @@ abstract class AggregationIterator(
var i = 0
while (i < allAggregateExpressions.length) {
val func = allAggregateExpressions(i).aggregateFunction
- val funcWithBoundReferences = allAggregateExpressions(i).mode match {
+ val funcWithBoundReferences: AggregateFunction2 = allAggregateExpressions(i).mode match {
case Partial | Complete if func.isInstanceOf[ImperativeAggregate] =>
// We need to create BoundReferences if the function is not an
// expression-based aggregate function (it does not support code-gen) and the mode of
@@ -94,24 +94,24 @@ abstract class AggregationIterator(
case _ =>
// We only need to set inputBufferOffset for aggregate functions with mode
// PartialMerge and Final.
- func match {
+ val updatedFunc = func match {
case function: ImperativeAggregate =>
function.withNewInputAggBufferOffset(inputBufferOffset)
- case _ =>
+ case function => function
}
inputBufferOffset += func.aggBufferSchema.length
- func
+ updatedFunc
}
- // Set mutableBufferOffset for this function. It is important that setting
- // mutableBufferOffset happens after all potential bindReference operations
- // because bindReference will create a new instance of the function.
- funcWithBoundReferences match {
+ val funcWithUpdatedAggBufferOffset = funcWithBoundReferences match {
case function: ImperativeAggregate =>
+ // Set mutableBufferOffset for this function. It is important that setting
+ // mutableBufferOffset happens after all potential bindReference operations
+ // because bindReference will create a new instance of the function.
function.withNewMutableAggBufferOffset(mutableBufferOffset)
- case _ =>
+ case function => function
}
- mutableBufferOffset += funcWithBoundReferences.aggBufferSchema.length
- functions(i) = funcWithBoundReferences
+ mutableBufferOffset += funcWithUpdatedAggBufferOffset.aggBufferSchema.length
+ functions(i) = funcWithUpdatedAggBufferOffset
i += 1
}
functions
@@ -320,7 +320,7 @@ abstract class AggregationIterator(
// Initializing the function used to generate the output row.
protected val generateOutput: (InternalRow, MutableRow) => InternalRow = {
val rowToBeEvaluated = new JoinedRow
- val safeOutputRow = new GenericMutableRow(resultExpressions.length)
+ val safeOutputRow = new SpecificMutableRow(resultExpressions.map(_.dataType))
val mutableOutput = if (outputsUnsafeRows) {
UnsafeProjection.create(resultExpressions.map(_.dataType).toArray).apply(safeOutputRow)
} else {
@@ -358,7 +358,8 @@ abstract class AggregationIterator(
val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferSchemata)()
val aggregateResultSchema = nonCompleteAggregateAttributes ++ completeAggregateAttributes
// TODO: Use unsafe row.
- val aggregateResult = new GenericMutableRow(aggregateResultSchema.length)
+ val aggregateResult = new SpecificMutableRow(aggregateResultSchema.map(_.dataType))
+ expressionAggEvalProjection.target(aggregateResult)
val resultProjection =
newMutableProjection(
resultExpressions, groupingKeyAttributes ++ aggregateResultSchema)()
@@ -366,7 +367,7 @@ abstract class AggregationIterator(
(currentGroupingKey: InternalRow, currentBuffer: MutableRow) => {
// Generate results for all expression-based aggregate functions.
- expressionAggEvalProjection.target(aggregateResult)(currentBuffer)
+ expressionAggEvalProjection(currentBuffer)
// Generate results for all imperative aggregate functions.
var i = 0
while (i < allImperativeAggregateFunctions.length) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index 7b3d072b2e..c342940e6e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -24,8 +24,9 @@ import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.{UnaryNode, SparkPlan}
+import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, UnaryNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
+import org.apache.spark.sql.types.StructType
case class TungstenAggregate(
requiredChildDistributionExpressions: Option[Seq[Expression]],
@@ -34,10 +35,18 @@ case class TungstenAggregate(
nonCompleteAggregateAttributes: Seq[Attribute],
completeAggregateExpressions: Seq[AggregateExpression2],
completeAggregateAttributes: Seq[Attribute],
+ initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
extends UnaryNode {
+ private[this] val aggregateBufferAttributes = {
+ (nonCompleteAggregateExpressions ++ completeAggregateExpressions)
+ .flatMap(_.aggregateFunction.aggBufferAttributes)
+ }
+
+ require(TungstenAggregate.supportsAggregate(groupingExpressions, aggregateBufferAttributes))
+
override private[sql] lazy val metrics = Map(
"numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"),
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
@@ -82,6 +91,7 @@ case class TungstenAggregate(
nonCompleteAggregateAttributes,
completeAggregateExpressions,
completeAggregateAttributes,
+ initialInputBufferOffset,
resultExpressions,
newMutableProjection,
child.output,
@@ -138,3 +148,13 @@ case class TungstenAggregate(
}
}
}
+
+object TungstenAggregate {
+ def supportsAggregate(
+ groupingExpressions: Seq[Expression],
+ aggregateBufferAttributes: Seq[Attribute]): Boolean = {
+ val aggregationBufferSchema = StructType.fromAttributes(aggregateBufferAttributes)
+ UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) &&
+ UnsafeProjection.canSupport(groupingExpressions)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
index 4bb95c9eb7..fe708a5f71 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution.aggregate
+import scala.collection.mutable.ArrayBuffer
+
import org.apache.spark.unsafe.KVIterator
import org.apache.spark.{InternalAccumulator, Logging, SparkEnv, TaskContext}
import org.apache.spark.sql.catalyst.expressions._
@@ -79,6 +81,7 @@ class TungstenAggregationIterator(
nonCompleteAggregateAttributes: Seq[Attribute],
completeAggregateExpressions: Seq[AggregateExpression2],
completeAggregateAttributes: Seq[Attribute],
+ initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
originalInputAttributes: Seq[Attribute],
@@ -134,19 +137,74 @@ class TungstenAggregationIterator(
completeAggregateExpressions.map(_.mode).distinct.headOption
}
- // All aggregate functions. TungstenAggregationIterator only handles expression-based aggregate.
- // If there is any functions that is an ImperativeAggregateFunction, we throw an
- // IllegalStateException.
- private[this] val allAggregateFunctions: Array[DeclarativeAggregate] = {
- if (!allAggregateExpressions.forall(
- _.aggregateFunction.isInstanceOf[DeclarativeAggregate])) {
- throw new IllegalStateException(
- "Only ExpressionAggregateFunctions should be passed in TungstenAggregationIterator.")
+ // Initialize all AggregateFunctions by binding references, if necessary,
+ // and setting inputBufferOffset and mutableBufferOffset.
+ private def initializeAllAggregateFunctions(
+ startingInputBufferOffset: Int): Array[AggregateFunction2] = {
+ var mutableBufferOffset = 0
+ var inputBufferOffset: Int = startingInputBufferOffset
+ val functions = new Array[AggregateFunction2](allAggregateExpressions.length)
+ var i = 0
+ while (i < allAggregateExpressions.length) {
+ val func = allAggregateExpressions(i).aggregateFunction
+ val aggregateExpressionIsNonComplete = i < nonCompleteAggregateExpressions.length
+ // We need to use this mode instead of func.mode in order to handle aggregation mode switching
+ // when switching to sort-based aggregation:
+ val mode = if (aggregateExpressionIsNonComplete) aggregationMode._1 else aggregationMode._2
+ val funcWithBoundReferences = mode match {
+ case Some(Partial) | Some(Complete) if func.isInstanceOf[ImperativeAggregate] =>
+ // We need to create BoundReferences if the function is not an
+ // expression-based aggregate function (it does not support code-gen) and the mode of
+ // this function is Partial or Complete because we will call eval of this
+ // function's children in the update method of this aggregate function.
+ // Those eval calls require BoundReferences to work.
+ BindReferences.bindReference(func, originalInputAttributes)
+ case _ =>
+ // We only need to set inputBufferOffset for aggregate functions with mode
+ // PartialMerge and Final.
+ val updatedFunc = func match {
+ case function: ImperativeAggregate =>
+ function.withNewInputAggBufferOffset(inputBufferOffset)
+ case function => function
+ }
+ inputBufferOffset += func.aggBufferSchema.length
+ updatedFunc
+ }
+ val funcWithUpdatedAggBufferOffset = funcWithBoundReferences match {
+ case function: ImperativeAggregate =>
+ // Set mutableBufferOffset for this function. It is important that setting
+ // mutableBufferOffset happens after all potential bindReference operations
+ // because bindReference will create a new instance of the function.
+ function.withNewMutableAggBufferOffset(mutableBufferOffset)
+ case function => function
+ }
+ mutableBufferOffset += funcWithUpdatedAggBufferOffset.aggBufferSchema.length
+ functions(i) = funcWithUpdatedAggBufferOffset
+ i += 1
}
+ functions
+ }
- allAggregateExpressions
- .map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
- .toArray
+ private[this] var allAggregateFunctions: Array[AggregateFunction2] =
+ initializeAllAggregateFunctions(initialInputBufferOffset)
+
+ // Positions of those imperative aggregate functions in allAggregateFunctions.
+ // For example, say that we have func1, func2, func3, func4 in aggregateFunctions, and
+ // func2 and func3 are imperative aggregate functions. Then
+ // allImperativeAggregateFunctionPositions will be [1, 2]. Note that this does not need to be
+ // updated when falling back to sort-based aggregation because the positions of the aggregate
+ // functions do not change in that case.
+ private[this] val allImperativeAggregateFunctionPositions: Array[Int] = {
+ val positions = new ArrayBuffer[Int]()
+ var i = 0
+ while (i < allAggregateFunctions.length) {
+ allAggregateFunctions(i) match {
+ case agg: DeclarativeAggregate =>
+ case _ => positions += i
+ }
+ i += 1
+ }
+ positions.toArray
}
///////////////////////////////////////////////////////////////////////////
@@ -155,25 +213,31 @@ class TungstenAggregationIterator(
// rows.
///////////////////////////////////////////////////////////////////////////
- // The projection used to initialize buffer values.
- private[this] val initialProjection: MutableProjection = {
- val initExpressions = allAggregateFunctions.flatMap(_.initialValues)
+ // The projection used to initialize buffer values for all expression-based aggregates.
+ // Note that this projection does not need to be updated when switching to sort-based aggregation
+ // because the schema of empty aggregation buffers does not change in that case.
+ private[this] val expressionAggInitialProjection: MutableProjection = {
+ val initExpressions = allAggregateFunctions.flatMap {
+ case ae: DeclarativeAggregate => ae.initialValues
+ // For the positions corresponding to imperative aggregate functions, we'll use special
+ // no-op expressions which are ignored during projection code-generation.
+ case i: ImperativeAggregate => Seq.fill(i.aggBufferAttributes.length)(NoOp)
+ }
newMutableProjection(initExpressions, Nil)()
}
// Creates a new aggregation buffer and initializes buffer values.
- // This functions should be only called at most three times (when we create the hash map,
+ // This function should be only called at most three times (when we create the hash map,
// when we switch to sort-based aggregation, and when we create the re-used buffer for
// sort-based aggregation).
private def createNewAggregationBuffer(): UnsafeRow = {
val bufferSchema = allAggregateFunctions.flatMap(_.aggBufferAttributes)
- val bufferRowSize: Int = bufferSchema.length
-
- val genericMutableBuffer = new GenericMutableRow(bufferRowSize)
- val unsafeProjection =
- UnsafeProjection.create(bufferSchema.map(_.dataType))
- val buffer = unsafeProjection.apply(genericMutableBuffer)
- initialProjection.target(buffer)(EmptyRow)
+ val buffer: UnsafeRow = UnsafeProjection.create(bufferSchema.map(_.dataType))
+ .apply(new GenericMutableRow(bufferSchema.length))
+ // Initialize declarative aggregates' buffer values
+ expressionAggInitialProjection.target(buffer)(EmptyRow)
+ // Initialize imperative aggregates' buffer values
+ allAggregateFunctions.collect { case f: ImperativeAggregate => f }.foreach(_.initialize(buffer))
buffer
}
@@ -187,72 +251,124 @@ class TungstenAggregationIterator(
aggregationMode match {
// Partial-only
case (Some(Partial), None) =>
- val updateExpressions = allAggregateFunctions.flatMap(_.updateExpressions)
- val updateProjection =
+ val updateExpressions = allAggregateFunctions.flatMap {
+ case ae: DeclarativeAggregate => ae.updateExpressions
+ case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
+ }
+ val imperativeAggregateFunctions: Array[ImperativeAggregate] =
+ allAggregateFunctions.collect { case func: ImperativeAggregate => func}
+ val expressionAggUpdateProjection =
newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)()
(currentBuffer: UnsafeRow, row: InternalRow) => {
- updateProjection.target(currentBuffer)
- updateProjection(joinedRow(currentBuffer, row))
+ expressionAggUpdateProjection.target(currentBuffer)
+ // Process all expression-based aggregate functions.
+ expressionAggUpdateProjection(joinedRow(currentBuffer, row))
+ // Process all imperative aggregate functions
+ var i = 0
+ while (i < imperativeAggregateFunctions.length) {
+ imperativeAggregateFunctions(i).update(currentBuffer, row)
+ i += 1
+ }
}
// PartialMerge-only or Final-only
case (Some(PartialMerge), None) | (Some(Final), None) =>
- val mergeExpressions = allAggregateFunctions.flatMap(_.mergeExpressions)
- val mergeProjection =
+ val mergeExpressions = allAggregateFunctions.flatMap {
+ case ae: DeclarativeAggregate => ae.mergeExpressions
+ case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
+ }
+ val imperativeAggregateFunctions: Array[ImperativeAggregate] =
+ allAggregateFunctions.collect { case func: ImperativeAggregate => func}
+ // This projection is used to merge buffer values for all expression-based aggregates.
+ val expressionAggMergeProjection =
newMutableProjection(mergeExpressions, aggregationBufferAttributes ++ inputAttributes)()
(currentBuffer: UnsafeRow, row: InternalRow) => {
- mergeProjection.target(currentBuffer)
- mergeProjection(joinedRow(currentBuffer, row))
+ // Process all expression-based aggregate functions.
+ expressionAggMergeProjection.target(currentBuffer)(joinedRow(currentBuffer, row))
+ // Process all imperative aggregate functions.
+ var i = 0
+ while (i < imperativeAggregateFunctions.length) {
+ imperativeAggregateFunctions(i).merge(currentBuffer, row)
+ i += 1
+ }
}
// Final-Complete
case (Some(Final), Some(Complete)) =>
- val nonCompleteAggregateFunctions: Array[DeclarativeAggregate] =
- allAggregateFunctions.take(nonCompleteAggregateExpressions.length)
- val completeAggregateFunctions: Array[DeclarativeAggregate] =
+ val completeAggregateFunctions: Array[AggregateFunction2] =
allAggregateFunctions.takeRight(completeAggregateExpressions.length)
+ val completeImperativeAggregateFunctions: Array[ImperativeAggregate] =
+ completeAggregateFunctions.collect { case func: ImperativeAggregate => func }
+ val nonCompleteAggregateFunctions: Array[AggregateFunction2] =
+ allAggregateFunctions.take(nonCompleteAggregateExpressions.length)
+ val nonCompleteImperativeAggregateFunctions: Array[ImperativeAggregate] =
+ nonCompleteAggregateFunctions.collect { case func: ImperativeAggregate => func }
val completeOffsetExpressions =
Seq.fill(completeAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp)
val mergeExpressions =
- nonCompleteAggregateFunctions.flatMap(_.mergeExpressions) ++ completeOffsetExpressions
+ nonCompleteAggregateFunctions.flatMap {
+ case ae: DeclarativeAggregate => ae.mergeExpressions
+ case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
+ } ++ completeOffsetExpressions
val finalMergeProjection =
newMutableProjection(mergeExpressions, aggregationBufferAttributes ++ inputAttributes)()
// We do not touch buffer values of aggregate functions with the Final mode.
val finalOffsetExpressions =
Seq.fill(nonCompleteAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp)
- val updateExpressions =
- finalOffsetExpressions ++ completeAggregateFunctions.flatMap(_.updateExpressions)
+ val updateExpressions = finalOffsetExpressions ++ completeAggregateFunctions.flatMap {
+ case ae: DeclarativeAggregate => ae.updateExpressions
+ case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
+ }
val completeUpdateProjection =
newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)()
(currentBuffer: UnsafeRow, row: InternalRow) => {
val input = joinedRow(currentBuffer, row)
- // For all aggregate functions with mode Complete, update the given currentBuffer.
+ // For all aggregate functions with mode Complete, update buffers.
completeUpdateProjection.target(currentBuffer)(input)
+ var i = 0
+ while (i < completeImperativeAggregateFunctions.length) {
+ completeImperativeAggregateFunctions(i).update(currentBuffer, row)
+ i += 1
+ }
// For all aggregate functions with mode Final, merge buffer values in row to
// currentBuffer.
finalMergeProjection.target(currentBuffer)(input)
+ i = 0
+ while (i < nonCompleteImperativeAggregateFunctions.length) {
+ nonCompleteImperativeAggregateFunctions(i).merge(currentBuffer, row)
+ i += 1
+ }
}
// Complete-only
case (None, Some(Complete)) =>
- val completeAggregateFunctions: Array[DeclarativeAggregate] =
+ val completeAggregateFunctions: Array[AggregateFunction2] =
allAggregateFunctions.takeRight(completeAggregateExpressions.length)
+ // All imperative aggregate functions with mode Complete.
+ val completeImperativeAggregateFunctions: Array[ImperativeAggregate] =
+ completeAggregateFunctions.collect { case func: ImperativeAggregate => func }
- val updateExpressions =
- completeAggregateFunctions.flatMap(_.updateExpressions)
- val completeUpdateProjection =
+ val updateExpressions = completeAggregateFunctions.flatMap {
+ case ae: DeclarativeAggregate => ae.updateExpressions
+ case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
+ }
+ val completeExpressionAggUpdateProjection =
newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)()
(currentBuffer: UnsafeRow, row: InternalRow) => {
- completeUpdateProjection.target(currentBuffer)
- // For all aggregate functions with mode Complete, update the given currentBuffer.
- completeUpdateProjection(joinedRow(currentBuffer, row))
+ // For all aggregate functions with mode Complete, update buffers.
+ completeExpressionAggUpdateProjection.target(currentBuffer)(joinedRow(currentBuffer, row))
+ var i = 0
+ while (i < completeImperativeAggregateFunctions.length) {
+ completeImperativeAggregateFunctions(i).update(currentBuffer, row)
+ i += 1
+ }
}
// Grouping only.
@@ -288,17 +404,30 @@ class TungstenAggregationIterator(
val joinedRow = new JoinedRow()
val evalExpressions = allAggregateFunctions.map {
case ae: DeclarativeAggregate => ae.evaluateExpression
- // case agg: AggregateFunction2 => Literal.create(null, agg.dataType)
+ case agg: AggregateFunction2 => NoOp
}
- val expressionAggEvalProjection = UnsafeProjection.create(evalExpressions, bufferAttributes)
+ val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferAttributes)()
// These are the attributes of the row produced by `expressionAggEvalProjection`
val aggregateResultSchema = nonCompleteAggregateAttributes ++ completeAggregateAttributes
+ val aggregateResult = new SpecificMutableRow(aggregateResultSchema.map(_.dataType))
+ expressionAggEvalProjection.target(aggregateResult)
val resultProjection =
UnsafeProjection.create(resultExpressions, groupingAttributes ++ aggregateResultSchema)
+ val allImperativeAggregateFunctions: Array[ImperativeAggregate] =
+ allAggregateFunctions.collect { case func: ImperativeAggregate => func}
+
(currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => {
// Generate results for all expression-based aggregate functions.
- val aggregateResult = expressionAggEvalProjection.apply(currentBuffer)
+ expressionAggEvalProjection(currentBuffer)
+ // Generate results for all imperative aggregate functions.
+ var i = 0
+ while (i < allImperativeAggregateFunctions.length) {
+ aggregateResult.update(
+ allImperativeAggregateFunctionPositions(i),
+ allImperativeAggregateFunctions(i).eval(currentBuffer))
+ i += 1
+ }
resultProjection(joinedRow(currentGroupingKey, aggregateResult))
}
@@ -481,10 +610,27 @@ class TungstenAggregationIterator(
// When needsProcess is false, the format of input rows is groupingKey + aggregation buffer.
// We need to project the aggregation buffer part from an input row.
val buffer = createNewAggregationBuffer()
- // The originalInputAttributes are using cloneBufferAttributes. So, we need to use
- // allAggregateFunctions.flatMap(_.cloneBufferAttributes).
+ // In principle, we could use `allAggregateFunctions.flatMap(_.inputAggBufferAttributes)` to
+ // extract the aggregation buffer. In practice, however, we extract it positionally by relying
+ // on it being present at the end of the row. The reason for this relates to how the different
+ // aggregates handle input binding.
+ //
+ // ImperativeAggregate uses field numbers and field number offsets to manipulate its buffers,
+ // so its correctness does not rely on attribute bindings. When we fall back to sort-based
+ // aggregation, these field number offsets (mutableAggBufferOffset and inputAggBufferOffset)
+ // need to be updated and any internal state in the aggregate functions themselves must be
+ // reset, so we call withNewMutableAggBufferOffset and withNewInputAggBufferOffset to reset
+ // this state and update the offsets.
+ //
+ // The updated ImperativeAggregate will have different attribute ids for its
+ // aggBufferAttributes and inputAggBufferAttributes. This isn't a problem for the actual
+ // ImperativeAggregate evaluation, but it means that
+ // `allAggregateFunctions.flatMap(_.inputAggBufferAttributes)` will no longer match the
+ // attributes in `originalInputAttributes`, which is why we can't use those attributes here.
+ //
+ // For more details, see the discussion on PR #9038.
val bufferExtractor = newMutableProjection(
- allAggregateFunctions.flatMap(_.inputAggBufferAttributes),
+ originalInputAttributes.drop(initialInputBufferOffset),
originalInputAttributes)()
bufferExtractor.target(buffer)
@@ -511,8 +657,10 @@ class TungstenAggregationIterator(
}
aggregationMode = newAggregationMode
+ allAggregateFunctions = initializeAllAggregateFunctions(startingInputBufferOffset = 0)
+
// Basically the value of the KVIterator returned by externalSorter
- // will just aggregation buffer. At here, we use cloneBufferAttributes.
+ // will just aggregation buffer. At here, we use inputAggBufferAttributes.
val newInputAttributes: Seq[Attribute] =
allAggregateFunctions.flatMap(_.inputAggBufferAttributes)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
index fd02be1225..d2f56e0fc1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
@@ -321,9 +321,17 @@ private[sql] class InputAggregationBuffer private[sql] (
*/
private[sql] case class ScalaUDAF(
children: Seq[Expression],
- udaf: UserDefinedAggregateFunction)
+ udaf: UserDefinedAggregateFunction,
+ mutableAggBufferOffset: Int = 0,
+ inputAggBufferOffset: Int = 0)
extends ImperativeAggregate with Logging {
+ override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
+ copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+ override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
+ copy(inputAggBufferOffset = newInputAggBufferOffset)
+
require(
children.length == udaf.inputSchema.length,
s"$udaf only accepts ${udaf.inputSchema.length} arguments, " +
@@ -341,6 +349,11 @@ private[sql] case class ScalaUDAF(
override val aggBufferAttributes: Seq[AttributeReference] = aggBufferSchema.toAttributes
+ // Note: although this simply copies aggBufferAttributes, this common code can not be placed
+ // in the superclass because that will lead to initialization ordering issues.
+ override val inputAggBufferAttributes: Seq[AttributeReference] =
+ aggBufferAttributes.map(_.newInstance())
+
private[this] lazy val childrenSchema: StructType = {
val inputFields = children.zipWithIndex.map {
case (child, index) =>
@@ -382,51 +395,33 @@ private[sql] case class ScalaUDAF(
}
// This buffer is only used at executor side.
- private[this] var inputAggregateBuffer: InputAggregationBuffer = null
-
- // This buffer is only used at executor side.
- private[this] var mutableAggregateBuffer: MutableAggregationBufferImpl = null
+ private[this] lazy val inputAggregateBuffer: InputAggregationBuffer = {
+ new InputAggregationBuffer(
+ aggBufferSchema,
+ bufferValuesToCatalystConverters,
+ bufferValuesToScalaConverters,
+ inputAggBufferOffset,
+ null)
+ }
// This buffer is only used at executor side.
- private[this] var evalAggregateBuffer: InputAggregationBuffer = null
-
- /**
- * Sets the inputBufferOffset to newInputBufferOffset and then create a new instance of
- * `inputAggregateBuffer` based on this new inputBufferOffset.
- */
- override def withNewInputAggBufferOffset(newInputBufferOffset: Int): Unit = {
- super.withNewInputAggBufferOffset(newInputBufferOffset)
- // inputBufferOffset has been updated.
- inputAggregateBuffer =
- new InputAggregationBuffer(
- aggBufferSchema,
- bufferValuesToCatalystConverters,
- bufferValuesToScalaConverters,
- inputAggBufferOffset,
- null)
+ private[this] lazy val mutableAggregateBuffer: MutableAggregationBufferImpl = {
+ new MutableAggregationBufferImpl(
+ aggBufferSchema,
+ bufferValuesToCatalystConverters,
+ bufferValuesToScalaConverters,
+ mutableAggBufferOffset,
+ null)
}
- /**
- * Sets the mutableBufferOffset to newMutableBufferOffset and then create a new instance of
- * `mutableAggregateBuffer` and `evalAggregateBuffer` based on this new mutableBufferOffset.
- */
- override def withNewMutableAggBufferOffset(newMutableBufferOffset: Int): Unit = {
- super.withNewMutableAggBufferOffset(newMutableBufferOffset)
- // mutableBufferOffset has been updated.
- mutableAggregateBuffer =
- new MutableAggregationBufferImpl(
- aggBufferSchema,
- bufferValuesToCatalystConverters,
- bufferValuesToScalaConverters,
- mutableAggBufferOffset,
- null)
- evalAggregateBuffer =
- new InputAggregationBuffer(
- aggBufferSchema,
- bufferValuesToCatalystConverters,
- bufferValuesToScalaConverters,
- mutableAggBufferOffset,
- null)
+ // This buffer is only used at executor side.
+ private[this] lazy val evalAggregateBuffer: InputAggregationBuffer = {
+ new InputAggregationBuffer(
+ aggBufferSchema,
+ bufferValuesToCatalystConverters,
+ bufferValuesToScalaConverters,
+ mutableAggBufferOffset,
+ null)
}
override def initialize(buffer: MutableRow): Unit = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
index cf6e7ed0d3..eaafd83158 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
@@ -19,21 +19,12 @@ package org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan}
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.execution.SparkPlan
/**
* Utility functions used by the query planner to convert our plan to new aggregation code path.
*/
object Utils {
- def supportsTungstenAggregate(
- groupingExpressions: Seq[Expression],
- aggregateBufferAttributes: Seq[Attribute]): Boolean = {
- val aggregationBufferSchema = StructType.fromAttributes(aggregateBufferAttributes)
-
- UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) &&
- UnsafeProjection.canSupport(groupingExpressions)
- }
def planAggregateWithoutPartial(
groupingExpressions: Seq[NamedExpression],
@@ -70,8 +61,7 @@ object Utils {
// Check if we can use TungstenAggregate.
val usesTungstenAggregate =
child.sqlContext.conf.unsafeEnabled &&
- aggregateExpressions.forall(_.aggregateFunction.isInstanceOf[DeclarativeAggregate]) &&
- supportsTungstenAggregate(
+ TungstenAggregate.supportsAggregate(
groupingExpressions,
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes))
@@ -94,6 +84,7 @@ object Utils {
nonCompleteAggregateAttributes = partialAggregateAttributes,
completeAggregateExpressions = Nil,
completeAggregateAttributes = Nil,
+ initialInputBufferOffset = 0,
resultExpressions = partialResultExpressions,
child = child)
} else {
@@ -125,6 +116,7 @@ object Utils {
nonCompleteAggregateAttributes = finalAggregateAttributes,
completeAggregateExpressions = Nil,
completeAggregateAttributes = Nil,
+ initialInputBufferOffset = groupingExpressions.length,
resultExpressions = resultExpressions,
child = partialAggregate)
} else {
@@ -154,143 +146,150 @@ object Utils {
val aggregateExpressions = functionsWithDistinct ++ functionsWithoutDistinct
val usesTungstenAggregate =
child.sqlContext.conf.unsafeEnabled &&
- aggregateExpressions.forall(
- _.aggregateFunction.isInstanceOf[DeclarativeAggregate]) &&
- supportsTungstenAggregate(
+ TungstenAggregate.supportsAggregate(
groupingExpressions,
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes))
- // 1. Create an Aggregate Operator for partial aggregations.
- val groupingAttributes = groupingExpressions.map(_.toAttribute)
-
- // It is safe to call head at here since functionsWithDistinct has at least one
- // AggregateExpression2.
- val distinctColumnExpressions =
- functionsWithDistinct.head.aggregateFunction.children
- val namedDistinctColumnExpressions = distinctColumnExpressions.map {
- case ne: NamedExpression => ne -> ne
- case other =>
- val withAlias = Alias(other, other.toString)()
- other -> withAlias
+ // functionsWithDistinct is guaranteed to be non-empty. Even though it may contain more than one
+ // DISTINCT aggregate function, all of those functions will have the same column expression.
+ // For example, it would be valid for functionsWithDistinct to be
+ // [COUNT(DISTINCT foo), MAX(DISTINCT foo)], but [COUNT(DISTINCT bar), COUNT(DISTINCT foo)] is
+ // disallowed because those two distinct aggregates have different column expressions.
+ val distinctColumnExpression: Expression = {
+ val allDistinctColumnExpressions = functionsWithDistinct.head.aggregateFunction.children
+ assert(allDistinctColumnExpressions.length == 1)
+ allDistinctColumnExpressions.head
+ }
+ val namedDistinctColumnExpression: NamedExpression = distinctColumnExpression match {
+ case ne: NamedExpression => ne
+ case other => Alias(other, other.toString)()
}
- val distinctColumnExpressionMap = namedDistinctColumnExpressions.toMap
- val distinctColumnAttributes = namedDistinctColumnExpressions.map(_._2.toAttribute)
+ val distinctColumnAttribute: Attribute = namedDistinctColumnExpression.toAttribute
+ val groupingAttributes = groupingExpressions.map(_.toAttribute)
- val partialAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial))
- val partialAggregateAttributes =
- partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
- val partialAggregateGroupingExpressions =
- groupingExpressions ++ namedDistinctColumnExpressions.map(_._2)
- val partialAggregateResult =
+ // 1. Create an Aggregate Operator for partial aggregations.
+ val partialAggregate: SparkPlan = {
+ val partialAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial))
+ val partialAggregateAttributes =
+ partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
+ // We will group by the original grouping expression, plus an additional expression for the
+ // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping
+ // expressions will be [key, value].
+ val partialAggregateGroupingExpressions = groupingExpressions :+ namedDistinctColumnExpression
+ val partialAggregateResult =
groupingAttributes ++
- distinctColumnAttributes ++
- partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
- val partialAggregate = if (usesTungstenAggregate) {
- TungstenAggregate(
- requiredChildDistributionExpressions = None,
- // The grouping expressions are original groupingExpressions and
- // distinct columns. For example, for avg(distinct value) ... group by key
- // the grouping expressions of this Aggregate Operator will be [key, value].
- groupingExpressions = partialAggregateGroupingExpressions,
- nonCompleteAggregateExpressions = partialAggregateExpressions,
- nonCompleteAggregateAttributes = partialAggregateAttributes,
- completeAggregateExpressions = Nil,
- completeAggregateAttributes = Nil,
- resultExpressions = partialAggregateResult,
- child = child)
- } else {
- SortBasedAggregate(
- requiredChildDistributionExpressions = None,
- groupingExpressions = partialAggregateGroupingExpressions,
- nonCompleteAggregateExpressions = partialAggregateExpressions,
- nonCompleteAggregateAttributes = partialAggregateAttributes,
- completeAggregateExpressions = Nil,
- completeAggregateAttributes = Nil,
- initialInputBufferOffset = 0,
- resultExpressions = partialAggregateResult,
- child = child)
+ Seq(distinctColumnAttribute) ++
+ partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
+ if (usesTungstenAggregate) {
+ TungstenAggregate(
+ requiredChildDistributionExpressions = None,
+ groupingExpressions = partialAggregateGroupingExpressions,
+ nonCompleteAggregateExpressions = partialAggregateExpressions,
+ nonCompleteAggregateAttributes = partialAggregateAttributes,
+ completeAggregateExpressions = Nil,
+ completeAggregateAttributes = Nil,
+ initialInputBufferOffset = 0,
+ resultExpressions = partialAggregateResult,
+ child = child)
+ } else {
+ SortBasedAggregate(
+ requiredChildDistributionExpressions = None,
+ groupingExpressions = partialAggregateGroupingExpressions,
+ nonCompleteAggregateExpressions = partialAggregateExpressions,
+ nonCompleteAggregateAttributes = partialAggregateAttributes,
+ completeAggregateExpressions = Nil,
+ completeAggregateAttributes = Nil,
+ initialInputBufferOffset = 0,
+ resultExpressions = partialAggregateResult,
+ child = child)
+ }
}
// 2. Create an Aggregate Operator for partial merge aggregations.
- val partialMergeAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge))
- val partialMergeAggregateAttributes =
- partialMergeAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
- val partialMergeAggregateResult =
+ val partialMergeAggregate: SparkPlan = {
+ val partialMergeAggregateExpressions =
+ functionsWithoutDistinct.map(_.copy(mode = PartialMerge))
+ val partialMergeAggregateAttributes =
+ partialMergeAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
+ val partialMergeAggregateResult =
groupingAttributes ++
- distinctColumnAttributes ++
- partialMergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
- val partialMergeAggregate = if (usesTungstenAggregate) {
- TungstenAggregate(
- requiredChildDistributionExpressions = Some(groupingAttributes),
- groupingExpressions = groupingAttributes ++ distinctColumnAttributes,
- nonCompleteAggregateExpressions = partialMergeAggregateExpressions,
- nonCompleteAggregateAttributes = partialMergeAggregateAttributes,
- completeAggregateExpressions = Nil,
- completeAggregateAttributes = Nil,
- resultExpressions = partialMergeAggregateResult,
- child = partialAggregate)
- } else {
- SortBasedAggregate(
- requiredChildDistributionExpressions = Some(groupingAttributes),
- groupingExpressions = groupingAttributes ++ distinctColumnAttributes,
- nonCompleteAggregateExpressions = partialMergeAggregateExpressions,
- nonCompleteAggregateAttributes = partialMergeAggregateAttributes,
- completeAggregateExpressions = Nil,
- completeAggregateAttributes = Nil,
- initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length,
- resultExpressions = partialMergeAggregateResult,
- child = partialAggregate)
+ Seq(distinctColumnAttribute) ++
+ partialMergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
+ if (usesTungstenAggregate) {
+ TungstenAggregate(
+ requiredChildDistributionExpressions = Some(groupingAttributes),
+ groupingExpressions = groupingAttributes :+ distinctColumnAttribute,
+ nonCompleteAggregateExpressions = partialMergeAggregateExpressions,
+ nonCompleteAggregateAttributes = partialMergeAggregateAttributes,
+ completeAggregateExpressions = Nil,
+ completeAggregateAttributes = Nil,
+ initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length,
+ resultExpressions = partialMergeAggregateResult,
+ child = partialAggregate)
+ } else {
+ SortBasedAggregate(
+ requiredChildDistributionExpressions = Some(groupingAttributes),
+ groupingExpressions = groupingAttributes :+ distinctColumnAttribute,
+ nonCompleteAggregateExpressions = partialMergeAggregateExpressions,
+ nonCompleteAggregateAttributes = partialMergeAggregateAttributes,
+ completeAggregateExpressions = Nil,
+ completeAggregateAttributes = Nil,
+ initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length,
+ resultExpressions = partialMergeAggregateResult,
+ child = partialAggregate)
+ }
}
- // 3. Create an Aggregate Operator for partial merge aggregations.
- val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final))
- // The attributes of the final aggregation buffer, which is presented as input to the result
- // projection:
- val finalAggregateAttributes = finalAggregateExpressions.map {
- expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct)
- }
+ // 3. Create an Aggregate Operator for the final aggregation.
+ val finalAndCompleteAggregate: SparkPlan = {
+ val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final))
+ // The attributes of the final aggregation buffer, which is presented as input to the result
+ // projection:
+ val finalAggregateAttributes = finalAggregateExpressions.map {
+ expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct)
+ }
- val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map {
- // Children of an AggregateFunction with DISTINCT keyword has already
- // been evaluated. At here, we need to replace original children
- // to AttributeReferences.
- case agg @ AggregateExpression2(aggregateFunction, mode, true) =>
- val rewrittenAggregateFunction = aggregateFunction.transformDown {
- case expr if distinctColumnExpressionMap.contains(expr) =>
- distinctColumnExpressionMap(expr).toAttribute
- }.asInstanceOf[AggregateFunction2]
- // We rewrite the aggregate function to a non-distinct aggregation because
- // its input will have distinct arguments.
- // We just keep the isDistinct setting to true, so when users look at the query plan,
- // they still can see distinct aggregations.
- val rewrittenAggregateExpression =
- AggregateExpression2(rewrittenAggregateFunction, Complete, true)
+ val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map {
+ // Children of an AggregateFunction with DISTINCT keyword has already
+ // been evaluated. At here, we need to replace original children
+ // to AttributeReferences.
+ case agg @ AggregateExpression2(aggregateFunction, mode, true) =>
+ val rewrittenAggregateFunction = aggregateFunction.transformDown {
+ case expr if expr == distinctColumnExpression => distinctColumnAttribute
+ }.asInstanceOf[AggregateFunction2]
+ // We rewrite the aggregate function to a non-distinct aggregation because
+ // its input will have distinct arguments.
+ // We just keep the isDistinct setting to true, so when users look at the query plan,
+ // they still can see distinct aggregations.
+ val rewrittenAggregateExpression =
+ AggregateExpression2(rewrittenAggregateFunction, Complete, isDistinct = true)
- val aggregateFunctionAttribute = aggregateFunctionToAttribute(agg.aggregateFunction, true)
- (rewrittenAggregateExpression, aggregateFunctionAttribute)
- }.unzip
-
- val finalAndCompleteAggregate = if (usesTungstenAggregate) {
- TungstenAggregate(
- requiredChildDistributionExpressions = Some(groupingAttributes),
- groupingExpressions = groupingAttributes,
- nonCompleteAggregateExpressions = finalAggregateExpressions,
- nonCompleteAggregateAttributes = finalAggregateAttributes,
- completeAggregateExpressions = completeAggregateExpressions,
- completeAggregateAttributes = completeAggregateAttributes,
- resultExpressions = resultExpressions,
- child = partialMergeAggregate)
- } else {
- SortBasedAggregate(
- requiredChildDistributionExpressions = Some(groupingAttributes),
- groupingExpressions = groupingAttributes,
- nonCompleteAggregateExpressions = finalAggregateExpressions,
- nonCompleteAggregateAttributes = finalAggregateAttributes,
- completeAggregateExpressions = completeAggregateExpressions,
- completeAggregateAttributes = completeAggregateAttributes,
- initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length,
- resultExpressions = resultExpressions,
- child = partialMergeAggregate)
+ val aggregateFunctionAttribute = aggregateFunctionToAttribute(agg.aggregateFunction, true)
+ (rewrittenAggregateExpression, aggregateFunctionAttribute)
+ }.unzip
+ if (usesTungstenAggregate) {
+ TungstenAggregate(
+ requiredChildDistributionExpressions = Some(groupingAttributes),
+ groupingExpressions = groupingAttributes,
+ nonCompleteAggregateExpressions = finalAggregateExpressions,
+ nonCompleteAggregateAttributes = finalAggregateAttributes,
+ completeAggregateExpressions = completeAggregateExpressions,
+ completeAggregateAttributes = completeAggregateAttributes,
+ initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length,
+ resultExpressions = resultExpressions,
+ child = partialMergeAggregate)
+ } else {
+ SortBasedAggregate(
+ requiredChildDistributionExpressions = Some(groupingAttributes),
+ groupingExpressions = groupingAttributes,
+ nonCompleteAggregateExpressions = finalAggregateExpressions,
+ nonCompleteAggregateAttributes = finalAggregateAttributes,
+ completeAggregateExpressions = completeAggregateExpressions,
+ completeAggregateAttributes = completeAggregateAttributes,
+ initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length,
+ resultExpressions = resultExpressions,
+ child = partialMergeAggregate)
+ }
}
finalAndCompleteAggregate :: Nil
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
index ed974b3a53..0cc4988ff6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
@@ -39,7 +39,7 @@ class TungstenAggregationIteratorSuite extends SparkFunSuite with SharedSQLConte
}
val dummyAccum = SQLMetrics.createLongMetric(sparkContext, "dummy")
iter = new TungstenAggregationIterator(Seq.empty, Seq.empty, Seq.empty, Seq.empty, Seq.empty,
- Seq.empty, newMutableProjection, Seq.empty, None, dummyAccum, dummyAccum)
+ 0, Seq.empty, newMutableProjection, Seq.empty, None, dummyAccum, dummyAccum)
val numPages = iter.getHashMap.getNumDataPages
assert(numPages === 1)
} finally {
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index 18bbdb9908..a2ebf6552f 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -553,10 +553,16 @@ private[hive] case class HiveGenericUDTF(
private[hive] case class HiveUDAFFunction(
funcWrapper: HiveFunctionWrapper,
children: Seq[Expression],
- isUDAFBridgeRequired: Boolean = false)
+ isUDAFBridgeRequired: Boolean = false,
+ mutableAggBufferOffset: Int = 0,
+ inputAggBufferOffset: Int = 0)
extends ImperativeAggregate with HiveInspectors {
- def this() = this(null, null)
+ override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
+ copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+ override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
+ copy(inputAggBufferOffset = newInputAggBufferOffset)
@transient
private lazy val resolver =
@@ -614,7 +620,11 @@ private[hive] case class HiveUDAFFunction(
buffer = function.getNewAggregationBuffer
}
- override def aggBufferAttributes: Seq[AttributeReference] = Nil
+ override val aggBufferAttributes: Seq[AttributeReference] = Nil
+
+ // Note: although this simply copies aggBufferAttributes, this common code can not be placed
+ // in the superclass because that will lead to initialization ordering issues.
+ override val inputAggBufferAttributes: Seq[AttributeReference] = Nil
// We rely on Hive to check the input data types, so use `AnyDataType` here to bypass our
// catalyst type checking framework.