aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@databricks.com>2015-10-07 13:19:49 -0700
committerYin Huai <yhuai@databricks.com>2015-10-07 13:19:49 -0700
commita9ecd06149df4ccafd3927c35f63b9f03f170ae5 (patch)
tree99d41c5b5f8cd786e2c161c339373a740b4048eb
parent5be5d247440d6346d667c4b3d817666126f62906 (diff)
downloadspark-a9ecd06149df4ccafd3927c35f63b9f03f170ae5.tar.gz
spark-a9ecd06149df4ccafd3927c35f63b9f03f170ae5.tar.bz2
spark-a9ecd06149df4ccafd3927c35f63b9f03f170ae5.zip
[SPARK-10941] [SQL] Refactor AggregateFunction2 and AlgebraicAggregate interfaces to improve code clarity
This patch refactors several of the Aggregate2 interfaces in order to improve code clarity. The biggest change is a refactoring of the `AggregateFunction2` class hierarchy. In the old code, we had a class named `AlgebraicAggregate` that inherited from `AggregateFunction2`, added a new set of methods, then banned the use of the inherited methods. I found this to be fairly confusing because. If you look carefully at the existing code, you'll see that subclasses of `AggregateFunction2` fall into two disjoint categories: imperative aggregation functions which directly extended `AggregateFunction2` and declarative, expression-based aggregate functions which extended `AlgebraicAggregate`. In order to make this more explicit, this patch refactors things so that `AggregateFunction2` is a sealed abstract class with two subclasses, `ImperativeAggregateFunction` and `ExpressionAggregateFunction`. The superclass, `AggregateFunction2`, now only contains methods and fields that are common to both subclasses. After making this change, I updated the various AggregationIterator classes to comply with this new naming scheme. I also performed several small renamings in the aggregate interfaces themselves in order to improve clarity and rewrote or expanded a number of comments. Author: Josh Rosen <joshrosen@databricks.com> Closes #8973 from JoshRosen/tungsten-agg-comments.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala70
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala235
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala188
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala80
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala36
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala30
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala2
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala8
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala4
12 files changed, 356 insertions, 303 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
index e7f8104d43..4ad2607a85 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
-case class Average(child: Expression) extends AlgebraicAggregate {
+case class Average(child: Expression) extends ExpressionAggregate {
override def children: Seq[Expression] = child :: Nil
@@ -57,7 +57,7 @@ case class Average(child: Expression) extends AlgebraicAggregate {
private val currentSum = AttributeReference("currentSum", sumDataType)()
private val currentCount = AttributeReference("currentCount", LongType)()
- override val bufferAttributes = currentSum :: currentCount :: Nil
+ override val aggBufferAttributes = currentSum :: currentCount :: Nil
override val initialValues = Seq(
/* currentSum = */ Cast(Literal(0), sumDataType),
@@ -88,7 +88,7 @@ case class Average(child: Expression) extends AlgebraicAggregate {
}
}
-case class Count(child: Expression) extends AlgebraicAggregate {
+case class Count(child: Expression) extends ExpressionAggregate {
override def children: Seq[Expression] = child :: Nil
override def nullable: Boolean = false
@@ -101,7 +101,7 @@ case class Count(child: Expression) extends AlgebraicAggregate {
private val currentCount = AttributeReference("currentCount", LongType)()
- override val bufferAttributes = currentCount :: Nil
+ override val aggBufferAttributes = currentCount :: Nil
override val initialValues = Seq(
/* currentCount = */ Literal(0L)
@@ -118,7 +118,7 @@ case class Count(child: Expression) extends AlgebraicAggregate {
override val evaluateExpression = Cast(currentCount, LongType)
}
-case class First(child: Expression) extends AlgebraicAggregate {
+case class First(child: Expression) extends ExpressionAggregate {
override def children: Seq[Expression] = child :: Nil
@@ -135,7 +135,7 @@ case class First(child: Expression) extends AlgebraicAggregate {
private val first = AttributeReference("first", child.dataType)()
- override val bufferAttributes = first :: Nil
+ override val aggBufferAttributes = first :: Nil
override val initialValues = Seq(
/* first = */ Literal.create(null, child.dataType)
@@ -152,7 +152,7 @@ case class First(child: Expression) extends AlgebraicAggregate {
override val evaluateExpression = first
}
-case class Last(child: Expression) extends AlgebraicAggregate {
+case class Last(child: Expression) extends ExpressionAggregate {
override def children: Seq[Expression] = child :: Nil
@@ -169,7 +169,7 @@ case class Last(child: Expression) extends AlgebraicAggregate {
private val last = AttributeReference("last", child.dataType)()
- override val bufferAttributes = last :: Nil
+ override val aggBufferAttributes = last :: Nil
override val initialValues = Seq(
/* last = */ Literal.create(null, child.dataType)
@@ -186,7 +186,7 @@ case class Last(child: Expression) extends AlgebraicAggregate {
override val evaluateExpression = last
}
-case class Max(child: Expression) extends AlgebraicAggregate {
+case class Max(child: Expression) extends ExpressionAggregate {
override def children: Seq[Expression] = child :: Nil
@@ -200,7 +200,7 @@ case class Max(child: Expression) extends AlgebraicAggregate {
private val max = AttributeReference("max", child.dataType)()
- override val bufferAttributes = max :: Nil
+ override val aggBufferAttributes = max :: Nil
override val initialValues = Seq(
/* max = */ Literal.create(null, child.dataType)
@@ -220,7 +220,7 @@ case class Max(child: Expression) extends AlgebraicAggregate {
override val evaluateExpression = max
}
-case class Min(child: Expression) extends AlgebraicAggregate {
+case class Min(child: Expression) extends ExpressionAggregate {
override def children: Seq[Expression] = child :: Nil
@@ -234,7 +234,7 @@ case class Min(child: Expression) extends AlgebraicAggregate {
private val min = AttributeReference("min", child.dataType)()
- override val bufferAttributes = min :: Nil
+ override val aggBufferAttributes = min :: Nil
override val initialValues = Seq(
/* min = */ Literal.create(null, child.dataType)
@@ -277,7 +277,7 @@ case class StddevSamp(child: Expression) extends StddevAgg(child) {
// Compute standard deviation based on online algorithm specified here:
// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
-abstract class StddevAgg(child: Expression) extends AlgebraicAggregate {
+abstract class StddevAgg(child: Expression) extends ExpressionAggregate {
override def children: Seq[Expression] = child :: Nil
@@ -304,7 +304,7 @@ abstract class StddevAgg(child: Expression) extends AlgebraicAggregate {
private val currentAvg = AttributeReference("currentAvg", resultType)()
private val currentMk = AttributeReference("currentMk", resultType)()
- override val bufferAttributes = preCount :: currentCount :: preAvg ::
+ override val aggBufferAttributes = preCount :: currentCount :: preAvg ::
currentAvg :: currentMk :: Nil
override val initialValues = Seq(
@@ -397,7 +397,7 @@ abstract class StddevAgg(child: Expression) extends AlgebraicAggregate {
}
}
-case class Sum(child: Expression) extends AlgebraicAggregate {
+case class Sum(child: Expression) extends ExpressionAggregate {
override def children: Seq[Expression] = child :: Nil
@@ -429,7 +429,7 @@ case class Sum(child: Expression) extends AlgebraicAggregate {
private val zero = Cast(Literal(0), sumDataType)
- override val bufferAttributes = currentSum :: Nil
+ override val aggBufferAttributes = currentSum :: Nil
override val initialValues = Seq(
/* currentSum = */ Literal.create(null, sumDataType)
@@ -473,7 +473,7 @@ case class Sum(child: Expression) extends AlgebraicAggregate {
*/
// scalastyle:on
case class HyperLogLogPlusPlus(child: Expression, relativeSD: Double = 0.05)
- extends AggregateFunction2 {
+ extends ImperativeAggregate {
import HyperLogLogPlusPlus._
/**
@@ -531,28 +531,26 @@ case class HyperLogLogPlusPlus(child: Expression, relativeSD: Double = 0.05)
*/
private[this] val numWords = m / REGISTERS_PER_WORD + 1
- def children: Seq[Expression] = Seq(child)
+ override def children: Seq[Expression] = Seq(child)
- def nullable: Boolean = false
-
- def dataType: DataType = LongType
+ override def nullable: Boolean = false
- def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
+ override def dataType: DataType = LongType
- def bufferSchema: StructType = StructType.fromAttributes(bufferAttributes)
+ override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
- def cloneBufferAttributes: Seq[Attribute] = bufferAttributes.map(_.newInstance())
+ override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)
/** Allocate enough words to store all registers. */
- val bufferAttributes: Seq[AttributeReference] = Seq.tabulate(numWords) { i =>
+ override val aggBufferAttributes: Seq[AttributeReference] = Seq.tabulate(numWords) { i =>
AttributeReference(s"MS[$i]", LongType)()
}
/** Fill all words with zeros. */
- def initialize(buffer: MutableRow): Unit = {
+ override def initialize(buffer: MutableRow): Unit = {
var word = 0
while (word < numWords) {
- buffer.setLong(mutableBufferOffset + word, 0)
+ buffer.setLong(mutableAggBufferOffset + word, 0)
word += 1
}
}
@@ -562,7 +560,7 @@ case class HyperLogLogPlusPlus(child: Expression, relativeSD: Double = 0.05)
*
* Variable names in the HLL++ paper match variable names in the code.
*/
- def update(buffer: MutableRow, input: InternalRow): Unit = {
+ override def update(buffer: MutableRow, input: InternalRow): Unit = {
val v = child.eval(input)
if (v != null) {
// Create the hashed value 'x'.
@@ -576,7 +574,7 @@ case class HyperLogLogPlusPlus(child: Expression, relativeSD: Double = 0.05)
// Get the word containing the register we are interested in.
val wordOffset = idx / REGISTERS_PER_WORD
- val word = buffer.getLong(mutableBufferOffset + wordOffset)
+ val word = buffer.getLong(mutableAggBufferOffset + wordOffset)
// Extract the M[J] register value from the word.
val shift = REGISTER_SIZE * (idx - (wordOffset * REGISTERS_PER_WORD))
@@ -585,7 +583,7 @@ case class HyperLogLogPlusPlus(child: Expression, relativeSD: Double = 0.05)
// Assign the maximum number of leading zeros to the register.
if (pw > Midx) {
- buffer.setLong(mutableBufferOffset + wordOffset, (word & ~mask) | (pw << shift))
+ buffer.setLong(mutableAggBufferOffset + wordOffset, (word & ~mask) | (pw << shift))
}
}
}
@@ -594,12 +592,12 @@ case class HyperLogLogPlusPlus(child: Expression, relativeSD: Double = 0.05)
* Merge the HLL buffers by iterating through the registers in both buffers and select the
* maximum number of leading zeros for each register.
*/
- def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
+ override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
var idx = 0
var wordOffset = 0
while (wordOffset < numWords) {
- val word1 = buffer1.getLong(mutableBufferOffset + wordOffset)
- val word2 = buffer2.getLong(inputBufferOffset + wordOffset)
+ val word1 = buffer1.getLong(mutableAggBufferOffset + wordOffset)
+ val word2 = buffer2.getLong(inputAggBufferOffset + wordOffset)
var word = 0L
var i = 0
var mask = REGISTER_WORD_MASK
@@ -609,7 +607,7 @@ case class HyperLogLogPlusPlus(child: Expression, relativeSD: Double = 0.05)
i += 1
idx += 1
}
- buffer1.setLong(mutableBufferOffset + wordOffset, word)
+ buffer1.setLong(mutableAggBufferOffset + wordOffset, word)
wordOffset += 1
}
}
@@ -664,14 +662,14 @@ case class HyperLogLogPlusPlus(child: Expression, relativeSD: Double = 0.05)
*
* Variable names in the HLL++ paper match variable names in the code.
*/
- def eval(buffer: InternalRow): Any = {
+ override def eval(buffer: InternalRow): Any = {
// Compute the inverse of indicator value 'z' and count the number of zeros 'V'.
var zInverse = 0.0d
var V = 0.0d
var idx = 0
var wordOffset = 0
while (wordOffset < numWords) {
- val word = buffer.getLong(mutableBufferOffset + wordOffset)
+ val word = buffer.getLong(mutableAggBufferOffset + wordOffset)
var i = 0
var shift = 0
while (idx < m && i < REGISTERS_PER_WORD) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index d8699533cd..74e15ec90b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -69,9 +69,6 @@ private[sql] case object NoOp extends Expression with Unevaluable {
/**
* A container for an [[AggregateFunction2]] with its [[AggregateMode]] and a field
* (`isDistinct`) indicating if DISTINCT keyword is specified for this function.
- * @param aggregateFunction
- * @param mode
- * @param isDistinct
*/
private[sql] case class AggregateExpression2(
aggregateFunction: AggregateFunction2,
@@ -86,7 +83,7 @@ private[sql] case class AggregateExpression2(
override def references: AttributeSet = {
val childReferences = mode match {
case Partial | Complete => aggregateFunction.references.toSeq
- case PartialMerge | Final => aggregateFunction.bufferAttributes
+ case PartialMerge | Final => aggregateFunction.aggBufferAttributes
}
AttributeSet(childReferences)
@@ -95,98 +92,192 @@ private[sql] case class AggregateExpression2(
override def toString: String = s"(${aggregateFunction},mode=$mode,isDistinct=$isDistinct)"
}
-abstract class AggregateFunction2
- extends Expression with ImplicitCastInputTypes {
+/**
+ * AggregateFunction2 is the superclass of two aggregation function interfaces:
+ *
+ * - [[ImperativeAggregate]] is for aggregation functions that are specified in terms of
+ * initialize(), update(), and merge() functions that operate on Row-based aggregation buffers.
+ * - [[ExpressionAggregate]] is for aggregation functions that are specified using
+ * Catalyst expressions.
+ *
+ * In both interfaces, aggregates must define the schema ([[aggBufferSchema]]) and attributes
+ * ([[aggBufferAttributes]]) of an aggregation buffer which is used to hold partial aggregate
+ * results. At runtime, multiple aggregate functions are evaluated by the same operator using a
+ * combined aggregation buffer which concatenates the aggregation buffers of the individual
+ * aggregate functions.
+ *
+ * Code which accepts [[AggregateFunction2]] instances should be prepared to handle both types of
+ * aggregate functions.
+ */
+sealed abstract class AggregateFunction2 extends Expression with ImplicitCastInputTypes {
/** An aggregate function is not foldable. */
final override def foldable: Boolean = false
+ /** The schema of the aggregation buffer. */
+ def aggBufferSchema: StructType
+
+ /** Attributes of fields in aggBufferSchema. */
+ def aggBufferAttributes: Seq[AttributeReference]
+
/**
- * The offset of this function's start buffer value in the
- * underlying shared mutable aggregation buffer.
- * For example, we have two aggregate functions `avg(x)` and `avg(y)`, which share
- * the same aggregation buffer. In this shared buffer, the position of the first
- * buffer value of `avg(x)` will be 0 and the position of the first buffer value of `avg(y)`
- * will be 2.
+ * Attributes of fields in input aggregation buffers (immutable aggregation buffers that are
+ * merged with mutable aggregation buffers in the merge() function or merge expressions).
+ * These attributes are created automatically by cloning the [[aggBufferAttributes]].
*/
- protected var mutableBufferOffset: Int = 0
-
- def withNewMutableBufferOffset(newMutableBufferOffset: Int): Unit = {
- mutableBufferOffset = newMutableBufferOffset
- }
+ def inputAggBufferAttributes: Seq[AttributeReference]
/**
- * The offset of this function's start buffer value in the
- * underlying shared input aggregation buffer. An input aggregation buffer is used
- * when we merge two aggregation buffers and it is basically the immutable one
- * (we merge an input aggregation buffer and a mutable aggregation buffer and
- * then store the new buffer values to the mutable aggregation buffer).
- * Usually, an input aggregation buffer also contain extra elements like grouping
- * keys at the beginning. So, mutableBufferOffset and inputBufferOffset are often
- * different.
- * For example, we have a grouping expression `key``, and two aggregate functions
- * `avg(x)` and `avg(y)`. In this shared input aggregation buffer, the position of the first
- * buffer value of `avg(x)` will be 1 and the position of the first buffer value of `avg(y)`
- * will be 3 (position 0 is used for the value of key`).
+ * Indicates if this function supports partial aggregation.
+ * Currently Hive UDAF is the only one that doesn't support partial aggregation.
*/
- protected var inputBufferOffset: Int = 0
+ def supportsPartial: Boolean = true
- def withNewInputBufferOffset(newInputBufferOffset: Int): Unit = {
- inputBufferOffset = newInputBufferOffset
- }
+ override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String =
+ throw new UnsupportedOperationException(s"Cannot evaluate expression: $this")
+}
- /** The schema of the aggregation buffer. */
- def bufferSchema: StructType
+/**
+ * API for aggregation functions that are expressed in terms of imperative initialize(), update(),
+ * and merge() functions which operate on Row-based aggregation buffers.
+ *
+ * Within these functions, code should access fields of the mutable aggregation buffer by adding the
+ * bufferSchema-relative field number to `mutableAggBufferOffset` then using this new field number
+ * to access the buffer Row. This is necessary because this aggregation function's buffer is
+ * embedded inside of a larger shared aggregation buffer when an aggregation operator evaluates
+ * multiple aggregate functions at the same time.
+ *
+ * We need to perform similar field number arithmetic when merging multiple intermediate
+ * aggregate buffers together in `merge()` (in this case, use `inputAggBufferOffset` when accessing
+ * the input buffer).
+ */
+abstract class ImperativeAggregate extends AggregateFunction2 {
- /** Attributes of fields in bufferSchema. */
- def bufferAttributes: Seq[AttributeReference]
+ /**
+ * The offset of this function's first buffer value in the underlying shared mutable aggregation
+ * buffer.
+ *
+ * For example, we have two aggregate functions `avg(x)` and `avg(y)`, which share the same
+ * aggregation buffer. In this shared buffer, the position of the first buffer value of `avg(x)`
+ * will be 0 and the position of the first buffer value of `avg(y)` will be 2:
+ *
+ * avg(x) mutableAggBufferOffset = 0
+ * |
+ * v
+ * +--------+--------+--------+--------+
+ * | sum1 | count1 | sum2 | count2 |
+ * +--------+--------+--------+--------+
+ * ^
+ * |
+ * avg(y) mutableAggBufferOffset = 2
+ *
+ */
+ protected var mutableAggBufferOffset: Int = 0
- /** Clones bufferAttributes. */
- def cloneBufferAttributes: Seq[Attribute]
+ def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): Unit = {
+ mutableAggBufferOffset = newMutableAggBufferOffset
+ }
/**
- * Initializes its aggregation buffer located in `buffer`.
- * It will use bufferOffset to find the starting point of
- * its buffer in the given `buffer` shared with other functions.
+ * The offset of this function's start buffer value in the underlying shared input aggregation
+ * buffer. An input aggregation buffer is used when we merge two aggregation buffers together in
+ * the `update()` function and is immutable (we merge an input aggregation buffer and a mutable
+ * aggregation buffer and then store the new buffer values to the mutable aggregation buffer).
+ *
+ * An input aggregation buffer may contain extra fields, such as grouping keys, at its start, so
+ * mutableAggBufferOffset and inputAggBufferOffset are often different.
+ *
+ * For example, say we have a grouping expression, `key`, and two aggregate functions,
+ * `avg(x)` and `avg(y)`. In the shared input aggregation buffer, the position of the first
+ * buffer value of `avg(x)` will be 1 and the position of the first buffer value of `avg(y)`
+ * will be 3 (position 0 is used for the value of `key`):
+ *
+ * avg(x) inputAggBufferOffset = 1
+ * |
+ * v
+ * +--------+--------+--------+--------+--------+
+ * | key | sum1 | count1 | sum2 | count2 |
+ * +--------+--------+--------+--------+--------+
+ * ^
+ * |
+ * avg(y) inputAggBufferOffset = 3
+ *
*/
- def initialize(buffer: MutableRow): Unit
+ protected var inputAggBufferOffset: Int = 0
+
+ def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): Unit = {
+ inputAggBufferOffset = newInputAggBufferOffset
+ }
/**
- * Updates its aggregation buffer located in `buffer` based on the given `input`.
- * It will use bufferOffset to find the starting point of its buffer in the given `buffer`
- * shared with other functions.
+ * Initializes the mutable aggregation buffer located in `mutableAggBuffer`.
+ *
+ * Use `fieldNumber + mutableAggBufferOffset` to access fields of `mutableAggBuffer`.
*/
- def update(buffer: MutableRow, input: InternalRow): Unit
+ def initialize(mutableAggBuffer: MutableRow): Unit
/**
- * Updates its aggregation buffer located in `buffer1` by combining intermediate results
- * in the current buffer and intermediate results from another buffer `buffer2`.
- * It will use bufferOffset to find the starting point of its buffer in the given `buffer1`
- * and `buffer2`.
+ * Updates its aggregation buffer, located in `mutableAggBuffer`, based on the given `inputRow`.
+ *
+ * Use `fieldNumber + mutableAggBufferOffset` to access fields of `mutableAggBuffer`.
*/
- def merge(buffer1: MutableRow, buffer2: InternalRow): Unit
-
- override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String =
- throw new UnsupportedOperationException(s"Cannot evaluate expression: $this")
+ def update(mutableAggBuffer: MutableRow, inputRow: InternalRow): Unit
/**
- * Indicates if this function supports partial aggregation.
- * Currently Hive UDAF is the only one that doesn't support partial aggregation.
+ * Combines new intermediate results from the `inputAggBuffer` with the existing intermediate
+ * results in the `mutableAggBuffer.`
+ *
+ * Use `fieldNumber + mutableAggBufferOffset` to access fields of `mutableAggBuffer`.
+ * Use `fieldNumber + inputAggBufferOffset` to access fields of `inputAggBuffer`.
*/
- def supportsPartial: Boolean = true
+ def merge(mutableAggBuffer: MutableRow, inputAggBuffer: InternalRow): Unit
+
+ final lazy val inputAggBufferAttributes: Seq[AttributeReference] =
+ aggBufferAttributes.map(_.newInstance())
}
/**
- * A helper class for aggregate functions that can be implemented in terms of catalyst expressions.
+ * API for aggregation functions that are expressed in terms of Catalyst expressions.
+ *
+ * When implementing a new expression-based aggregate function, start by implementing
+ * `bufferAttributes`, defining attributes for the fields of the mutable aggregation buffer. You
+ * can then use these attributes when defining `updateExpressions`, `mergeExpressions`, and
+ * `evaluateExpressions`.
*/
-abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable with Unevaluable {
+abstract class ExpressionAggregate
+ extends AggregateFunction2
+ with Serializable
+ with Unevaluable {
+ /**
+ * Expressions for initializing empty aggregation buffers.
+ */
val initialValues: Seq[Expression]
+
+ /**
+ * Expressions for updating the mutable aggregation buffer based on an input row.
+ */
val updateExpressions: Seq[Expression]
+
+ /**
+ * A sequence of expressions for merging two aggregation buffers together. When defining these
+ * expressions, you can use the syntax `attributeName.left` and `attributeName.right` to refer
+ * to the attributes corresponding to each of the buffers being merged (this magic is enabled
+ * by the [[RichAttribute]] implicit class).
+ */
val mergeExpressions: Seq[Expression]
+
+ /**
+ * An expression which returns the final value for this aggregate function. Its data type should
+ * match this expression's [[dataType]].
+ */
val evaluateExpression: Expression
- override lazy val cloneBufferAttributes = bufferAttributes.map(_.newInstance())
+ /** An expression-based aggregate's bufferSchema is derived from bufferAttributes. */
+ final override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)
+
+ final lazy val inputAggBufferAttributes: Seq[AttributeReference] =
+ aggBufferAttributes.map(_.newInstance())
/**
* A helper class for representing an attribute used in merging two
@@ -194,33 +285,13 @@ abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable w
* we merge buffer values and then update bufferLeft. A [[RichAttribute]]
* of an [[AttributeReference]] `a` has two functions `left` and `right`,
* which represent `a` in `bufferLeft` and `bufferRight`, respectively.
- * @param a
*/
implicit class RichAttribute(a: AttributeReference) {
/** Represents this attribute at the mutable buffer side. */
def left: AttributeReference = a
/** Represents this attribute at the input buffer side (the data value is read-only). */
- def right: AttributeReference = cloneBufferAttributes(bufferAttributes.indexOf(a))
- }
-
- /** An AlgebraicAggregate's bufferSchema is derived from bufferAttributes. */
- override def bufferSchema: StructType = StructType.fromAttributes(bufferAttributes)
-
- override def initialize(buffer: MutableRow): Unit = {
- throw new UnsupportedOperationException(
- "AlgebraicAggregate's initialize should not be called directly")
- }
-
- override final def update(buffer: MutableRow, input: InternalRow): Unit = {
- throw new UnsupportedOperationException(
- "AlgebraicAggregate's update should not be called directly")
+ def right: AttributeReference = inputAggBufferAttributes(aggBufferAttributes.indexOf(a))
}
-
- override final def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
- throw new UnsupportedOperationException(
- "AlgebraicAggregate's merge should not be called directly")
- }
-
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala
index ecc0644164..0d32949775 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala
@@ -38,7 +38,7 @@ class HyperLogLogPlusPlusSuite extends SparkFunSuite {
}
def createBuffer(hll: HyperLogLogPlusPlus): MutableRow = {
- val buffer = new SpecificMutableRow(hll.bufferAttributes.map(_.dataType))
+ val buffer = new SpecificMutableRow(hll.aggBufferAttributes.map(_.dataType))
hll.initialize(buffer)
buffer
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
index 62dbc07e88..04903022e5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
@@ -85,9 +85,9 @@ abstract class AggregationIterator(
while (i < allAggregateExpressions.length) {
val func = allAggregateExpressions(i).aggregateFunction
val funcWithBoundReferences = allAggregateExpressions(i).mode match {
- case Partial | Complete if !func.isInstanceOf[AlgebraicAggregate] =>
+ case Partial | Complete if func.isInstanceOf[ImperativeAggregate] =>
// We need to create BoundReferences if the function is not an
- // AlgebraicAggregate (it does not support code-gen) and the mode of
+ // expression-based aggregate function (it does not support code-gen) and the mode of
// this function is Partial or Complete because we will call eval of this
// function's children in the update method of this aggregate function.
// Those eval calls require BoundReferences to work.
@@ -95,31 +95,39 @@ abstract class AggregationIterator(
case _ =>
// We only need to set inputBufferOffset for aggregate functions with mode
// PartialMerge and Final.
- func.withNewInputBufferOffset(inputBufferOffset)
- inputBufferOffset += func.bufferSchema.length
+ func match {
+ case function: ImperativeAggregate =>
+ function.withNewInputAggBufferOffset(inputBufferOffset)
+ case _ =>
+ }
+ inputBufferOffset += func.aggBufferSchema.length
func
}
// Set mutableBufferOffset for this function. It is important that setting
// mutableBufferOffset happens after all potential bindReference operations
// because bindReference will create a new instance of the function.
- funcWithBoundReferences.withNewMutableBufferOffset(mutableBufferOffset)
- mutableBufferOffset += funcWithBoundReferences.bufferSchema.length
+ funcWithBoundReferences match {
+ case function: ImperativeAggregate =>
+ function.withNewMutableAggBufferOffset(mutableBufferOffset)
+ case _ =>
+ }
+ mutableBufferOffset += funcWithBoundReferences.aggBufferSchema.length
functions(i) = funcWithBoundReferences
i += 1
}
functions
}
- // Positions of those non-algebraic aggregate functions in allAggregateFunctions.
+ // Positions of those imperative aggregate functions in allAggregateFunctions.
// For example, we have func1, func2, func3, func4 in aggregateFunctions, and
- // func2 and func3 are non-algebraic aggregate functions.
- // nonAlgebraicAggregateFunctionPositions will be [1, 2].
- private[this] val allNonAlgebraicAggregateFunctionPositions: Array[Int] = {
+ // func2 and func3 are imperative aggregate functions.
+ // ImperativeAggregateFunctionPositions will be [1, 2].
+ private[this] val allImperativeAggregateFunctionPositions: Array[Int] = {
val positions = new ArrayBuffer[Int]()
var i = 0
while (i < allAggregateFunctions.length) {
allAggregateFunctions(i) match {
- case agg: AlgebraicAggregate =>
+ case agg: ExpressionAggregate =>
case _ => positions += i
}
i += 1
@@ -131,24 +139,26 @@ abstract class AggregationIterator(
private[this] val nonCompleteAggregateFunctions: Array[AggregateFunction2] =
allAggregateFunctions.take(nonCompleteAggregateExpressions.length)
- // All non-algebraic aggregate functions with mode Partial, PartialMerge, or Final.
- private[this] val nonCompleteNonAlgebraicAggregateFunctions: Array[AggregateFunction2] =
- nonCompleteAggregateFunctions.collect {
- case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func
- }
+ // All imperative aggregate functions with mode Partial, PartialMerge, or Final.
+ private[this] val nonCompleteImperativeAggregateFunctions: Array[ImperativeAggregate] =
+ nonCompleteAggregateFunctions.collect { case func: ImperativeAggregate => func }
- // The projection used to initialize buffer values for all AlgebraicAggregates.
- private[this] val algebraicInitialProjection = {
+ // The projection used to initialize buffer values for all expression-based aggregates.
+ private[this] val expressionAggInitialProjection = {
val initExpressions = allAggregateFunctions.flatMap {
- case ae: AlgebraicAggregate => ae.initialValues
- case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+ case ae: ExpressionAggregate => ae.initialValues
+ // For the positions corresponding to imperative aggregate functions, we'll use special
+ // no-op expressions which are ignored during projection code-generation.
+ case i: ImperativeAggregate => Seq.fill(i.aggBufferAttributes.length)(NoOp)
}
newMutableProjection(initExpressions, Nil)()
}
- // All non-Algebraic AggregateFunctions.
- private[this] val allNonAlgebraicAggregateFunctions =
- allNonAlgebraicAggregateFunctionPositions.map(allAggregateFunctions)
+ // All imperative AggregateFunctions.
+ private[this] val allImperativeAggregateFunctions: Array[ImperativeAggregate] =
+ allImperativeAggregateFunctionPositions
+ .map(allAggregateFunctions)
+ .map(_.asInstanceOf[ImperativeAggregate])
///////////////////////////////////////////////////////////////////////////
// Methods and fields used by sub-classes.
@@ -157,25 +167,25 @@ abstract class AggregationIterator(
// Initializing functions used to process a row.
protected val processRow: (MutableRow, InternalRow) => Unit = {
val rowToBeProcessed = new JoinedRow
- val aggregationBufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes)
+ val aggregationBufferSchema = allAggregateFunctions.flatMap(_.aggBufferAttributes)
aggregationMode match {
// Partial-only
case (Some(Partial), None) =>
val updateExpressions = nonCompleteAggregateFunctions.flatMap {
- case ae: AlgebraicAggregate => ae.updateExpressions
- case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+ case ae: ExpressionAggregate => ae.updateExpressions
+ case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
}
- val algebraicUpdateProjection =
+ val expressionAggUpdateProjection =
newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)()
(currentBuffer: MutableRow, row: InternalRow) => {
- algebraicUpdateProjection.target(currentBuffer)
- // Process all algebraic aggregate functions.
- algebraicUpdateProjection(rowToBeProcessed(currentBuffer, row))
- // Process all non-algebraic aggregate functions.
+ expressionAggUpdateProjection.target(currentBuffer)
+ // Process all expression-based aggregate functions.
+ expressionAggUpdateProjection(rowToBeProcessed(currentBuffer, row))
+ // Process all imperative aggregate functions.
var i = 0
- while (i < nonCompleteNonAlgebraicAggregateFunctions.length) {
- nonCompleteNonAlgebraicAggregateFunctions(i).update(currentBuffer, row)
+ while (i < nonCompleteImperativeAggregateFunctions.length) {
+ nonCompleteImperativeAggregateFunctions(i).update(currentBuffer, row)
i += 1
}
}
@@ -186,30 +196,30 @@ abstract class AggregationIterator(
// If initialInputBufferOffset, the input value does not contain
// grouping keys.
// This part is pretty hacky.
- allAggregateFunctions.flatMap(_.cloneBufferAttributes).toSeq
+ allAggregateFunctions.flatMap(_.inputAggBufferAttributes).toSeq
} else {
- groupingKeyAttributes ++ allAggregateFunctions.flatMap(_.cloneBufferAttributes)
+ groupingKeyAttributes ++ allAggregateFunctions.flatMap(_.inputAggBufferAttributes)
}
// val inputAggregationBufferSchema =
// groupingKeyAttributes ++
// allAggregateFunctions.flatMap(_.cloneBufferAttributes)
val mergeExpressions = nonCompleteAggregateFunctions.flatMap {
- case ae: AlgebraicAggregate => ae.mergeExpressions
- case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+ case ae: ExpressionAggregate => ae.mergeExpressions
+ case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
}
- // This projection is used to merge buffer values for all AlgebraicAggregates.
- val algebraicMergeProjection =
+ // This projection is used to merge buffer values for all expression-based aggregates.
+ val expressionAggMergeProjection =
newMutableProjection(
mergeExpressions,
aggregationBufferSchema ++ inputAggregationBufferSchema)()
(currentBuffer: MutableRow, row: InternalRow) => {
- // Process all algebraic aggregate functions.
- algebraicMergeProjection.target(currentBuffer)(rowToBeProcessed(currentBuffer, row))
- // Process all non-algebraic aggregate functions.
+ // Process all expression-based aggregate functions.
+ expressionAggMergeProjection.target(currentBuffer)(rowToBeProcessed(currentBuffer, row))
+ // Process all imperative aggregate functions.
var i = 0
- while (i < nonCompleteNonAlgebraicAggregateFunctions.length) {
- nonCompleteNonAlgebraicAggregateFunctions(i).merge(currentBuffer, row)
+ while (i < nonCompleteImperativeAggregateFunctions.length) {
+ nonCompleteImperativeAggregateFunctions(i).merge(currentBuffer, row)
i += 1
}
}
@@ -218,57 +228,55 @@ abstract class AggregationIterator(
case (Some(Final), Some(Complete)) =>
val completeAggregateFunctions: Array[AggregateFunction2] =
allAggregateFunctions.takeRight(completeAggregateExpressions.length)
- // All non-algebraic aggregate functions with mode Complete.
- val completeNonAlgebraicAggregateFunctions: Array[AggregateFunction2] =
- completeAggregateFunctions.collect {
- case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func
- }
+ // All imperative aggregate functions with mode Complete.
+ val completeImperativeAggregateFunctions: Array[ImperativeAggregate] =
+ completeAggregateFunctions.collect { case func: ImperativeAggregate => func }
// The first initialInputBufferOffset values of the input aggregation buffer is
// for grouping expressions and distinct columns.
val groupingAttributesAndDistinctColumns = valueAttributes.take(initialInputBufferOffset)
val completeOffsetExpressions =
- Seq.fill(completeAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp)
+ Seq.fill(completeAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp)
// We do not touch buffer values of aggregate functions with the Final mode.
val finalOffsetExpressions =
- Seq.fill(nonCompleteAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp)
+ Seq.fill(nonCompleteAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp)
val mergeInputSchema =
aggregationBufferSchema ++
groupingAttributesAndDistinctColumns ++
- nonCompleteAggregateFunctions.flatMap(_.cloneBufferAttributes)
+ nonCompleteAggregateFunctions.flatMap(_.inputAggBufferAttributes)
val mergeExpressions =
nonCompleteAggregateFunctions.flatMap {
- case ae: AlgebraicAggregate => ae.mergeExpressions
- case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+ case ae: ExpressionAggregate => ae.mergeExpressions
+ case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
} ++ completeOffsetExpressions
- val finalAlgebraicMergeProjection =
+ val finalExpressionAggMergeProjection =
newMutableProjection(mergeExpressions, mergeInputSchema)()
val updateExpressions =
finalOffsetExpressions ++ completeAggregateFunctions.flatMap {
- case ae: AlgebraicAggregate => ae.updateExpressions
- case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+ case ae: ExpressionAggregate => ae.updateExpressions
+ case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
}
- val completeAlgebraicUpdateProjection =
+ val completeExpressionAggUpdateProjection =
newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)()
(currentBuffer: MutableRow, row: InternalRow) => {
val input = rowToBeProcessed(currentBuffer, row)
// For all aggregate functions with mode Complete, update buffers.
- completeAlgebraicUpdateProjection.target(currentBuffer)(input)
+ completeExpressionAggUpdateProjection.target(currentBuffer)(input)
var i = 0
- while (i < completeNonAlgebraicAggregateFunctions.length) {
- completeNonAlgebraicAggregateFunctions(i).update(currentBuffer, row)
+ while (i < completeImperativeAggregateFunctions.length) {
+ completeImperativeAggregateFunctions(i).update(currentBuffer, row)
i += 1
}
// For all aggregate functions with mode Final, merge buffers.
- finalAlgebraicMergeProjection.target(currentBuffer)(input)
+ finalExpressionAggMergeProjection.target(currentBuffer)(input)
i = 0
- while (i < nonCompleteNonAlgebraicAggregateFunctions.length) {
- nonCompleteNonAlgebraicAggregateFunctions(i).merge(currentBuffer, row)
+ while (i < nonCompleteImperativeAggregateFunctions.length) {
+ nonCompleteImperativeAggregateFunctions(i).merge(currentBuffer, row)
i += 1
}
}
@@ -277,27 +285,25 @@ abstract class AggregationIterator(
case (None, Some(Complete)) =>
val completeAggregateFunctions: Array[AggregateFunction2] =
allAggregateFunctions.takeRight(completeAggregateExpressions.length)
- // All non-algebraic aggregate functions with mode Complete.
- val completeNonAlgebraicAggregateFunctions: Array[AggregateFunction2] =
- completeAggregateFunctions.collect {
- case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func
- }
+ // All imperative aggregate functions with mode Complete.
+ val completeImperativeAggregateFunctions: Array[ImperativeAggregate] =
+ completeAggregateFunctions.collect { case func: ImperativeAggregate => func }
val updateExpressions =
completeAggregateFunctions.flatMap {
- case ae: AlgebraicAggregate => ae.updateExpressions
- case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+ case ae: ExpressionAggregate => ae.updateExpressions
+ case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
}
- val completeAlgebraicUpdateProjection =
+ val completeExpressionAggUpdateProjection =
newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)()
(currentBuffer: MutableRow, row: InternalRow) => {
val input = rowToBeProcessed(currentBuffer, row)
// For all aggregate functions with mode Complete, update buffers.
- completeAlgebraicUpdateProjection.target(currentBuffer)(input)
+ completeExpressionAggUpdateProjection.target(currentBuffer)(input)
var i = 0
- while (i < completeNonAlgebraicAggregateFunctions.length) {
- completeNonAlgebraicAggregateFunctions(i).update(currentBuffer, row)
+ while (i < completeImperativeAggregateFunctions.length) {
+ completeImperativeAggregateFunctions(i).update(currentBuffer, row)
i += 1
}
}
@@ -315,11 +321,11 @@ abstract class AggregationIterator(
// Initializing the function used to generate the output row.
protected val generateOutput: (InternalRow, MutableRow) => InternalRow = {
val rowToBeEvaluated = new JoinedRow
- val safeOutoutRow = new GenericMutableRow(resultExpressions.length)
+ val safeOutputRow = new GenericMutableRow(resultExpressions.length)
val mutableOutput = if (outputsUnsafeRows) {
- UnsafeProjection.create(resultExpressions.map(_.dataType).toArray).apply(safeOutoutRow)
+ UnsafeProjection.create(resultExpressions.map(_.dataType).toArray).apply(safeOutputRow)
} else {
- safeOutoutRow
+ safeOutputRow
}
aggregationMode match {
@@ -329,7 +335,7 @@ abstract class AggregationIterator(
// Because we cannot copy a joinedRow containing a UnsafeRow (UnsafeRow does not
// support generic getter), we create a mutable projection to output the
// JoinedRow(currentGroupingKey, currentBuffer)
- val bufferSchema = nonCompleteAggregateFunctions.flatMap(_.bufferAttributes)
+ val bufferSchema = nonCompleteAggregateFunctions.flatMap(_.aggBufferAttributes)
val resultProjection =
newMutableProjection(
groupingKeyAttributes ++ bufferSchema,
@@ -345,12 +351,12 @@ abstract class AggregationIterator(
// resultExpressions.
case (Some(Final), None) | (Some(Final) | None, Some(Complete)) =>
val bufferSchemata =
- allAggregateFunctions.flatMap(_.bufferAttributes)
+ allAggregateFunctions.flatMap(_.aggBufferAttributes)
val evalExpressions = allAggregateFunctions.map {
- case ae: AlgebraicAggregate => ae.evaluateExpression
+ case ae: ExpressionAggregate => ae.evaluateExpression
case agg: AggregateFunction2 => NoOp
}
- val algebraicEvalProjection = newMutableProjection(evalExpressions, bufferSchemata)()
+ val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferSchemata)()
val aggregateResultSchema = nonCompleteAggregateAttributes ++ completeAggregateAttributes
// TODO: Use unsafe row.
val aggregateResult = new GenericMutableRow(aggregateResultSchema.length)
@@ -360,14 +366,14 @@ abstract class AggregationIterator(
resultProjection.target(mutableOutput)
(currentGroupingKey: InternalRow, currentBuffer: MutableRow) => {
- // Generate results for all algebraic aggregate functions.
- algebraicEvalProjection.target(aggregateResult)(currentBuffer)
- // Generate results for all non-algebraic aggregate functions.
+ // Generate results for all expression-based aggregate functions.
+ expressionAggEvalProjection.target(aggregateResult)(currentBuffer)
+ // Generate results for all imperative aggregate functions.
var i = 0
- while (i < allNonAlgebraicAggregateFunctions.length) {
+ while (i < allImperativeAggregateFunctions.length) {
aggregateResult.update(
- allNonAlgebraicAggregateFunctionPositions(i),
- allNonAlgebraicAggregateFunctions(i).eval(currentBuffer))
+ allImperativeAggregateFunctionPositions(i),
+ allImperativeAggregateFunctions(i).eval(currentBuffer))
i += 1
}
resultProjection(rowToBeEvaluated(currentGroupingKey, aggregateResult))
@@ -392,10 +398,10 @@ abstract class AggregationIterator(
/** Initializes buffer values for all aggregate functions. */
protected def initializeBuffer(buffer: MutableRow): Unit = {
- algebraicInitialProjection.target(buffer)(EmptyRow)
+ expressionAggInitialProjection.target(buffer)(EmptyRow)
var i = 0
- while (i < allNonAlgebraicAggregateFunctions.length) {
- allNonAlgebraicAggregateFunctions(i).initialize(buffer)
+ while (i < allImperativeAggregateFunctions.length) {
+ allImperativeAggregateFunctions(i).initialize(buffer)
i += 1
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
index 73d50e07cf..a9e5d175bf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
@@ -54,7 +54,7 @@ class SortBasedAggregationIterator(
outputsUnsafeRows) {
override protected def newBuffer: MutableRow = {
- val bufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes)
+ val bufferSchema = allAggregateFunctions.flatMap(_.aggBufferAttributes)
val bufferRowSize: Int = bufferSchema.length
val genericMutableBuffer = new GenericMutableRow(bufferRowSize)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index ba379d358d..3cd22af305 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -32,7 +32,6 @@ case class TungstenAggregate(
groupingExpressions: Seq[NamedExpression],
nonCompleteAggregateExpressions: Seq[AggregateExpression2],
completeAggregateExpressions: Seq[AggregateExpression2],
- initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
extends UnaryNode {
@@ -79,7 +78,6 @@ case class TungstenAggregate(
groupingExpressions,
nonCompleteAggregateExpressions,
completeAggregateExpressions,
- initialInputBufferOffset,
resultExpressions,
newMutableProjection,
child.output,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
index 26fdbc83ef..6a84c0af0b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
@@ -62,10 +62,6 @@ import org.apache.spark.sql.types.StructType
* [[PartialMerge]], or [[Final]].
* @param completeAggregateExpressions
* [[AggregateExpression2]] containing [[AggregateFunction2]]s with mode [[Complete]].
- * @param initialInputBufferOffset
- * If this iterator is used to handle functions with mode [[PartialMerge]] or [[Final]].
- * The input rows have the format of `grouping keys + aggregation buffer`.
- * This offset indicates the starting position of aggregation buffer in a input row.
* @param resultExpressions
* expressions for generating output rows.
* @param newMutableProjection
@@ -77,7 +73,6 @@ class TungstenAggregationIterator(
groupingExpressions: Seq[NamedExpression],
nonCompleteAggregateExpressions: Seq[AggregateExpression2],
completeAggregateExpressions: Seq[AggregateExpression2],
- initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
originalInputAttributes: Seq[Attribute],
@@ -133,17 +128,18 @@ class TungstenAggregationIterator(
completeAggregateExpressions.map(_.mode).distinct.headOption
}
- // All aggregate functions. TungstenAggregationIterator only handles AlgebraicAggregates.
- // If there is any functions that is not an AlgebraicAggregate, we throw an
+ // All aggregate functions. TungstenAggregationIterator only handles expression-based aggregate.
+ // If there is any functions that is an ImperativeAggregateFunction, we throw an
// IllegalStateException.
- private[this] val allAggregateFunctions: Array[AlgebraicAggregate] = {
- if (!allAggregateExpressions.forall(_.aggregateFunction.isInstanceOf[AlgebraicAggregate])) {
+ private[this] val allAggregateFunctions: Array[ExpressionAggregate] = {
+ if (!allAggregateExpressions.forall(
+ _.aggregateFunction.isInstanceOf[ExpressionAggregate])) {
throw new IllegalStateException(
- "Only AlgebraicAggregates should be passed in TungstenAggregationIterator.")
+ "Only ExpressionAggregateFunctions should be passed in TungstenAggregationIterator.")
}
allAggregateExpressions
- .map(_.aggregateFunction.asInstanceOf[AlgebraicAggregate])
+ .map(_.aggregateFunction.asInstanceOf[ExpressionAggregate])
.toArray
}
@@ -154,7 +150,7 @@ class TungstenAggregationIterator(
///////////////////////////////////////////////////////////////////////////
// The projection used to initialize buffer values.
- private[this] val algebraicInitialProjection: MutableProjection = {
+ private[this] val initialProjection: MutableProjection = {
val initExpressions = allAggregateFunctions.flatMap(_.initialValues)
newMutableProjection(initExpressions, Nil)()
}
@@ -164,14 +160,14 @@ class TungstenAggregationIterator(
// when we switch to sort-based aggregation, and when we create the re-used buffer for
// sort-based aggregation).
private def createNewAggregationBuffer(): UnsafeRow = {
- val bufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes)
+ val bufferSchema = allAggregateFunctions.flatMap(_.aggBufferAttributes)
val bufferRowSize: Int = bufferSchema.length
val genericMutableBuffer = new GenericMutableRow(bufferRowSize)
val unsafeProjection =
UnsafeProjection.create(bufferSchema.map(_.dataType))
val buffer = unsafeProjection.apply(genericMutableBuffer)
- algebraicInitialProjection.target(buffer)(EmptyRow)
+ initialProjection.target(buffer)(EmptyRow)
buffer
}
@@ -179,84 +175,78 @@ class TungstenAggregationIterator(
private def generateProcessRow(
inputAttributes: Seq[Attribute]): (UnsafeRow, InternalRow) => Unit = {
- val aggregationBufferAttributes = allAggregateFunctions.flatMap(_.bufferAttributes)
+ val aggregationBufferAttributes = allAggregateFunctions.flatMap(_.aggBufferAttributes)
val joinedRow = new JoinedRow()
aggregationMode match {
// Partial-only
case (Some(Partial), None) =>
val updateExpressions = allAggregateFunctions.flatMap(_.updateExpressions)
- val algebraicUpdateProjection =
+ val updateProjection =
newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)()
(currentBuffer: UnsafeRow, row: InternalRow) => {
- algebraicUpdateProjection.target(currentBuffer)
- algebraicUpdateProjection(joinedRow(currentBuffer, row))
+ updateProjection.target(currentBuffer)
+ updateProjection(joinedRow(currentBuffer, row))
}
// PartialMerge-only or Final-only
case (Some(PartialMerge), None) | (Some(Final), None) =>
val mergeExpressions = allAggregateFunctions.flatMap(_.mergeExpressions)
- // This projection is used to merge buffer values for all AlgebraicAggregates.
- val algebraicMergeProjection =
- newMutableProjection(
- mergeExpressions,
- aggregationBufferAttributes ++ inputAttributes)()
+ val mergeProjection =
+ newMutableProjection(mergeExpressions, aggregationBufferAttributes ++ inputAttributes)()
(currentBuffer: UnsafeRow, row: InternalRow) => {
- // Process all algebraic aggregate functions.
- algebraicMergeProjection.target(currentBuffer)
- algebraicMergeProjection(joinedRow(currentBuffer, row))
+ mergeProjection.target(currentBuffer)
+ mergeProjection(joinedRow(currentBuffer, row))
}
// Final-Complete
case (Some(Final), Some(Complete)) =>
- val nonCompleteAggregateFunctions: Array[AlgebraicAggregate] =
+ val nonCompleteAggregateFunctions: Array[ExpressionAggregate] =
allAggregateFunctions.take(nonCompleteAggregateExpressions.length)
- val completeAggregateFunctions: Array[AlgebraicAggregate] =
+ val completeAggregateFunctions: Array[ExpressionAggregate] =
allAggregateFunctions.takeRight(completeAggregateExpressions.length)
val completeOffsetExpressions =
- Seq.fill(completeAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp)
+ Seq.fill(completeAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp)
val mergeExpressions =
nonCompleteAggregateFunctions.flatMap(_.mergeExpressions) ++ completeOffsetExpressions
- val finalAlgebraicMergeProjection =
- newMutableProjection(
- mergeExpressions,
- aggregationBufferAttributes ++ inputAttributes)()
+ val finalMergeProjection =
+ newMutableProjection(mergeExpressions, aggregationBufferAttributes ++ inputAttributes)()
// We do not touch buffer values of aggregate functions with the Final mode.
val finalOffsetExpressions =
- Seq.fill(nonCompleteAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp)
+ Seq.fill(nonCompleteAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp)
val updateExpressions =
finalOffsetExpressions ++ completeAggregateFunctions.flatMap(_.updateExpressions)
- val completeAlgebraicUpdateProjection =
+ val completeUpdateProjection =
newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)()
(currentBuffer: UnsafeRow, row: InternalRow) => {
val input = joinedRow(currentBuffer, row)
// For all aggregate functions with mode Complete, update the given currentBuffer.
- completeAlgebraicUpdateProjection.target(currentBuffer)(input)
+ completeUpdateProjection.target(currentBuffer)(input)
// For all aggregate functions with mode Final, merge buffer values in row to
// currentBuffer.
- finalAlgebraicMergeProjection.target(currentBuffer)(input)
+ finalMergeProjection.target(currentBuffer)(input)
}
// Complete-only
case (None, Some(Complete)) =>
- val completeAggregateFunctions: Array[AlgebraicAggregate] =
+ val completeAggregateFunctions: Array[ExpressionAggregate] =
allAggregateFunctions.takeRight(completeAggregateExpressions.length)
val updateExpressions =
completeAggregateFunctions.flatMap(_.updateExpressions)
- val completeAlgebraicUpdateProjection =
+ val completeUpdateProjection =
newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)()
(currentBuffer: UnsafeRow, row: InternalRow) => {
- completeAlgebraicUpdateProjection.target(currentBuffer)
+ completeUpdateProjection.target(currentBuffer)
// For all aggregate functions with mode Complete, update the given currentBuffer.
- completeAlgebraicUpdateProjection(joinedRow(currentBuffer, row))
+ completeUpdateProjection(joinedRow(currentBuffer, row))
}
// Grouping only.
@@ -272,7 +262,7 @@ class TungstenAggregationIterator(
private def generateResultProjection(): (UnsafeRow, UnsafeRow) => UnsafeRow = {
val groupingAttributes = groupingExpressions.map(_.toAttribute)
- val bufferAttributes = allAggregateFunctions.flatMap(_.bufferAttributes)
+ val bufferAttributes = allAggregateFunctions.flatMap(_.aggBufferAttributes)
aggregationMode match {
// Partial-only or PartialMerge-only: every output row is basically the values of
@@ -339,7 +329,7 @@ class TungstenAggregationIterator(
// all groups and their corresponding aggregation buffers for hash-based aggregation.
private[this] val hashMap = new UnsafeFixedWidthAggregationMap(
initialAggregationBuffer,
- StructType.fromAttributes(allAggregateFunctions.flatMap(_.bufferAttributes)),
+ StructType.fromAttributes(allAggregateFunctions.flatMap(_.aggBufferAttributes)),
StructType.fromAttributes(groupingExpressions.map(_.toAttribute)),
TaskContext.get.taskMemoryManager(),
SparkEnv.get.shuffleMemoryManager,
@@ -480,7 +470,7 @@ class TungstenAggregationIterator(
// The originalInputAttributes are using cloneBufferAttributes. So, we need to use
// allAggregateFunctions.flatMap(_.cloneBufferAttributes).
val bufferExtractor = newMutableProjection(
- allAggregateFunctions.flatMap(_.cloneBufferAttributes),
+ allAggregateFunctions.flatMap(_.inputAggBufferAttributes),
originalInputAttributes)()
bufferExtractor.target(buffer)
@@ -510,7 +500,7 @@ class TungstenAggregationIterator(
// Basically the value of the KVIterator returned by externalSorter
// will just aggregation buffer. At here, we use cloneBufferAttributes.
val newInputAttributes: Seq[Attribute] =
- allAggregateFunctions.flatMap(_.cloneBufferAttributes)
+ allAggregateFunctions.flatMap(_.inputAggBufferAttributes)
// Set up new processRow and generateOutput.
processRow = generateProcessRow(newInputAttributes)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
index 1114fe6552..fd02be1225 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
@@ -22,7 +22,7 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters}
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection
import org.apache.spark.sql.catalyst.expressions.{MutableRow, InterpretedMutableProjection, AttributeReference, Expression}
-import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction2
+import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, AggregateFunction2}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
@@ -318,13 +318,11 @@ private[sql] class InputAggregationBuffer private[sql] (
/**
* The internal wrapper used to hook a [[UserDefinedAggregateFunction]] `udaf` in the
* internal aggregation code path.
- * @param children
- * @param udaf
*/
private[sql] case class ScalaUDAF(
children: Seq[Expression],
udaf: UserDefinedAggregateFunction)
- extends AggregateFunction2 with Logging {
+ extends ImperativeAggregate with Logging {
require(
children.length == udaf.inputSchema.length,
@@ -339,11 +337,9 @@ private[sql] case class ScalaUDAF(
override val inputTypes: Seq[DataType] = udaf.inputSchema.map(_.dataType)
- override val bufferSchema: StructType = udaf.bufferSchema
+ override val aggBufferSchema: StructType = udaf.bufferSchema
- override val bufferAttributes: Seq[AttributeReference] = bufferSchema.toAttributes
-
- override lazy val cloneBufferAttributes = bufferAttributes.map(_.newInstance())
+ override val aggBufferAttributes: Seq[AttributeReference] = aggBufferSchema.toAttributes
private[this] lazy val childrenSchema: StructType = {
val inputFields = children.zipWithIndex.map {
@@ -370,13 +366,13 @@ private[sql] case class ScalaUDAF(
CatalystTypeConverters.createToScalaConverter(childrenSchema)
private[this] lazy val bufferValuesToCatalystConverters: Array[Any => Any] = {
- bufferSchema.fields.map { field =>
+ aggBufferSchema.fields.map { field =>
CatalystTypeConverters.createToCatalystConverter(field.dataType)
}
}
private[this] lazy val bufferValuesToScalaConverters: Array[Any => Any] = {
- bufferSchema.fields.map { field =>
+ aggBufferSchema.fields.map { field =>
CatalystTypeConverters.createToScalaConverter(field.dataType)
}
}
@@ -398,15 +394,15 @@ private[sql] case class ScalaUDAF(
* Sets the inputBufferOffset to newInputBufferOffset and then create a new instance of
* `inputAggregateBuffer` based on this new inputBufferOffset.
*/
- override def withNewInputBufferOffset(newInputBufferOffset: Int): Unit = {
- super.withNewInputBufferOffset(newInputBufferOffset)
+ override def withNewInputAggBufferOffset(newInputBufferOffset: Int): Unit = {
+ super.withNewInputAggBufferOffset(newInputBufferOffset)
// inputBufferOffset has been updated.
inputAggregateBuffer =
new InputAggregationBuffer(
- bufferSchema,
+ aggBufferSchema,
bufferValuesToCatalystConverters,
bufferValuesToScalaConverters,
- inputBufferOffset,
+ inputAggBufferOffset,
null)
}
@@ -414,22 +410,22 @@ private[sql] case class ScalaUDAF(
* Sets the mutableBufferOffset to newMutableBufferOffset and then create a new instance of
* `mutableAggregateBuffer` and `evalAggregateBuffer` based on this new mutableBufferOffset.
*/
- override def withNewMutableBufferOffset(newMutableBufferOffset: Int): Unit = {
- super.withNewMutableBufferOffset(newMutableBufferOffset)
+ override def withNewMutableAggBufferOffset(newMutableBufferOffset: Int): Unit = {
+ super.withNewMutableAggBufferOffset(newMutableBufferOffset)
// mutableBufferOffset has been updated.
mutableAggregateBuffer =
new MutableAggregationBufferImpl(
- bufferSchema,
+ aggBufferSchema,
bufferValuesToCatalystConverters,
bufferValuesToScalaConverters,
- mutableBufferOffset,
+ mutableAggBufferOffset,
null)
evalAggregateBuffer =
new InputAggregationBuffer(
- bufferSchema,
+ aggBufferSchema,
bufferValuesToCatalystConverters,
bufferValuesToScalaConverters,
- mutableBufferOffset,
+ mutableAggBufferOffset,
null)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
index 4f5e86cceb..e1d7e1bf02 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
@@ -97,10 +97,10 @@ object Utils {
// Check if we can use TungstenAggregate.
val usesTungstenAggregate =
child.sqlContext.conf.unsafeEnabled &&
- aggregateExpressions.forall(_.aggregateFunction.isInstanceOf[AlgebraicAggregate]) &&
+ aggregateExpressions.forall(_.aggregateFunction.isInstanceOf[ExpressionAggregate]) &&
supportsTungstenAggregate(
groupingExpressions,
- aggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes))
+ aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes))
// 1. Create an Aggregate Operator for partial aggregations.
@@ -117,10 +117,10 @@ object Utils {
val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute)
val partialAggregateExpressions = aggregateExpressions.map(_.copy(mode = Partial))
val partialAggregateAttributes =
- partialAggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes)
+ partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
val partialResultExpressions =
namedGroupingAttributes ++
- partialAggregateExpressions.flatMap(_.aggregateFunction.cloneBufferAttributes)
+ partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
val partialAggregate = if (usesTungstenAggregate) {
TungstenAggregate(
@@ -128,7 +128,6 @@ object Utils {
groupingExpressions = namedGroupingExpressions.map(_._2),
nonCompleteAggregateExpressions = partialAggregateExpressions,
completeAggregateExpressions = Nil,
- initialInputBufferOffset = 0,
resultExpressions = partialResultExpressions,
child = child)
} else {
@@ -158,7 +157,7 @@ object Utils {
// aggregateFunctionMap contains unique aggregate functions.
val aggregateFunction =
aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct)._1
- aggregateFunction.asInstanceOf[AlgebraicAggregate].evaluateExpression
+ aggregateFunction.asInstanceOf[ExpressionAggregate].evaluateExpression
case expression =>
// We do not rely on the equality check at here since attributes may
// different cosmetically. Instead, we use semanticEquals.
@@ -173,7 +172,6 @@ object Utils {
groupingExpressions = namedGroupingAttributes,
nonCompleteAggregateExpressions = finalAggregateExpressions,
completeAggregateExpressions = Nil,
- initialInputBufferOffset = namedGroupingAttributes.length,
resultExpressions = rewrittenResultExpressions,
child = partialAggregate)
} else {
@@ -216,10 +214,11 @@ object Utils {
val aggregateExpressions = functionsWithDistinct ++ functionsWithoutDistinct
val usesTungstenAggregate =
child.sqlContext.conf.unsafeEnabled &&
- aggregateExpressions.forall(_.aggregateFunction.isInstanceOf[AlgebraicAggregate]) &&
+ aggregateExpressions.forall(
+ _.aggregateFunction.isInstanceOf[ExpressionAggregate]) &&
supportsTungstenAggregate(
groupingExpressions,
- aggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes))
+ aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes))
// 1. Create an Aggregate Operator for partial aggregations.
// The grouping expressions are original groupingExpressions and
@@ -252,20 +251,19 @@ object Utils {
val partialAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial))
val partialAggregateAttributes =
- partialAggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes)
+ partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
val partialAggregateGroupingExpressions =
(namedGroupingExpressions ++ namedDistinctColumnExpressions).map(_._2)
val partialAggregateResult =
namedGroupingAttributes ++
distinctColumnAttributes ++
- partialAggregateExpressions.flatMap(_.aggregateFunction.cloneBufferAttributes)
+ partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
val partialAggregate = if (usesTungstenAggregate) {
TungstenAggregate(
requiredChildDistributionExpressions = None: Option[Seq[Expression]],
groupingExpressions = partialAggregateGroupingExpressions,
nonCompleteAggregateExpressions = partialAggregateExpressions,
completeAggregateExpressions = Nil,
- initialInputBufferOffset = 0,
resultExpressions = partialAggregateResult,
child = child)
} else {
@@ -284,18 +282,17 @@ object Utils {
// 2. Create an Aggregate Operator for partial merge aggregations.
val partialMergeAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge))
val partialMergeAggregateAttributes =
- partialMergeAggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes)
+ partialMergeAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
val partialMergeAggregateResult =
namedGroupingAttributes ++
distinctColumnAttributes ++
- partialMergeAggregateExpressions.flatMap(_.aggregateFunction.cloneBufferAttributes)
+ partialMergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
val partialMergeAggregate = if (usesTungstenAggregate) {
TungstenAggregate(
requiredChildDistributionExpressions = Some(namedGroupingAttributes),
groupingExpressions = namedGroupingAttributes ++ distinctColumnAttributes,
nonCompleteAggregateExpressions = partialMergeAggregateExpressions,
completeAggregateExpressions = Nil,
- initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length,
resultExpressions = partialMergeAggregateResult,
child = partialAggregate)
} else {
@@ -362,7 +359,7 @@ object Utils {
// aggregate functions that have not been rewritten.
aggregateFunctionMap(function, isDistinct)._1
}
- aggregateFunction.asInstanceOf[AlgebraicAggregate].evaluateExpression
+ aggregateFunction.asInstanceOf[ExpressionAggregate].evaluateExpression
case expression =>
// We do not rely on the equality check at here since attributes may
// different cosmetically. Instead, we use semanticEquals.
@@ -377,7 +374,6 @@ object Utils {
groupingExpressions = namedGroupingAttributes,
nonCompleteAggregateExpressions = finalAggregateExpressions,
completeAggregateExpressions = completeAggregateExpressions,
- initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length,
resultExpressions = rewrittenResultExpressions,
child = partialMergeAggregate)
} else {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
index afda0d29f6..7ca677a6c7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
@@ -38,7 +38,7 @@ class TungstenAggregationIteratorSuite extends SparkFunSuite with SharedSQLConte
() => new InterpretedMutableProjection(expr, schema)
}
val dummyAccum = SQLMetrics.createLongMetric(sparkContext, "dummy")
- iter = new TungstenAggregationIterator(Seq.empty, Seq.empty, Seq.empty, 0,
+ iter = new TungstenAggregationIterator(Seq.empty, Seq.empty, Seq.empty,
Seq.empty, newMutableProjection, Seq.empty, None, dummyAccum, dummyAccum)
val numPages = iter.getHashMap.getNumDataPages
assert(numPages === 1)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index a85d4db88d..18bbdb9908 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -554,7 +554,7 @@ private[hive] case class HiveUDAFFunction(
funcWrapper: HiveFunctionWrapper,
children: Seq[Expression],
isUDAFBridgeRequired: Boolean = false)
- extends AggregateFunction2 with HiveInspectors {
+ extends ImperativeAggregate with HiveInspectors {
def this() = this(null, null)
@@ -598,7 +598,7 @@ private[hive] case class HiveUDAFFunction(
// Hive UDAF has its own buffer, so we don't need to occupy a slot in the aggregation
// buffer for it.
- override def bufferSchema: StructType = StructType(Nil)
+ override def aggBufferSchema: StructType = StructType(Nil)
override def update(_buffer: MutableRow, input: InternalRow): Unit = {
val inputs = inputProjection(input)
@@ -610,13 +610,11 @@ private[hive] case class HiveUDAFFunction(
"Hive UDAF doesn't support partial aggregate")
}
- override def cloneBufferAttributes: Seq[Attribute] = Nil
-
override def initialize(_buffer: MutableRow): Unit = {
buffer = function.getNewAggregationBuffer
}
- override def bufferAttributes: Seq[AttributeReference] = Nil
+ override def aggBufferAttributes: Seq[AttributeReference] = Nil
// We rely on Hive to check the input data types, so use `AnyDataType` here to bypass our
// catalyst type checking framework.
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index 24b1846923..c9e1bb1995 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -343,7 +343,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
Row(null, null, 110.0, null, null, 10.0) :: Nil)
}
- test("non-AlgebraicAggregate aggreguate function") {
+ test("interpreted aggregate function") {
checkAnswer(
sqlContext.sql(
"""
@@ -368,7 +368,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
Row(null) :: Nil)
}
- test("non-AlgebraicAggregate and AlgebraicAggregate aggreguate function") {
+ test("interpreted and expression-based aggregation functions") {
checkAnswer(
sqlContext.sql(
"""