From a9ecd06149df4ccafd3927c35f63b9f03f170ae5 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 7 Oct 2015 13:19:49 -0700 Subject: [SPARK-10941] [SQL] Refactor AggregateFunction2 and AlgebraicAggregate interfaces to improve code clarity This patch refactors several of the Aggregate2 interfaces in order to improve code clarity. The biggest change is a refactoring of the `AggregateFunction2` class hierarchy. In the old code, we had a class named `AlgebraicAggregate` that inherited from `AggregateFunction2`, added a new set of methods, then banned the use of the inherited methods. I found this to be fairly confusing because. If you look carefully at the existing code, you'll see that subclasses of `AggregateFunction2` fall into two disjoint categories: imperative aggregation functions which directly extended `AggregateFunction2` and declarative, expression-based aggregate functions which extended `AlgebraicAggregate`. In order to make this more explicit, this patch refactors things so that `AggregateFunction2` is a sealed abstract class with two subclasses, `ImperativeAggregateFunction` and `ExpressionAggregateFunction`. The superclass, `AggregateFunction2`, now only contains methods and fields that are common to both subclasses. After making this change, I updated the various AggregationIterator classes to comply with this new naming scheme. I also performed several small renamings in the aggregate interfaces themselves in order to improve clarity and rewrote or expanded a number of comments. Author: Josh Rosen Closes #8973 from JoshRosen/tungsten-agg-comments. --- .../catalyst/expressions/aggregate/functions.scala | 70 +++--- .../expressions/aggregate/interfaces.scala | 235 ++++++++++++++------- .../aggregate/HyperLogLogPlusPlusSuite.scala | 2 +- .../execution/aggregate/AggregationIterator.scala | 188 +++++++++-------- .../aggregate/SortBasedAggregationIterator.scala | 2 +- .../execution/aggregate/TungstenAggregate.scala | 2 - .../aggregate/TungstenAggregationIterator.scala | 80 +++---- .../spark/sql/execution/aggregate/udaf.scala | 36 ++-- .../spark/sql/execution/aggregate/utils.scala | 30 ++- .../TungstenAggregationIteratorSuite.scala | 2 +- .../scala/org/apache/spark/sql/hive/hiveUDFs.scala | 8 +- .../sql/hive/execution/AggregationQuerySuite.scala | 4 +- 12 files changed, 356 insertions(+), 303 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index e7f8104d43..4ad2607a85 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ -case class Average(child: Expression) extends AlgebraicAggregate { +case class Average(child: Expression) extends ExpressionAggregate { override def children: Seq[Expression] = child :: Nil @@ -57,7 +57,7 @@ case class Average(child: Expression) extends AlgebraicAggregate { private val currentSum = AttributeReference("currentSum", sumDataType)() private val currentCount = AttributeReference("currentCount", LongType)() - override val bufferAttributes = currentSum :: currentCount :: Nil + override val aggBufferAttributes = currentSum :: currentCount :: Nil override val initialValues = Seq( /* currentSum = */ Cast(Literal(0), sumDataType), @@ -88,7 +88,7 @@ case class Average(child: Expression) extends AlgebraicAggregate { } } -case class Count(child: Expression) extends AlgebraicAggregate { +case class Count(child: Expression) extends ExpressionAggregate { override def children: Seq[Expression] = child :: Nil override def nullable: Boolean = false @@ -101,7 +101,7 @@ case class Count(child: Expression) extends AlgebraicAggregate { private val currentCount = AttributeReference("currentCount", LongType)() - override val bufferAttributes = currentCount :: Nil + override val aggBufferAttributes = currentCount :: Nil override val initialValues = Seq( /* currentCount = */ Literal(0L) @@ -118,7 +118,7 @@ case class Count(child: Expression) extends AlgebraicAggregate { override val evaluateExpression = Cast(currentCount, LongType) } -case class First(child: Expression) extends AlgebraicAggregate { +case class First(child: Expression) extends ExpressionAggregate { override def children: Seq[Expression] = child :: Nil @@ -135,7 +135,7 @@ case class First(child: Expression) extends AlgebraicAggregate { private val first = AttributeReference("first", child.dataType)() - override val bufferAttributes = first :: Nil + override val aggBufferAttributes = first :: Nil override val initialValues = Seq( /* first = */ Literal.create(null, child.dataType) @@ -152,7 +152,7 @@ case class First(child: Expression) extends AlgebraicAggregate { override val evaluateExpression = first } -case class Last(child: Expression) extends AlgebraicAggregate { +case class Last(child: Expression) extends ExpressionAggregate { override def children: Seq[Expression] = child :: Nil @@ -169,7 +169,7 @@ case class Last(child: Expression) extends AlgebraicAggregate { private val last = AttributeReference("last", child.dataType)() - override val bufferAttributes = last :: Nil + override val aggBufferAttributes = last :: Nil override val initialValues = Seq( /* last = */ Literal.create(null, child.dataType) @@ -186,7 +186,7 @@ case class Last(child: Expression) extends AlgebraicAggregate { override val evaluateExpression = last } -case class Max(child: Expression) extends AlgebraicAggregate { +case class Max(child: Expression) extends ExpressionAggregate { override def children: Seq[Expression] = child :: Nil @@ -200,7 +200,7 @@ case class Max(child: Expression) extends AlgebraicAggregate { private val max = AttributeReference("max", child.dataType)() - override val bufferAttributes = max :: Nil + override val aggBufferAttributes = max :: Nil override val initialValues = Seq( /* max = */ Literal.create(null, child.dataType) @@ -220,7 +220,7 @@ case class Max(child: Expression) extends AlgebraicAggregate { override val evaluateExpression = max } -case class Min(child: Expression) extends AlgebraicAggregate { +case class Min(child: Expression) extends ExpressionAggregate { override def children: Seq[Expression] = child :: Nil @@ -234,7 +234,7 @@ case class Min(child: Expression) extends AlgebraicAggregate { private val min = AttributeReference("min", child.dataType)() - override val bufferAttributes = min :: Nil + override val aggBufferAttributes = min :: Nil override val initialValues = Seq( /* min = */ Literal.create(null, child.dataType) @@ -277,7 +277,7 @@ case class StddevSamp(child: Expression) extends StddevAgg(child) { // Compute standard deviation based on online algorithm specified here: // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance -abstract class StddevAgg(child: Expression) extends AlgebraicAggregate { +abstract class StddevAgg(child: Expression) extends ExpressionAggregate { override def children: Seq[Expression] = child :: Nil @@ -304,7 +304,7 @@ abstract class StddevAgg(child: Expression) extends AlgebraicAggregate { private val currentAvg = AttributeReference("currentAvg", resultType)() private val currentMk = AttributeReference("currentMk", resultType)() - override val bufferAttributes = preCount :: currentCount :: preAvg :: + override val aggBufferAttributes = preCount :: currentCount :: preAvg :: currentAvg :: currentMk :: Nil override val initialValues = Seq( @@ -397,7 +397,7 @@ abstract class StddevAgg(child: Expression) extends AlgebraicAggregate { } } -case class Sum(child: Expression) extends AlgebraicAggregate { +case class Sum(child: Expression) extends ExpressionAggregate { override def children: Seq[Expression] = child :: Nil @@ -429,7 +429,7 @@ case class Sum(child: Expression) extends AlgebraicAggregate { private val zero = Cast(Literal(0), sumDataType) - override val bufferAttributes = currentSum :: Nil + override val aggBufferAttributes = currentSum :: Nil override val initialValues = Seq( /* currentSum = */ Literal.create(null, sumDataType) @@ -473,7 +473,7 @@ case class Sum(child: Expression) extends AlgebraicAggregate { */ // scalastyle:on case class HyperLogLogPlusPlus(child: Expression, relativeSD: Double = 0.05) - extends AggregateFunction2 { + extends ImperativeAggregate { import HyperLogLogPlusPlus._ /** @@ -531,28 +531,26 @@ case class HyperLogLogPlusPlus(child: Expression, relativeSD: Double = 0.05) */ private[this] val numWords = m / REGISTERS_PER_WORD + 1 - def children: Seq[Expression] = Seq(child) + override def children: Seq[Expression] = Seq(child) - def nullable: Boolean = false - - def dataType: DataType = LongType + override def nullable: Boolean = false - def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + override def dataType: DataType = LongType - def bufferSchema: StructType = StructType.fromAttributes(bufferAttributes) + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - def cloneBufferAttributes: Seq[Attribute] = bufferAttributes.map(_.newInstance()) + override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) /** Allocate enough words to store all registers. */ - val bufferAttributes: Seq[AttributeReference] = Seq.tabulate(numWords) { i => + override val aggBufferAttributes: Seq[AttributeReference] = Seq.tabulate(numWords) { i => AttributeReference(s"MS[$i]", LongType)() } /** Fill all words with zeros. */ - def initialize(buffer: MutableRow): Unit = { + override def initialize(buffer: MutableRow): Unit = { var word = 0 while (word < numWords) { - buffer.setLong(mutableBufferOffset + word, 0) + buffer.setLong(mutableAggBufferOffset + word, 0) word += 1 } } @@ -562,7 +560,7 @@ case class HyperLogLogPlusPlus(child: Expression, relativeSD: Double = 0.05) * * Variable names in the HLL++ paper match variable names in the code. */ - def update(buffer: MutableRow, input: InternalRow): Unit = { + override def update(buffer: MutableRow, input: InternalRow): Unit = { val v = child.eval(input) if (v != null) { // Create the hashed value 'x'. @@ -576,7 +574,7 @@ case class HyperLogLogPlusPlus(child: Expression, relativeSD: Double = 0.05) // Get the word containing the register we are interested in. val wordOffset = idx / REGISTERS_PER_WORD - val word = buffer.getLong(mutableBufferOffset + wordOffset) + val word = buffer.getLong(mutableAggBufferOffset + wordOffset) // Extract the M[J] register value from the word. val shift = REGISTER_SIZE * (idx - (wordOffset * REGISTERS_PER_WORD)) @@ -585,7 +583,7 @@ case class HyperLogLogPlusPlus(child: Expression, relativeSD: Double = 0.05) // Assign the maximum number of leading zeros to the register. if (pw > Midx) { - buffer.setLong(mutableBufferOffset + wordOffset, (word & ~mask) | (pw << shift)) + buffer.setLong(mutableAggBufferOffset + wordOffset, (word & ~mask) | (pw << shift)) } } } @@ -594,12 +592,12 @@ case class HyperLogLogPlusPlus(child: Expression, relativeSD: Double = 0.05) * Merge the HLL buffers by iterating through the registers in both buffers and select the * maximum number of leading zeros for each register. */ - def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { + override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { var idx = 0 var wordOffset = 0 while (wordOffset < numWords) { - val word1 = buffer1.getLong(mutableBufferOffset + wordOffset) - val word2 = buffer2.getLong(inputBufferOffset + wordOffset) + val word1 = buffer1.getLong(mutableAggBufferOffset + wordOffset) + val word2 = buffer2.getLong(inputAggBufferOffset + wordOffset) var word = 0L var i = 0 var mask = REGISTER_WORD_MASK @@ -609,7 +607,7 @@ case class HyperLogLogPlusPlus(child: Expression, relativeSD: Double = 0.05) i += 1 idx += 1 } - buffer1.setLong(mutableBufferOffset + wordOffset, word) + buffer1.setLong(mutableAggBufferOffset + wordOffset, word) wordOffset += 1 } } @@ -664,14 +662,14 @@ case class HyperLogLogPlusPlus(child: Expression, relativeSD: Double = 0.05) * * Variable names in the HLL++ paper match variable names in the code. */ - def eval(buffer: InternalRow): Any = { + override def eval(buffer: InternalRow): Any = { // Compute the inverse of indicator value 'z' and count the number of zeros 'V'. var zInverse = 0.0d var V = 0.0d var idx = 0 var wordOffset = 0 while (wordOffset < numWords) { - val word = buffer.getLong(mutableBufferOffset + wordOffset) + val word = buffer.getLong(mutableAggBufferOffset + wordOffset) var i = 0 var shift = 0 while (idx < m && i < REGISTERS_PER_WORD) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index d8699533cd..74e15ec90b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -69,9 +69,6 @@ private[sql] case object NoOp extends Expression with Unevaluable { /** * A container for an [[AggregateFunction2]] with its [[AggregateMode]] and a field * (`isDistinct`) indicating if DISTINCT keyword is specified for this function. - * @param aggregateFunction - * @param mode - * @param isDistinct */ private[sql] case class AggregateExpression2( aggregateFunction: AggregateFunction2, @@ -86,7 +83,7 @@ private[sql] case class AggregateExpression2( override def references: AttributeSet = { val childReferences = mode match { case Partial | Complete => aggregateFunction.references.toSeq - case PartialMerge | Final => aggregateFunction.bufferAttributes + case PartialMerge | Final => aggregateFunction.aggBufferAttributes } AttributeSet(childReferences) @@ -95,98 +92,192 @@ private[sql] case class AggregateExpression2( override def toString: String = s"(${aggregateFunction},mode=$mode,isDistinct=$isDistinct)" } -abstract class AggregateFunction2 - extends Expression with ImplicitCastInputTypes { +/** + * AggregateFunction2 is the superclass of two aggregation function interfaces: + * + * - [[ImperativeAggregate]] is for aggregation functions that are specified in terms of + * initialize(), update(), and merge() functions that operate on Row-based aggregation buffers. + * - [[ExpressionAggregate]] is for aggregation functions that are specified using + * Catalyst expressions. + * + * In both interfaces, aggregates must define the schema ([[aggBufferSchema]]) and attributes + * ([[aggBufferAttributes]]) of an aggregation buffer which is used to hold partial aggregate + * results. At runtime, multiple aggregate functions are evaluated by the same operator using a + * combined aggregation buffer which concatenates the aggregation buffers of the individual + * aggregate functions. + * + * Code which accepts [[AggregateFunction2]] instances should be prepared to handle both types of + * aggregate functions. + */ +sealed abstract class AggregateFunction2 extends Expression with ImplicitCastInputTypes { /** An aggregate function is not foldable. */ final override def foldable: Boolean = false + /** The schema of the aggregation buffer. */ + def aggBufferSchema: StructType + + /** Attributes of fields in aggBufferSchema. */ + def aggBufferAttributes: Seq[AttributeReference] + /** - * The offset of this function's start buffer value in the - * underlying shared mutable aggregation buffer. - * For example, we have two aggregate functions `avg(x)` and `avg(y)`, which share - * the same aggregation buffer. In this shared buffer, the position of the first - * buffer value of `avg(x)` will be 0 and the position of the first buffer value of `avg(y)` - * will be 2. + * Attributes of fields in input aggregation buffers (immutable aggregation buffers that are + * merged with mutable aggregation buffers in the merge() function or merge expressions). + * These attributes are created automatically by cloning the [[aggBufferAttributes]]. */ - protected var mutableBufferOffset: Int = 0 - - def withNewMutableBufferOffset(newMutableBufferOffset: Int): Unit = { - mutableBufferOffset = newMutableBufferOffset - } + def inputAggBufferAttributes: Seq[AttributeReference] /** - * The offset of this function's start buffer value in the - * underlying shared input aggregation buffer. An input aggregation buffer is used - * when we merge two aggregation buffers and it is basically the immutable one - * (we merge an input aggregation buffer and a mutable aggregation buffer and - * then store the new buffer values to the mutable aggregation buffer). - * Usually, an input aggregation buffer also contain extra elements like grouping - * keys at the beginning. So, mutableBufferOffset and inputBufferOffset are often - * different. - * For example, we have a grouping expression `key``, and two aggregate functions - * `avg(x)` and `avg(y)`. In this shared input aggregation buffer, the position of the first - * buffer value of `avg(x)` will be 1 and the position of the first buffer value of `avg(y)` - * will be 3 (position 0 is used for the value of key`). + * Indicates if this function supports partial aggregation. + * Currently Hive UDAF is the only one that doesn't support partial aggregation. */ - protected var inputBufferOffset: Int = 0 + def supportsPartial: Boolean = true - def withNewInputBufferOffset(newInputBufferOffset: Int): Unit = { - inputBufferOffset = newInputBufferOffset - } + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = + throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") +} - /** The schema of the aggregation buffer. */ - def bufferSchema: StructType +/** + * API for aggregation functions that are expressed in terms of imperative initialize(), update(), + * and merge() functions which operate on Row-based aggregation buffers. + * + * Within these functions, code should access fields of the mutable aggregation buffer by adding the + * bufferSchema-relative field number to `mutableAggBufferOffset` then using this new field number + * to access the buffer Row. This is necessary because this aggregation function's buffer is + * embedded inside of a larger shared aggregation buffer when an aggregation operator evaluates + * multiple aggregate functions at the same time. + * + * We need to perform similar field number arithmetic when merging multiple intermediate + * aggregate buffers together in `merge()` (in this case, use `inputAggBufferOffset` when accessing + * the input buffer). + */ +abstract class ImperativeAggregate extends AggregateFunction2 { - /** Attributes of fields in bufferSchema. */ - def bufferAttributes: Seq[AttributeReference] + /** + * The offset of this function's first buffer value in the underlying shared mutable aggregation + * buffer. + * + * For example, we have two aggregate functions `avg(x)` and `avg(y)`, which share the same + * aggregation buffer. In this shared buffer, the position of the first buffer value of `avg(x)` + * will be 0 and the position of the first buffer value of `avg(y)` will be 2: + * + * avg(x) mutableAggBufferOffset = 0 + * | + * v + * +--------+--------+--------+--------+ + * | sum1 | count1 | sum2 | count2 | + * +--------+--------+--------+--------+ + * ^ + * | + * avg(y) mutableAggBufferOffset = 2 + * + */ + protected var mutableAggBufferOffset: Int = 0 - /** Clones bufferAttributes. */ - def cloneBufferAttributes: Seq[Attribute] + def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): Unit = { + mutableAggBufferOffset = newMutableAggBufferOffset + } /** - * Initializes its aggregation buffer located in `buffer`. - * It will use bufferOffset to find the starting point of - * its buffer in the given `buffer` shared with other functions. + * The offset of this function's start buffer value in the underlying shared input aggregation + * buffer. An input aggregation buffer is used when we merge two aggregation buffers together in + * the `update()` function and is immutable (we merge an input aggregation buffer and a mutable + * aggregation buffer and then store the new buffer values to the mutable aggregation buffer). + * + * An input aggregation buffer may contain extra fields, such as grouping keys, at its start, so + * mutableAggBufferOffset and inputAggBufferOffset are often different. + * + * For example, say we have a grouping expression, `key`, and two aggregate functions, + * `avg(x)` and `avg(y)`. In the shared input aggregation buffer, the position of the first + * buffer value of `avg(x)` will be 1 and the position of the first buffer value of `avg(y)` + * will be 3 (position 0 is used for the value of `key`): + * + * avg(x) inputAggBufferOffset = 1 + * | + * v + * +--------+--------+--------+--------+--------+ + * | key | sum1 | count1 | sum2 | count2 | + * +--------+--------+--------+--------+--------+ + * ^ + * | + * avg(y) inputAggBufferOffset = 3 + * */ - def initialize(buffer: MutableRow): Unit + protected var inputAggBufferOffset: Int = 0 + + def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): Unit = { + inputAggBufferOffset = newInputAggBufferOffset + } /** - * Updates its aggregation buffer located in `buffer` based on the given `input`. - * It will use bufferOffset to find the starting point of its buffer in the given `buffer` - * shared with other functions. + * Initializes the mutable aggregation buffer located in `mutableAggBuffer`. + * + * Use `fieldNumber + mutableAggBufferOffset` to access fields of `mutableAggBuffer`. */ - def update(buffer: MutableRow, input: InternalRow): Unit + def initialize(mutableAggBuffer: MutableRow): Unit /** - * Updates its aggregation buffer located in `buffer1` by combining intermediate results - * in the current buffer and intermediate results from another buffer `buffer2`. - * It will use bufferOffset to find the starting point of its buffer in the given `buffer1` - * and `buffer2`. + * Updates its aggregation buffer, located in `mutableAggBuffer`, based on the given `inputRow`. + * + * Use `fieldNumber + mutableAggBufferOffset` to access fields of `mutableAggBuffer`. */ - def merge(buffer1: MutableRow, buffer2: InternalRow): Unit - - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = - throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") + def update(mutableAggBuffer: MutableRow, inputRow: InternalRow): Unit /** - * Indicates if this function supports partial aggregation. - * Currently Hive UDAF is the only one that doesn't support partial aggregation. + * Combines new intermediate results from the `inputAggBuffer` with the existing intermediate + * results in the `mutableAggBuffer.` + * + * Use `fieldNumber + mutableAggBufferOffset` to access fields of `mutableAggBuffer`. + * Use `fieldNumber + inputAggBufferOffset` to access fields of `inputAggBuffer`. */ - def supportsPartial: Boolean = true + def merge(mutableAggBuffer: MutableRow, inputAggBuffer: InternalRow): Unit + + final lazy val inputAggBufferAttributes: Seq[AttributeReference] = + aggBufferAttributes.map(_.newInstance()) } /** - * A helper class for aggregate functions that can be implemented in terms of catalyst expressions. + * API for aggregation functions that are expressed in terms of Catalyst expressions. + * + * When implementing a new expression-based aggregate function, start by implementing + * `bufferAttributes`, defining attributes for the fields of the mutable aggregation buffer. You + * can then use these attributes when defining `updateExpressions`, `mergeExpressions`, and + * `evaluateExpressions`. */ -abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable with Unevaluable { +abstract class ExpressionAggregate + extends AggregateFunction2 + with Serializable + with Unevaluable { + /** + * Expressions for initializing empty aggregation buffers. + */ val initialValues: Seq[Expression] + + /** + * Expressions for updating the mutable aggregation buffer based on an input row. + */ val updateExpressions: Seq[Expression] + + /** + * A sequence of expressions for merging two aggregation buffers together. When defining these + * expressions, you can use the syntax `attributeName.left` and `attributeName.right` to refer + * to the attributes corresponding to each of the buffers being merged (this magic is enabled + * by the [[RichAttribute]] implicit class). + */ val mergeExpressions: Seq[Expression] + + /** + * An expression which returns the final value for this aggregate function. Its data type should + * match this expression's [[dataType]]. + */ val evaluateExpression: Expression - override lazy val cloneBufferAttributes = bufferAttributes.map(_.newInstance()) + /** An expression-based aggregate's bufferSchema is derived from bufferAttributes. */ + final override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) + + final lazy val inputAggBufferAttributes: Seq[AttributeReference] = + aggBufferAttributes.map(_.newInstance()) /** * A helper class for representing an attribute used in merging two @@ -194,33 +285,13 @@ abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable w * we merge buffer values and then update bufferLeft. A [[RichAttribute]] * of an [[AttributeReference]] `a` has two functions `left` and `right`, * which represent `a` in `bufferLeft` and `bufferRight`, respectively. - * @param a */ implicit class RichAttribute(a: AttributeReference) { /** Represents this attribute at the mutable buffer side. */ def left: AttributeReference = a /** Represents this attribute at the input buffer side (the data value is read-only). */ - def right: AttributeReference = cloneBufferAttributes(bufferAttributes.indexOf(a)) - } - - /** An AlgebraicAggregate's bufferSchema is derived from bufferAttributes. */ - override def bufferSchema: StructType = StructType.fromAttributes(bufferAttributes) - - override def initialize(buffer: MutableRow): Unit = { - throw new UnsupportedOperationException( - "AlgebraicAggregate's initialize should not be called directly") - } - - override final def update(buffer: MutableRow, input: InternalRow): Unit = { - throw new UnsupportedOperationException( - "AlgebraicAggregate's update should not be called directly") + def right: AttributeReference = inputAggBufferAttributes(aggBufferAttributes.indexOf(a)) } - - override final def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { - throw new UnsupportedOperationException( - "AlgebraicAggregate's merge should not be called directly") - } - } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala index ecc0644164..0d32949775 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala @@ -38,7 +38,7 @@ class HyperLogLogPlusPlusSuite extends SparkFunSuite { } def createBuffer(hll: HyperLogLogPlusPlus): MutableRow = { - val buffer = new SpecificMutableRow(hll.bufferAttributes.map(_.dataType)) + val buffer = new SpecificMutableRow(hll.aggBufferAttributes.map(_.dataType)) hll.initialize(buffer) buffer } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala index 62dbc07e88..04903022e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -85,9 +85,9 @@ abstract class AggregationIterator( while (i < allAggregateExpressions.length) { val func = allAggregateExpressions(i).aggregateFunction val funcWithBoundReferences = allAggregateExpressions(i).mode match { - case Partial | Complete if !func.isInstanceOf[AlgebraicAggregate] => + case Partial | Complete if func.isInstanceOf[ImperativeAggregate] => // We need to create BoundReferences if the function is not an - // AlgebraicAggregate (it does not support code-gen) and the mode of + // expression-based aggregate function (it does not support code-gen) and the mode of // this function is Partial or Complete because we will call eval of this // function's children in the update method of this aggregate function. // Those eval calls require BoundReferences to work. @@ -95,31 +95,39 @@ abstract class AggregationIterator( case _ => // We only need to set inputBufferOffset for aggregate functions with mode // PartialMerge and Final. - func.withNewInputBufferOffset(inputBufferOffset) - inputBufferOffset += func.bufferSchema.length + func match { + case function: ImperativeAggregate => + function.withNewInputAggBufferOffset(inputBufferOffset) + case _ => + } + inputBufferOffset += func.aggBufferSchema.length func } // Set mutableBufferOffset for this function. It is important that setting // mutableBufferOffset happens after all potential bindReference operations // because bindReference will create a new instance of the function. - funcWithBoundReferences.withNewMutableBufferOffset(mutableBufferOffset) - mutableBufferOffset += funcWithBoundReferences.bufferSchema.length + funcWithBoundReferences match { + case function: ImperativeAggregate => + function.withNewMutableAggBufferOffset(mutableBufferOffset) + case _ => + } + mutableBufferOffset += funcWithBoundReferences.aggBufferSchema.length functions(i) = funcWithBoundReferences i += 1 } functions } - // Positions of those non-algebraic aggregate functions in allAggregateFunctions. + // Positions of those imperative aggregate functions in allAggregateFunctions. // For example, we have func1, func2, func3, func4 in aggregateFunctions, and - // func2 and func3 are non-algebraic aggregate functions. - // nonAlgebraicAggregateFunctionPositions will be [1, 2]. - private[this] val allNonAlgebraicAggregateFunctionPositions: Array[Int] = { + // func2 and func3 are imperative aggregate functions. + // ImperativeAggregateFunctionPositions will be [1, 2]. + private[this] val allImperativeAggregateFunctionPositions: Array[Int] = { val positions = new ArrayBuffer[Int]() var i = 0 while (i < allAggregateFunctions.length) { allAggregateFunctions(i) match { - case agg: AlgebraicAggregate => + case agg: ExpressionAggregate => case _ => positions += i } i += 1 @@ -131,24 +139,26 @@ abstract class AggregationIterator( private[this] val nonCompleteAggregateFunctions: Array[AggregateFunction2] = allAggregateFunctions.take(nonCompleteAggregateExpressions.length) - // All non-algebraic aggregate functions with mode Partial, PartialMerge, or Final. - private[this] val nonCompleteNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = - nonCompleteAggregateFunctions.collect { - case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func - } + // All imperative aggregate functions with mode Partial, PartialMerge, or Final. + private[this] val nonCompleteImperativeAggregateFunctions: Array[ImperativeAggregate] = + nonCompleteAggregateFunctions.collect { case func: ImperativeAggregate => func } - // The projection used to initialize buffer values for all AlgebraicAggregates. - private[this] val algebraicInitialProjection = { + // The projection used to initialize buffer values for all expression-based aggregates. + private[this] val expressionAggInitialProjection = { val initExpressions = allAggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.initialValues - case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) + case ae: ExpressionAggregate => ae.initialValues + // For the positions corresponding to imperative aggregate functions, we'll use special + // no-op expressions which are ignored during projection code-generation. + case i: ImperativeAggregate => Seq.fill(i.aggBufferAttributes.length)(NoOp) } newMutableProjection(initExpressions, Nil)() } - // All non-Algebraic AggregateFunctions. - private[this] val allNonAlgebraicAggregateFunctions = - allNonAlgebraicAggregateFunctionPositions.map(allAggregateFunctions) + // All imperative AggregateFunctions. + private[this] val allImperativeAggregateFunctions: Array[ImperativeAggregate] = + allImperativeAggregateFunctionPositions + .map(allAggregateFunctions) + .map(_.asInstanceOf[ImperativeAggregate]) /////////////////////////////////////////////////////////////////////////// // Methods and fields used by sub-classes. @@ -157,25 +167,25 @@ abstract class AggregationIterator( // Initializing functions used to process a row. protected val processRow: (MutableRow, InternalRow) => Unit = { val rowToBeProcessed = new JoinedRow - val aggregationBufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes) + val aggregationBufferSchema = allAggregateFunctions.flatMap(_.aggBufferAttributes) aggregationMode match { // Partial-only case (Some(Partial), None) => val updateExpressions = nonCompleteAggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.updateExpressions - case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) + case ae: ExpressionAggregate => ae.updateExpressions + case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) } - val algebraicUpdateProjection = + val expressionAggUpdateProjection = newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)() (currentBuffer: MutableRow, row: InternalRow) => { - algebraicUpdateProjection.target(currentBuffer) - // Process all algebraic aggregate functions. - algebraicUpdateProjection(rowToBeProcessed(currentBuffer, row)) - // Process all non-algebraic aggregate functions. + expressionAggUpdateProjection.target(currentBuffer) + // Process all expression-based aggregate functions. + expressionAggUpdateProjection(rowToBeProcessed(currentBuffer, row)) + // Process all imperative aggregate functions. var i = 0 - while (i < nonCompleteNonAlgebraicAggregateFunctions.length) { - nonCompleteNonAlgebraicAggregateFunctions(i).update(currentBuffer, row) + while (i < nonCompleteImperativeAggregateFunctions.length) { + nonCompleteImperativeAggregateFunctions(i).update(currentBuffer, row) i += 1 } } @@ -186,30 +196,30 @@ abstract class AggregationIterator( // If initialInputBufferOffset, the input value does not contain // grouping keys. // This part is pretty hacky. - allAggregateFunctions.flatMap(_.cloneBufferAttributes).toSeq + allAggregateFunctions.flatMap(_.inputAggBufferAttributes).toSeq } else { - groupingKeyAttributes ++ allAggregateFunctions.flatMap(_.cloneBufferAttributes) + groupingKeyAttributes ++ allAggregateFunctions.flatMap(_.inputAggBufferAttributes) } // val inputAggregationBufferSchema = // groupingKeyAttributes ++ // allAggregateFunctions.flatMap(_.cloneBufferAttributes) val mergeExpressions = nonCompleteAggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.mergeExpressions - case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) + case ae: ExpressionAggregate => ae.mergeExpressions + case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) } - // This projection is used to merge buffer values for all AlgebraicAggregates. - val algebraicMergeProjection = + // This projection is used to merge buffer values for all expression-based aggregates. + val expressionAggMergeProjection = newMutableProjection( mergeExpressions, aggregationBufferSchema ++ inputAggregationBufferSchema)() (currentBuffer: MutableRow, row: InternalRow) => { - // Process all algebraic aggregate functions. - algebraicMergeProjection.target(currentBuffer)(rowToBeProcessed(currentBuffer, row)) - // Process all non-algebraic aggregate functions. + // Process all expression-based aggregate functions. + expressionAggMergeProjection.target(currentBuffer)(rowToBeProcessed(currentBuffer, row)) + // Process all imperative aggregate functions. var i = 0 - while (i < nonCompleteNonAlgebraicAggregateFunctions.length) { - nonCompleteNonAlgebraicAggregateFunctions(i).merge(currentBuffer, row) + while (i < nonCompleteImperativeAggregateFunctions.length) { + nonCompleteImperativeAggregateFunctions(i).merge(currentBuffer, row) i += 1 } } @@ -218,57 +228,55 @@ abstract class AggregationIterator( case (Some(Final), Some(Complete)) => val completeAggregateFunctions: Array[AggregateFunction2] = allAggregateFunctions.takeRight(completeAggregateExpressions.length) - // All non-algebraic aggregate functions with mode Complete. - val completeNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = - completeAggregateFunctions.collect { - case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func - } + // All imperative aggregate functions with mode Complete. + val completeImperativeAggregateFunctions: Array[ImperativeAggregate] = + completeAggregateFunctions.collect { case func: ImperativeAggregate => func } // The first initialInputBufferOffset values of the input aggregation buffer is // for grouping expressions and distinct columns. val groupingAttributesAndDistinctColumns = valueAttributes.take(initialInputBufferOffset) val completeOffsetExpressions = - Seq.fill(completeAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp) + Seq.fill(completeAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp) // We do not touch buffer values of aggregate functions with the Final mode. val finalOffsetExpressions = - Seq.fill(nonCompleteAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp) + Seq.fill(nonCompleteAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp) val mergeInputSchema = aggregationBufferSchema ++ groupingAttributesAndDistinctColumns ++ - nonCompleteAggregateFunctions.flatMap(_.cloneBufferAttributes) + nonCompleteAggregateFunctions.flatMap(_.inputAggBufferAttributes) val mergeExpressions = nonCompleteAggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.mergeExpressions - case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) + case ae: ExpressionAggregate => ae.mergeExpressions + case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) } ++ completeOffsetExpressions - val finalAlgebraicMergeProjection = + val finalExpressionAggMergeProjection = newMutableProjection(mergeExpressions, mergeInputSchema)() val updateExpressions = finalOffsetExpressions ++ completeAggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.updateExpressions - case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) + case ae: ExpressionAggregate => ae.updateExpressions + case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) } - val completeAlgebraicUpdateProjection = + val completeExpressionAggUpdateProjection = newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)() (currentBuffer: MutableRow, row: InternalRow) => { val input = rowToBeProcessed(currentBuffer, row) // For all aggregate functions with mode Complete, update buffers. - completeAlgebraicUpdateProjection.target(currentBuffer)(input) + completeExpressionAggUpdateProjection.target(currentBuffer)(input) var i = 0 - while (i < completeNonAlgebraicAggregateFunctions.length) { - completeNonAlgebraicAggregateFunctions(i).update(currentBuffer, row) + while (i < completeImperativeAggregateFunctions.length) { + completeImperativeAggregateFunctions(i).update(currentBuffer, row) i += 1 } // For all aggregate functions with mode Final, merge buffers. - finalAlgebraicMergeProjection.target(currentBuffer)(input) + finalExpressionAggMergeProjection.target(currentBuffer)(input) i = 0 - while (i < nonCompleteNonAlgebraicAggregateFunctions.length) { - nonCompleteNonAlgebraicAggregateFunctions(i).merge(currentBuffer, row) + while (i < nonCompleteImperativeAggregateFunctions.length) { + nonCompleteImperativeAggregateFunctions(i).merge(currentBuffer, row) i += 1 } } @@ -277,27 +285,25 @@ abstract class AggregationIterator( case (None, Some(Complete)) => val completeAggregateFunctions: Array[AggregateFunction2] = allAggregateFunctions.takeRight(completeAggregateExpressions.length) - // All non-algebraic aggregate functions with mode Complete. - val completeNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = - completeAggregateFunctions.collect { - case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func - } + // All imperative aggregate functions with mode Complete. + val completeImperativeAggregateFunctions: Array[ImperativeAggregate] = + completeAggregateFunctions.collect { case func: ImperativeAggregate => func } val updateExpressions = completeAggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.updateExpressions - case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) + case ae: ExpressionAggregate => ae.updateExpressions + case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) } - val completeAlgebraicUpdateProjection = + val completeExpressionAggUpdateProjection = newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)() (currentBuffer: MutableRow, row: InternalRow) => { val input = rowToBeProcessed(currentBuffer, row) // For all aggregate functions with mode Complete, update buffers. - completeAlgebraicUpdateProjection.target(currentBuffer)(input) + completeExpressionAggUpdateProjection.target(currentBuffer)(input) var i = 0 - while (i < completeNonAlgebraicAggregateFunctions.length) { - completeNonAlgebraicAggregateFunctions(i).update(currentBuffer, row) + while (i < completeImperativeAggregateFunctions.length) { + completeImperativeAggregateFunctions(i).update(currentBuffer, row) i += 1 } } @@ -315,11 +321,11 @@ abstract class AggregationIterator( // Initializing the function used to generate the output row. protected val generateOutput: (InternalRow, MutableRow) => InternalRow = { val rowToBeEvaluated = new JoinedRow - val safeOutoutRow = new GenericMutableRow(resultExpressions.length) + val safeOutputRow = new GenericMutableRow(resultExpressions.length) val mutableOutput = if (outputsUnsafeRows) { - UnsafeProjection.create(resultExpressions.map(_.dataType).toArray).apply(safeOutoutRow) + UnsafeProjection.create(resultExpressions.map(_.dataType).toArray).apply(safeOutputRow) } else { - safeOutoutRow + safeOutputRow } aggregationMode match { @@ -329,7 +335,7 @@ abstract class AggregationIterator( // Because we cannot copy a joinedRow containing a UnsafeRow (UnsafeRow does not // support generic getter), we create a mutable projection to output the // JoinedRow(currentGroupingKey, currentBuffer) - val bufferSchema = nonCompleteAggregateFunctions.flatMap(_.bufferAttributes) + val bufferSchema = nonCompleteAggregateFunctions.flatMap(_.aggBufferAttributes) val resultProjection = newMutableProjection( groupingKeyAttributes ++ bufferSchema, @@ -345,12 +351,12 @@ abstract class AggregationIterator( // resultExpressions. case (Some(Final), None) | (Some(Final) | None, Some(Complete)) => val bufferSchemata = - allAggregateFunctions.flatMap(_.bufferAttributes) + allAggregateFunctions.flatMap(_.aggBufferAttributes) val evalExpressions = allAggregateFunctions.map { - case ae: AlgebraicAggregate => ae.evaluateExpression + case ae: ExpressionAggregate => ae.evaluateExpression case agg: AggregateFunction2 => NoOp } - val algebraicEvalProjection = newMutableProjection(evalExpressions, bufferSchemata)() + val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferSchemata)() val aggregateResultSchema = nonCompleteAggregateAttributes ++ completeAggregateAttributes // TODO: Use unsafe row. val aggregateResult = new GenericMutableRow(aggregateResultSchema.length) @@ -360,14 +366,14 @@ abstract class AggregationIterator( resultProjection.target(mutableOutput) (currentGroupingKey: InternalRow, currentBuffer: MutableRow) => { - // Generate results for all algebraic aggregate functions. - algebraicEvalProjection.target(aggregateResult)(currentBuffer) - // Generate results for all non-algebraic aggregate functions. + // Generate results for all expression-based aggregate functions. + expressionAggEvalProjection.target(aggregateResult)(currentBuffer) + // Generate results for all imperative aggregate functions. var i = 0 - while (i < allNonAlgebraicAggregateFunctions.length) { + while (i < allImperativeAggregateFunctions.length) { aggregateResult.update( - allNonAlgebraicAggregateFunctionPositions(i), - allNonAlgebraicAggregateFunctions(i).eval(currentBuffer)) + allImperativeAggregateFunctionPositions(i), + allImperativeAggregateFunctions(i).eval(currentBuffer)) i += 1 } resultProjection(rowToBeEvaluated(currentGroupingKey, aggregateResult)) @@ -392,10 +398,10 @@ abstract class AggregationIterator( /** Initializes buffer values for all aggregate functions. */ protected def initializeBuffer(buffer: MutableRow): Unit = { - algebraicInitialProjection.target(buffer)(EmptyRow) + expressionAggInitialProjection.target(buffer)(EmptyRow) var i = 0 - while (i < allNonAlgebraicAggregateFunctions.length) { - allNonAlgebraicAggregateFunctions(i).initialize(buffer) + while (i < allImperativeAggregateFunctions.length) { + allImperativeAggregateFunctions(i).initialize(buffer) i += 1 } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index 73d50e07cf..a9e5d175bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -54,7 +54,7 @@ class SortBasedAggregationIterator( outputsUnsafeRows) { override protected def newBuffer: MutableRow = { - val bufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes) + val bufferSchema = allAggregateFunctions.flatMap(_.aggBufferAttributes) val bufferRowSize: Int = bufferSchema.length val genericMutableBuffer = new GenericMutableRow(bufferRowSize) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index ba379d358d..3cd22af305 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -32,7 +32,6 @@ case class TungstenAggregate( groupingExpressions: Seq[NamedExpression], nonCompleteAggregateExpressions: Seq[AggregateExpression2], completeAggregateExpressions: Seq[AggregateExpression2], - initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { @@ -79,7 +78,6 @@ case class TungstenAggregate( groupingExpressions, nonCompleteAggregateExpressions, completeAggregateExpressions, - initialInputBufferOffset, resultExpressions, newMutableProjection, child.output, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 26fdbc83ef..6a84c0af0b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -62,10 +62,6 @@ import org.apache.spark.sql.types.StructType * [[PartialMerge]], or [[Final]]. * @param completeAggregateExpressions * [[AggregateExpression2]] containing [[AggregateFunction2]]s with mode [[Complete]]. - * @param initialInputBufferOffset - * If this iterator is used to handle functions with mode [[PartialMerge]] or [[Final]]. - * The input rows have the format of `grouping keys + aggregation buffer`. - * This offset indicates the starting position of aggregation buffer in a input row. * @param resultExpressions * expressions for generating output rows. * @param newMutableProjection @@ -77,7 +73,6 @@ class TungstenAggregationIterator( groupingExpressions: Seq[NamedExpression], nonCompleteAggregateExpressions: Seq[AggregateExpression2], completeAggregateExpressions: Seq[AggregateExpression2], - initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), originalInputAttributes: Seq[Attribute], @@ -133,17 +128,18 @@ class TungstenAggregationIterator( completeAggregateExpressions.map(_.mode).distinct.headOption } - // All aggregate functions. TungstenAggregationIterator only handles AlgebraicAggregates. - // If there is any functions that is not an AlgebraicAggregate, we throw an + // All aggregate functions. TungstenAggregationIterator only handles expression-based aggregate. + // If there is any functions that is an ImperativeAggregateFunction, we throw an // IllegalStateException. - private[this] val allAggregateFunctions: Array[AlgebraicAggregate] = { - if (!allAggregateExpressions.forall(_.aggregateFunction.isInstanceOf[AlgebraicAggregate])) { + private[this] val allAggregateFunctions: Array[ExpressionAggregate] = { + if (!allAggregateExpressions.forall( + _.aggregateFunction.isInstanceOf[ExpressionAggregate])) { throw new IllegalStateException( - "Only AlgebraicAggregates should be passed in TungstenAggregationIterator.") + "Only ExpressionAggregateFunctions should be passed in TungstenAggregationIterator.") } allAggregateExpressions - .map(_.aggregateFunction.asInstanceOf[AlgebraicAggregate]) + .map(_.aggregateFunction.asInstanceOf[ExpressionAggregate]) .toArray } @@ -154,7 +150,7 @@ class TungstenAggregationIterator( /////////////////////////////////////////////////////////////////////////// // The projection used to initialize buffer values. - private[this] val algebraicInitialProjection: MutableProjection = { + private[this] val initialProjection: MutableProjection = { val initExpressions = allAggregateFunctions.flatMap(_.initialValues) newMutableProjection(initExpressions, Nil)() } @@ -164,14 +160,14 @@ class TungstenAggregationIterator( // when we switch to sort-based aggregation, and when we create the re-used buffer for // sort-based aggregation). private def createNewAggregationBuffer(): UnsafeRow = { - val bufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes) + val bufferSchema = allAggregateFunctions.flatMap(_.aggBufferAttributes) val bufferRowSize: Int = bufferSchema.length val genericMutableBuffer = new GenericMutableRow(bufferRowSize) val unsafeProjection = UnsafeProjection.create(bufferSchema.map(_.dataType)) val buffer = unsafeProjection.apply(genericMutableBuffer) - algebraicInitialProjection.target(buffer)(EmptyRow) + initialProjection.target(buffer)(EmptyRow) buffer } @@ -179,84 +175,78 @@ class TungstenAggregationIterator( private def generateProcessRow( inputAttributes: Seq[Attribute]): (UnsafeRow, InternalRow) => Unit = { - val aggregationBufferAttributes = allAggregateFunctions.flatMap(_.bufferAttributes) + val aggregationBufferAttributes = allAggregateFunctions.flatMap(_.aggBufferAttributes) val joinedRow = new JoinedRow() aggregationMode match { // Partial-only case (Some(Partial), None) => val updateExpressions = allAggregateFunctions.flatMap(_.updateExpressions) - val algebraicUpdateProjection = + val updateProjection = newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() (currentBuffer: UnsafeRow, row: InternalRow) => { - algebraicUpdateProjection.target(currentBuffer) - algebraicUpdateProjection(joinedRow(currentBuffer, row)) + updateProjection.target(currentBuffer) + updateProjection(joinedRow(currentBuffer, row)) } // PartialMerge-only or Final-only case (Some(PartialMerge), None) | (Some(Final), None) => val mergeExpressions = allAggregateFunctions.flatMap(_.mergeExpressions) - // This projection is used to merge buffer values for all AlgebraicAggregates. - val algebraicMergeProjection = - newMutableProjection( - mergeExpressions, - aggregationBufferAttributes ++ inputAttributes)() + val mergeProjection = + newMutableProjection(mergeExpressions, aggregationBufferAttributes ++ inputAttributes)() (currentBuffer: UnsafeRow, row: InternalRow) => { - // Process all algebraic aggregate functions. - algebraicMergeProjection.target(currentBuffer) - algebraicMergeProjection(joinedRow(currentBuffer, row)) + mergeProjection.target(currentBuffer) + mergeProjection(joinedRow(currentBuffer, row)) } // Final-Complete case (Some(Final), Some(Complete)) => - val nonCompleteAggregateFunctions: Array[AlgebraicAggregate] = + val nonCompleteAggregateFunctions: Array[ExpressionAggregate] = allAggregateFunctions.take(nonCompleteAggregateExpressions.length) - val completeAggregateFunctions: Array[AlgebraicAggregate] = + val completeAggregateFunctions: Array[ExpressionAggregate] = allAggregateFunctions.takeRight(completeAggregateExpressions.length) val completeOffsetExpressions = - Seq.fill(completeAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp) + Seq.fill(completeAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp) val mergeExpressions = nonCompleteAggregateFunctions.flatMap(_.mergeExpressions) ++ completeOffsetExpressions - val finalAlgebraicMergeProjection = - newMutableProjection( - mergeExpressions, - aggregationBufferAttributes ++ inputAttributes)() + val finalMergeProjection = + newMutableProjection(mergeExpressions, aggregationBufferAttributes ++ inputAttributes)() // We do not touch buffer values of aggregate functions with the Final mode. val finalOffsetExpressions = - Seq.fill(nonCompleteAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp) + Seq.fill(nonCompleteAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp) val updateExpressions = finalOffsetExpressions ++ completeAggregateFunctions.flatMap(_.updateExpressions) - val completeAlgebraicUpdateProjection = + val completeUpdateProjection = newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() (currentBuffer: UnsafeRow, row: InternalRow) => { val input = joinedRow(currentBuffer, row) // For all aggregate functions with mode Complete, update the given currentBuffer. - completeAlgebraicUpdateProjection.target(currentBuffer)(input) + completeUpdateProjection.target(currentBuffer)(input) // For all aggregate functions with mode Final, merge buffer values in row to // currentBuffer. - finalAlgebraicMergeProjection.target(currentBuffer)(input) + finalMergeProjection.target(currentBuffer)(input) } // Complete-only case (None, Some(Complete)) => - val completeAggregateFunctions: Array[AlgebraicAggregate] = + val completeAggregateFunctions: Array[ExpressionAggregate] = allAggregateFunctions.takeRight(completeAggregateExpressions.length) val updateExpressions = completeAggregateFunctions.flatMap(_.updateExpressions) - val completeAlgebraicUpdateProjection = + val completeUpdateProjection = newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() (currentBuffer: UnsafeRow, row: InternalRow) => { - completeAlgebraicUpdateProjection.target(currentBuffer) + completeUpdateProjection.target(currentBuffer) // For all aggregate functions with mode Complete, update the given currentBuffer. - completeAlgebraicUpdateProjection(joinedRow(currentBuffer, row)) + completeUpdateProjection(joinedRow(currentBuffer, row)) } // Grouping only. @@ -272,7 +262,7 @@ class TungstenAggregationIterator( private def generateResultProjection(): (UnsafeRow, UnsafeRow) => UnsafeRow = { val groupingAttributes = groupingExpressions.map(_.toAttribute) - val bufferAttributes = allAggregateFunctions.flatMap(_.bufferAttributes) + val bufferAttributes = allAggregateFunctions.flatMap(_.aggBufferAttributes) aggregationMode match { // Partial-only or PartialMerge-only: every output row is basically the values of @@ -339,7 +329,7 @@ class TungstenAggregationIterator( // all groups and their corresponding aggregation buffers for hash-based aggregation. private[this] val hashMap = new UnsafeFixedWidthAggregationMap( initialAggregationBuffer, - StructType.fromAttributes(allAggregateFunctions.flatMap(_.bufferAttributes)), + StructType.fromAttributes(allAggregateFunctions.flatMap(_.aggBufferAttributes)), StructType.fromAttributes(groupingExpressions.map(_.toAttribute)), TaskContext.get.taskMemoryManager(), SparkEnv.get.shuffleMemoryManager, @@ -480,7 +470,7 @@ class TungstenAggregationIterator( // The originalInputAttributes are using cloneBufferAttributes. So, we need to use // allAggregateFunctions.flatMap(_.cloneBufferAttributes). val bufferExtractor = newMutableProjection( - allAggregateFunctions.flatMap(_.cloneBufferAttributes), + allAggregateFunctions.flatMap(_.inputAggBufferAttributes), originalInputAttributes)() bufferExtractor.target(buffer) @@ -510,7 +500,7 @@ class TungstenAggregationIterator( // Basically the value of the KVIterator returned by externalSorter // will just aggregation buffer. At here, we use cloneBufferAttributes. val newInputAttributes: Seq[Attribute] = - allAggregateFunctions.flatMap(_.cloneBufferAttributes) + allAggregateFunctions.flatMap(_.inputAggBufferAttributes) // Set up new processRow and generateOutput. processRow = generateProcessRow(newInputAttributes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index 1114fe6552..fd02be1225 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.catalyst.expressions.{MutableRow, InterpretedMutableProjection, AttributeReference, Expression} -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction2 +import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, AggregateFunction2} import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types._ @@ -318,13 +318,11 @@ private[sql] class InputAggregationBuffer private[sql] ( /** * The internal wrapper used to hook a [[UserDefinedAggregateFunction]] `udaf` in the * internal aggregation code path. - * @param children - * @param udaf */ private[sql] case class ScalaUDAF( children: Seq[Expression], udaf: UserDefinedAggregateFunction) - extends AggregateFunction2 with Logging { + extends ImperativeAggregate with Logging { require( children.length == udaf.inputSchema.length, @@ -339,11 +337,9 @@ private[sql] case class ScalaUDAF( override val inputTypes: Seq[DataType] = udaf.inputSchema.map(_.dataType) - override val bufferSchema: StructType = udaf.bufferSchema + override val aggBufferSchema: StructType = udaf.bufferSchema - override val bufferAttributes: Seq[AttributeReference] = bufferSchema.toAttributes - - override lazy val cloneBufferAttributes = bufferAttributes.map(_.newInstance()) + override val aggBufferAttributes: Seq[AttributeReference] = aggBufferSchema.toAttributes private[this] lazy val childrenSchema: StructType = { val inputFields = children.zipWithIndex.map { @@ -370,13 +366,13 @@ private[sql] case class ScalaUDAF( CatalystTypeConverters.createToScalaConverter(childrenSchema) private[this] lazy val bufferValuesToCatalystConverters: Array[Any => Any] = { - bufferSchema.fields.map { field => + aggBufferSchema.fields.map { field => CatalystTypeConverters.createToCatalystConverter(field.dataType) } } private[this] lazy val bufferValuesToScalaConverters: Array[Any => Any] = { - bufferSchema.fields.map { field => + aggBufferSchema.fields.map { field => CatalystTypeConverters.createToScalaConverter(field.dataType) } } @@ -398,15 +394,15 @@ private[sql] case class ScalaUDAF( * Sets the inputBufferOffset to newInputBufferOffset and then create a new instance of * `inputAggregateBuffer` based on this new inputBufferOffset. */ - override def withNewInputBufferOffset(newInputBufferOffset: Int): Unit = { - super.withNewInputBufferOffset(newInputBufferOffset) + override def withNewInputAggBufferOffset(newInputBufferOffset: Int): Unit = { + super.withNewInputAggBufferOffset(newInputBufferOffset) // inputBufferOffset has been updated. inputAggregateBuffer = new InputAggregationBuffer( - bufferSchema, + aggBufferSchema, bufferValuesToCatalystConverters, bufferValuesToScalaConverters, - inputBufferOffset, + inputAggBufferOffset, null) } @@ -414,22 +410,22 @@ private[sql] case class ScalaUDAF( * Sets the mutableBufferOffset to newMutableBufferOffset and then create a new instance of * `mutableAggregateBuffer` and `evalAggregateBuffer` based on this new mutableBufferOffset. */ - override def withNewMutableBufferOffset(newMutableBufferOffset: Int): Unit = { - super.withNewMutableBufferOffset(newMutableBufferOffset) + override def withNewMutableAggBufferOffset(newMutableBufferOffset: Int): Unit = { + super.withNewMutableAggBufferOffset(newMutableBufferOffset) // mutableBufferOffset has been updated. mutableAggregateBuffer = new MutableAggregationBufferImpl( - bufferSchema, + aggBufferSchema, bufferValuesToCatalystConverters, bufferValuesToScalaConverters, - mutableBufferOffset, + mutableAggBufferOffset, null) evalAggregateBuffer = new InputAggregationBuffer( - bufferSchema, + aggBufferSchema, bufferValuesToCatalystConverters, bufferValuesToScalaConverters, - mutableBufferOffset, + mutableAggBufferOffset, null) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala index 4f5e86cceb..e1d7e1bf02 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala @@ -97,10 +97,10 @@ object Utils { // Check if we can use TungstenAggregate. val usesTungstenAggregate = child.sqlContext.conf.unsafeEnabled && - aggregateExpressions.forall(_.aggregateFunction.isInstanceOf[AlgebraicAggregate]) && + aggregateExpressions.forall(_.aggregateFunction.isInstanceOf[ExpressionAggregate]) && supportsTungstenAggregate( groupingExpressions, - aggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes)) + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) // 1. Create an Aggregate Operator for partial aggregations. @@ -117,10 +117,10 @@ object Utils { val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute) val partialAggregateExpressions = aggregateExpressions.map(_.copy(mode = Partial)) val partialAggregateAttributes = - partialAggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes) + partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) val partialResultExpressions = namedGroupingAttributes ++ - partialAggregateExpressions.flatMap(_.aggregateFunction.cloneBufferAttributes) + partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) val partialAggregate = if (usesTungstenAggregate) { TungstenAggregate( @@ -128,7 +128,6 @@ object Utils { groupingExpressions = namedGroupingExpressions.map(_._2), nonCompleteAggregateExpressions = partialAggregateExpressions, completeAggregateExpressions = Nil, - initialInputBufferOffset = 0, resultExpressions = partialResultExpressions, child = child) } else { @@ -158,7 +157,7 @@ object Utils { // aggregateFunctionMap contains unique aggregate functions. val aggregateFunction = aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct)._1 - aggregateFunction.asInstanceOf[AlgebraicAggregate].evaluateExpression + aggregateFunction.asInstanceOf[ExpressionAggregate].evaluateExpression case expression => // We do not rely on the equality check at here since attributes may // different cosmetically. Instead, we use semanticEquals. @@ -173,7 +172,6 @@ object Utils { groupingExpressions = namedGroupingAttributes, nonCompleteAggregateExpressions = finalAggregateExpressions, completeAggregateExpressions = Nil, - initialInputBufferOffset = namedGroupingAttributes.length, resultExpressions = rewrittenResultExpressions, child = partialAggregate) } else { @@ -216,10 +214,11 @@ object Utils { val aggregateExpressions = functionsWithDistinct ++ functionsWithoutDistinct val usesTungstenAggregate = child.sqlContext.conf.unsafeEnabled && - aggregateExpressions.forall(_.aggregateFunction.isInstanceOf[AlgebraicAggregate]) && + aggregateExpressions.forall( + _.aggregateFunction.isInstanceOf[ExpressionAggregate]) && supportsTungstenAggregate( groupingExpressions, - aggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes)) + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) // 1. Create an Aggregate Operator for partial aggregations. // The grouping expressions are original groupingExpressions and @@ -252,20 +251,19 @@ object Utils { val partialAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) val partialAggregateAttributes = - partialAggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes) + partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) val partialAggregateGroupingExpressions = (namedGroupingExpressions ++ namedDistinctColumnExpressions).map(_._2) val partialAggregateResult = namedGroupingAttributes ++ distinctColumnAttributes ++ - partialAggregateExpressions.flatMap(_.aggregateFunction.cloneBufferAttributes) + partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) val partialAggregate = if (usesTungstenAggregate) { TungstenAggregate( requiredChildDistributionExpressions = None: Option[Seq[Expression]], groupingExpressions = partialAggregateGroupingExpressions, nonCompleteAggregateExpressions = partialAggregateExpressions, completeAggregateExpressions = Nil, - initialInputBufferOffset = 0, resultExpressions = partialAggregateResult, child = child) } else { @@ -284,18 +282,17 @@ object Utils { // 2. Create an Aggregate Operator for partial merge aggregations. val partialMergeAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) val partialMergeAggregateAttributes = - partialMergeAggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes) + partialMergeAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) val partialMergeAggregateResult = namedGroupingAttributes ++ distinctColumnAttributes ++ - partialMergeAggregateExpressions.flatMap(_.aggregateFunction.cloneBufferAttributes) + partialMergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) val partialMergeAggregate = if (usesTungstenAggregate) { TungstenAggregate( requiredChildDistributionExpressions = Some(namedGroupingAttributes), groupingExpressions = namedGroupingAttributes ++ distinctColumnAttributes, nonCompleteAggregateExpressions = partialMergeAggregateExpressions, completeAggregateExpressions = Nil, - initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length, resultExpressions = partialMergeAggregateResult, child = partialAggregate) } else { @@ -362,7 +359,7 @@ object Utils { // aggregate functions that have not been rewritten. aggregateFunctionMap(function, isDistinct)._1 } - aggregateFunction.asInstanceOf[AlgebraicAggregate].evaluateExpression + aggregateFunction.asInstanceOf[ExpressionAggregate].evaluateExpression case expression => // We do not rely on the equality check at here since attributes may // different cosmetically. Instead, we use semanticEquals. @@ -377,7 +374,6 @@ object Utils { groupingExpressions = namedGroupingAttributes, nonCompleteAggregateExpressions = finalAggregateExpressions, completeAggregateExpressions = completeAggregateExpressions, - initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length, resultExpressions = rewrittenResultExpressions, child = partialMergeAggregate) } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala index afda0d29f6..7ca677a6c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala @@ -38,7 +38,7 @@ class TungstenAggregationIteratorSuite extends SparkFunSuite with SharedSQLConte () => new InterpretedMutableProjection(expr, schema) } val dummyAccum = SQLMetrics.createLongMetric(sparkContext, "dummy") - iter = new TungstenAggregationIterator(Seq.empty, Seq.empty, Seq.empty, 0, + iter = new TungstenAggregationIterator(Seq.empty, Seq.empty, Seq.empty, Seq.empty, newMutableProjection, Seq.empty, None, dummyAccum, dummyAccum) val numPages = iter.getHashMap.getNumDataPages assert(numPages === 1) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index a85d4db88d..18bbdb9908 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -554,7 +554,7 @@ private[hive] case class HiveUDAFFunction( funcWrapper: HiveFunctionWrapper, children: Seq[Expression], isUDAFBridgeRequired: Boolean = false) - extends AggregateFunction2 with HiveInspectors { + extends ImperativeAggregate with HiveInspectors { def this() = this(null, null) @@ -598,7 +598,7 @@ private[hive] case class HiveUDAFFunction( // Hive UDAF has its own buffer, so we don't need to occupy a slot in the aggregation // buffer for it. - override def bufferSchema: StructType = StructType(Nil) + override def aggBufferSchema: StructType = StructType(Nil) override def update(_buffer: MutableRow, input: InternalRow): Unit = { val inputs = inputProjection(input) @@ -610,13 +610,11 @@ private[hive] case class HiveUDAFFunction( "Hive UDAF doesn't support partial aggregate") } - override def cloneBufferAttributes: Seq[Attribute] = Nil - override def initialize(_buffer: MutableRow): Unit = { buffer = function.getNewAggregationBuffer } - override def bufferAttributes: Seq[AttributeReference] = Nil + override def aggBufferAttributes: Seq[AttributeReference] = Nil // We rely on Hive to check the input data types, so use `AnyDataType` here to bypass our // catalyst type checking framework. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 24b1846923..c9e1bb1995 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -343,7 +343,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(null, null, 110.0, null, null, 10.0) :: Nil) } - test("non-AlgebraicAggregate aggreguate function") { + test("interpreted aggregate function") { checkAnswer( sqlContext.sql( """ @@ -368,7 +368,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(null) :: Nil) } - test("non-AlgebraicAggregate and AlgebraicAggregate aggreguate function") { + test("interpreted and expression-based aggregation functions") { checkAnswer( sqlContext.sql( """ -- cgit v1.2.3