aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@databricks.com>2015-10-07 13:19:49 -0700
committerYin Huai <yhuai@databricks.com>2015-10-07 13:19:49 -0700
commita9ecd06149df4ccafd3927c35f63b9f03f170ae5 (patch)
tree99d41c5b5f8cd786e2c161c339373a740b4048eb /sql/catalyst
parent5be5d247440d6346d667c4b3d817666126f62906 (diff)
downloadspark-a9ecd06149df4ccafd3927c35f63b9f03f170ae5.tar.gz
spark-a9ecd06149df4ccafd3927c35f63b9f03f170ae5.tar.bz2
spark-a9ecd06149df4ccafd3927c35f63b9f03f170ae5.zip
[SPARK-10941] [SQL] Refactor AggregateFunction2 and AlgebraicAggregate interfaces to improve code clarity
This patch refactors several of the Aggregate2 interfaces in order to improve code clarity. The biggest change is a refactoring of the `AggregateFunction2` class hierarchy. In the old code, we had a class named `AlgebraicAggregate` that inherited from `AggregateFunction2`, added a new set of methods, then banned the use of the inherited methods. I found this to be fairly confusing because. If you look carefully at the existing code, you'll see that subclasses of `AggregateFunction2` fall into two disjoint categories: imperative aggregation functions which directly extended `AggregateFunction2` and declarative, expression-based aggregate functions which extended `AlgebraicAggregate`. In order to make this more explicit, this patch refactors things so that `AggregateFunction2` is a sealed abstract class with two subclasses, `ImperativeAggregateFunction` and `ExpressionAggregateFunction`. The superclass, `AggregateFunction2`, now only contains methods and fields that are common to both subclasses. After making this change, I updated the various AggregationIterator classes to comply with this new naming scheme. I also performed several small renamings in the aggregate interfaces themselves in order to improve clarity and rewrote or expanded a number of comments. Author: Josh Rosen <joshrosen@databricks.com> Closes #8973 from JoshRosen/tungsten-agg-comments.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala70
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala235
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala2
3 files changed, 188 insertions, 119 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
index e7f8104d43..4ad2607a85 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
-case class Average(child: Expression) extends AlgebraicAggregate {
+case class Average(child: Expression) extends ExpressionAggregate {
override def children: Seq[Expression] = child :: Nil
@@ -57,7 +57,7 @@ case class Average(child: Expression) extends AlgebraicAggregate {
private val currentSum = AttributeReference("currentSum", sumDataType)()
private val currentCount = AttributeReference("currentCount", LongType)()
- override val bufferAttributes = currentSum :: currentCount :: Nil
+ override val aggBufferAttributes = currentSum :: currentCount :: Nil
override val initialValues = Seq(
/* currentSum = */ Cast(Literal(0), sumDataType),
@@ -88,7 +88,7 @@ case class Average(child: Expression) extends AlgebraicAggregate {
}
}
-case class Count(child: Expression) extends AlgebraicAggregate {
+case class Count(child: Expression) extends ExpressionAggregate {
override def children: Seq[Expression] = child :: Nil
override def nullable: Boolean = false
@@ -101,7 +101,7 @@ case class Count(child: Expression) extends AlgebraicAggregate {
private val currentCount = AttributeReference("currentCount", LongType)()
- override val bufferAttributes = currentCount :: Nil
+ override val aggBufferAttributes = currentCount :: Nil
override val initialValues = Seq(
/* currentCount = */ Literal(0L)
@@ -118,7 +118,7 @@ case class Count(child: Expression) extends AlgebraicAggregate {
override val evaluateExpression = Cast(currentCount, LongType)
}
-case class First(child: Expression) extends AlgebraicAggregate {
+case class First(child: Expression) extends ExpressionAggregate {
override def children: Seq[Expression] = child :: Nil
@@ -135,7 +135,7 @@ case class First(child: Expression) extends AlgebraicAggregate {
private val first = AttributeReference("first", child.dataType)()
- override val bufferAttributes = first :: Nil
+ override val aggBufferAttributes = first :: Nil
override val initialValues = Seq(
/* first = */ Literal.create(null, child.dataType)
@@ -152,7 +152,7 @@ case class First(child: Expression) extends AlgebraicAggregate {
override val evaluateExpression = first
}
-case class Last(child: Expression) extends AlgebraicAggregate {
+case class Last(child: Expression) extends ExpressionAggregate {
override def children: Seq[Expression] = child :: Nil
@@ -169,7 +169,7 @@ case class Last(child: Expression) extends AlgebraicAggregate {
private val last = AttributeReference("last", child.dataType)()
- override val bufferAttributes = last :: Nil
+ override val aggBufferAttributes = last :: Nil
override val initialValues = Seq(
/* last = */ Literal.create(null, child.dataType)
@@ -186,7 +186,7 @@ case class Last(child: Expression) extends AlgebraicAggregate {
override val evaluateExpression = last
}
-case class Max(child: Expression) extends AlgebraicAggregate {
+case class Max(child: Expression) extends ExpressionAggregate {
override def children: Seq[Expression] = child :: Nil
@@ -200,7 +200,7 @@ case class Max(child: Expression) extends AlgebraicAggregate {
private val max = AttributeReference("max", child.dataType)()
- override val bufferAttributes = max :: Nil
+ override val aggBufferAttributes = max :: Nil
override val initialValues = Seq(
/* max = */ Literal.create(null, child.dataType)
@@ -220,7 +220,7 @@ case class Max(child: Expression) extends AlgebraicAggregate {
override val evaluateExpression = max
}
-case class Min(child: Expression) extends AlgebraicAggregate {
+case class Min(child: Expression) extends ExpressionAggregate {
override def children: Seq[Expression] = child :: Nil
@@ -234,7 +234,7 @@ case class Min(child: Expression) extends AlgebraicAggregate {
private val min = AttributeReference("min", child.dataType)()
- override val bufferAttributes = min :: Nil
+ override val aggBufferAttributes = min :: Nil
override val initialValues = Seq(
/* min = */ Literal.create(null, child.dataType)
@@ -277,7 +277,7 @@ case class StddevSamp(child: Expression) extends StddevAgg(child) {
// Compute standard deviation based on online algorithm specified here:
// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
-abstract class StddevAgg(child: Expression) extends AlgebraicAggregate {
+abstract class StddevAgg(child: Expression) extends ExpressionAggregate {
override def children: Seq[Expression] = child :: Nil
@@ -304,7 +304,7 @@ abstract class StddevAgg(child: Expression) extends AlgebraicAggregate {
private val currentAvg = AttributeReference("currentAvg", resultType)()
private val currentMk = AttributeReference("currentMk", resultType)()
- override val bufferAttributes = preCount :: currentCount :: preAvg ::
+ override val aggBufferAttributes = preCount :: currentCount :: preAvg ::
currentAvg :: currentMk :: Nil
override val initialValues = Seq(
@@ -397,7 +397,7 @@ abstract class StddevAgg(child: Expression) extends AlgebraicAggregate {
}
}
-case class Sum(child: Expression) extends AlgebraicAggregate {
+case class Sum(child: Expression) extends ExpressionAggregate {
override def children: Seq[Expression] = child :: Nil
@@ -429,7 +429,7 @@ case class Sum(child: Expression) extends AlgebraicAggregate {
private val zero = Cast(Literal(0), sumDataType)
- override val bufferAttributes = currentSum :: Nil
+ override val aggBufferAttributes = currentSum :: Nil
override val initialValues = Seq(
/* currentSum = */ Literal.create(null, sumDataType)
@@ -473,7 +473,7 @@ case class Sum(child: Expression) extends AlgebraicAggregate {
*/
// scalastyle:on
case class HyperLogLogPlusPlus(child: Expression, relativeSD: Double = 0.05)
- extends AggregateFunction2 {
+ extends ImperativeAggregate {
import HyperLogLogPlusPlus._
/**
@@ -531,28 +531,26 @@ case class HyperLogLogPlusPlus(child: Expression, relativeSD: Double = 0.05)
*/
private[this] val numWords = m / REGISTERS_PER_WORD + 1
- def children: Seq[Expression] = Seq(child)
+ override def children: Seq[Expression] = Seq(child)
- def nullable: Boolean = false
-
- def dataType: DataType = LongType
+ override def nullable: Boolean = false
- def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
+ override def dataType: DataType = LongType
- def bufferSchema: StructType = StructType.fromAttributes(bufferAttributes)
+ override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
- def cloneBufferAttributes: Seq[Attribute] = bufferAttributes.map(_.newInstance())
+ override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)
/** Allocate enough words to store all registers. */
- val bufferAttributes: Seq[AttributeReference] = Seq.tabulate(numWords) { i =>
+ override val aggBufferAttributes: Seq[AttributeReference] = Seq.tabulate(numWords) { i =>
AttributeReference(s"MS[$i]", LongType)()
}
/** Fill all words with zeros. */
- def initialize(buffer: MutableRow): Unit = {
+ override def initialize(buffer: MutableRow): Unit = {
var word = 0
while (word < numWords) {
- buffer.setLong(mutableBufferOffset + word, 0)
+ buffer.setLong(mutableAggBufferOffset + word, 0)
word += 1
}
}
@@ -562,7 +560,7 @@ case class HyperLogLogPlusPlus(child: Expression, relativeSD: Double = 0.05)
*
* Variable names in the HLL++ paper match variable names in the code.
*/
- def update(buffer: MutableRow, input: InternalRow): Unit = {
+ override def update(buffer: MutableRow, input: InternalRow): Unit = {
val v = child.eval(input)
if (v != null) {
// Create the hashed value 'x'.
@@ -576,7 +574,7 @@ case class HyperLogLogPlusPlus(child: Expression, relativeSD: Double = 0.05)
// Get the word containing the register we are interested in.
val wordOffset = idx / REGISTERS_PER_WORD
- val word = buffer.getLong(mutableBufferOffset + wordOffset)
+ val word = buffer.getLong(mutableAggBufferOffset + wordOffset)
// Extract the M[J] register value from the word.
val shift = REGISTER_SIZE * (idx - (wordOffset * REGISTERS_PER_WORD))
@@ -585,7 +583,7 @@ case class HyperLogLogPlusPlus(child: Expression, relativeSD: Double = 0.05)
// Assign the maximum number of leading zeros to the register.
if (pw > Midx) {
- buffer.setLong(mutableBufferOffset + wordOffset, (word & ~mask) | (pw << shift))
+ buffer.setLong(mutableAggBufferOffset + wordOffset, (word & ~mask) | (pw << shift))
}
}
}
@@ -594,12 +592,12 @@ case class HyperLogLogPlusPlus(child: Expression, relativeSD: Double = 0.05)
* Merge the HLL buffers by iterating through the registers in both buffers and select the
* maximum number of leading zeros for each register.
*/
- def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
+ override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
var idx = 0
var wordOffset = 0
while (wordOffset < numWords) {
- val word1 = buffer1.getLong(mutableBufferOffset + wordOffset)
- val word2 = buffer2.getLong(inputBufferOffset + wordOffset)
+ val word1 = buffer1.getLong(mutableAggBufferOffset + wordOffset)
+ val word2 = buffer2.getLong(inputAggBufferOffset + wordOffset)
var word = 0L
var i = 0
var mask = REGISTER_WORD_MASK
@@ -609,7 +607,7 @@ case class HyperLogLogPlusPlus(child: Expression, relativeSD: Double = 0.05)
i += 1
idx += 1
}
- buffer1.setLong(mutableBufferOffset + wordOffset, word)
+ buffer1.setLong(mutableAggBufferOffset + wordOffset, word)
wordOffset += 1
}
}
@@ -664,14 +662,14 @@ case class HyperLogLogPlusPlus(child: Expression, relativeSD: Double = 0.05)
*
* Variable names in the HLL++ paper match variable names in the code.
*/
- def eval(buffer: InternalRow): Any = {
+ override def eval(buffer: InternalRow): Any = {
// Compute the inverse of indicator value 'z' and count the number of zeros 'V'.
var zInverse = 0.0d
var V = 0.0d
var idx = 0
var wordOffset = 0
while (wordOffset < numWords) {
- val word = buffer.getLong(mutableBufferOffset + wordOffset)
+ val word = buffer.getLong(mutableAggBufferOffset + wordOffset)
var i = 0
var shift = 0
while (idx < m && i < REGISTERS_PER_WORD) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index d8699533cd..74e15ec90b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -69,9 +69,6 @@ private[sql] case object NoOp extends Expression with Unevaluable {
/**
* A container for an [[AggregateFunction2]] with its [[AggregateMode]] and a field
* (`isDistinct`) indicating if DISTINCT keyword is specified for this function.
- * @param aggregateFunction
- * @param mode
- * @param isDistinct
*/
private[sql] case class AggregateExpression2(
aggregateFunction: AggregateFunction2,
@@ -86,7 +83,7 @@ private[sql] case class AggregateExpression2(
override def references: AttributeSet = {
val childReferences = mode match {
case Partial | Complete => aggregateFunction.references.toSeq
- case PartialMerge | Final => aggregateFunction.bufferAttributes
+ case PartialMerge | Final => aggregateFunction.aggBufferAttributes
}
AttributeSet(childReferences)
@@ -95,98 +92,192 @@ private[sql] case class AggregateExpression2(
override def toString: String = s"(${aggregateFunction},mode=$mode,isDistinct=$isDistinct)"
}
-abstract class AggregateFunction2
- extends Expression with ImplicitCastInputTypes {
+/**
+ * AggregateFunction2 is the superclass of two aggregation function interfaces:
+ *
+ * - [[ImperativeAggregate]] is for aggregation functions that are specified in terms of
+ * initialize(), update(), and merge() functions that operate on Row-based aggregation buffers.
+ * - [[ExpressionAggregate]] is for aggregation functions that are specified using
+ * Catalyst expressions.
+ *
+ * In both interfaces, aggregates must define the schema ([[aggBufferSchema]]) and attributes
+ * ([[aggBufferAttributes]]) of an aggregation buffer which is used to hold partial aggregate
+ * results. At runtime, multiple aggregate functions are evaluated by the same operator using a
+ * combined aggregation buffer which concatenates the aggregation buffers of the individual
+ * aggregate functions.
+ *
+ * Code which accepts [[AggregateFunction2]] instances should be prepared to handle both types of
+ * aggregate functions.
+ */
+sealed abstract class AggregateFunction2 extends Expression with ImplicitCastInputTypes {
/** An aggregate function is not foldable. */
final override def foldable: Boolean = false
+ /** The schema of the aggregation buffer. */
+ def aggBufferSchema: StructType
+
+ /** Attributes of fields in aggBufferSchema. */
+ def aggBufferAttributes: Seq[AttributeReference]
+
/**
- * The offset of this function's start buffer value in the
- * underlying shared mutable aggregation buffer.
- * For example, we have two aggregate functions `avg(x)` and `avg(y)`, which share
- * the same aggregation buffer. In this shared buffer, the position of the first
- * buffer value of `avg(x)` will be 0 and the position of the first buffer value of `avg(y)`
- * will be 2.
+ * Attributes of fields in input aggregation buffers (immutable aggregation buffers that are
+ * merged with mutable aggregation buffers in the merge() function or merge expressions).
+ * These attributes are created automatically by cloning the [[aggBufferAttributes]].
*/
- protected var mutableBufferOffset: Int = 0
-
- def withNewMutableBufferOffset(newMutableBufferOffset: Int): Unit = {
- mutableBufferOffset = newMutableBufferOffset
- }
+ def inputAggBufferAttributes: Seq[AttributeReference]
/**
- * The offset of this function's start buffer value in the
- * underlying shared input aggregation buffer. An input aggregation buffer is used
- * when we merge two aggregation buffers and it is basically the immutable one
- * (we merge an input aggregation buffer and a mutable aggregation buffer and
- * then store the new buffer values to the mutable aggregation buffer).
- * Usually, an input aggregation buffer also contain extra elements like grouping
- * keys at the beginning. So, mutableBufferOffset and inputBufferOffset are often
- * different.
- * For example, we have a grouping expression `key``, and two aggregate functions
- * `avg(x)` and `avg(y)`. In this shared input aggregation buffer, the position of the first
- * buffer value of `avg(x)` will be 1 and the position of the first buffer value of `avg(y)`
- * will be 3 (position 0 is used for the value of key`).
+ * Indicates if this function supports partial aggregation.
+ * Currently Hive UDAF is the only one that doesn't support partial aggregation.
*/
- protected var inputBufferOffset: Int = 0
+ def supportsPartial: Boolean = true
- def withNewInputBufferOffset(newInputBufferOffset: Int): Unit = {
- inputBufferOffset = newInputBufferOffset
- }
+ override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String =
+ throw new UnsupportedOperationException(s"Cannot evaluate expression: $this")
+}
- /** The schema of the aggregation buffer. */
- def bufferSchema: StructType
+/**
+ * API for aggregation functions that are expressed in terms of imperative initialize(), update(),
+ * and merge() functions which operate on Row-based aggregation buffers.
+ *
+ * Within these functions, code should access fields of the mutable aggregation buffer by adding the
+ * bufferSchema-relative field number to `mutableAggBufferOffset` then using this new field number
+ * to access the buffer Row. This is necessary because this aggregation function's buffer is
+ * embedded inside of a larger shared aggregation buffer when an aggregation operator evaluates
+ * multiple aggregate functions at the same time.
+ *
+ * We need to perform similar field number arithmetic when merging multiple intermediate
+ * aggregate buffers together in `merge()` (in this case, use `inputAggBufferOffset` when accessing
+ * the input buffer).
+ */
+abstract class ImperativeAggregate extends AggregateFunction2 {
- /** Attributes of fields in bufferSchema. */
- def bufferAttributes: Seq[AttributeReference]
+ /**
+ * The offset of this function's first buffer value in the underlying shared mutable aggregation
+ * buffer.
+ *
+ * For example, we have two aggregate functions `avg(x)` and `avg(y)`, which share the same
+ * aggregation buffer. In this shared buffer, the position of the first buffer value of `avg(x)`
+ * will be 0 and the position of the first buffer value of `avg(y)` will be 2:
+ *
+ * avg(x) mutableAggBufferOffset = 0
+ * |
+ * v
+ * +--------+--------+--------+--------+
+ * | sum1 | count1 | sum2 | count2 |
+ * +--------+--------+--------+--------+
+ * ^
+ * |
+ * avg(y) mutableAggBufferOffset = 2
+ *
+ */
+ protected var mutableAggBufferOffset: Int = 0
- /** Clones bufferAttributes. */
- def cloneBufferAttributes: Seq[Attribute]
+ def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): Unit = {
+ mutableAggBufferOffset = newMutableAggBufferOffset
+ }
/**
- * Initializes its aggregation buffer located in `buffer`.
- * It will use bufferOffset to find the starting point of
- * its buffer in the given `buffer` shared with other functions.
+ * The offset of this function's start buffer value in the underlying shared input aggregation
+ * buffer. An input aggregation buffer is used when we merge two aggregation buffers together in
+ * the `update()` function and is immutable (we merge an input aggregation buffer and a mutable
+ * aggregation buffer and then store the new buffer values to the mutable aggregation buffer).
+ *
+ * An input aggregation buffer may contain extra fields, such as grouping keys, at its start, so
+ * mutableAggBufferOffset and inputAggBufferOffset are often different.
+ *
+ * For example, say we have a grouping expression, `key`, and two aggregate functions,
+ * `avg(x)` and `avg(y)`. In the shared input aggregation buffer, the position of the first
+ * buffer value of `avg(x)` will be 1 and the position of the first buffer value of `avg(y)`
+ * will be 3 (position 0 is used for the value of `key`):
+ *
+ * avg(x) inputAggBufferOffset = 1
+ * |
+ * v
+ * +--------+--------+--------+--------+--------+
+ * | key | sum1 | count1 | sum2 | count2 |
+ * +--------+--------+--------+--------+--------+
+ * ^
+ * |
+ * avg(y) inputAggBufferOffset = 3
+ *
*/
- def initialize(buffer: MutableRow): Unit
+ protected var inputAggBufferOffset: Int = 0
+
+ def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): Unit = {
+ inputAggBufferOffset = newInputAggBufferOffset
+ }
/**
- * Updates its aggregation buffer located in `buffer` based on the given `input`.
- * It will use bufferOffset to find the starting point of its buffer in the given `buffer`
- * shared with other functions.
+ * Initializes the mutable aggregation buffer located in `mutableAggBuffer`.
+ *
+ * Use `fieldNumber + mutableAggBufferOffset` to access fields of `mutableAggBuffer`.
*/
- def update(buffer: MutableRow, input: InternalRow): Unit
+ def initialize(mutableAggBuffer: MutableRow): Unit
/**
- * Updates its aggregation buffer located in `buffer1` by combining intermediate results
- * in the current buffer and intermediate results from another buffer `buffer2`.
- * It will use bufferOffset to find the starting point of its buffer in the given `buffer1`
- * and `buffer2`.
+ * Updates its aggregation buffer, located in `mutableAggBuffer`, based on the given `inputRow`.
+ *
+ * Use `fieldNumber + mutableAggBufferOffset` to access fields of `mutableAggBuffer`.
*/
- def merge(buffer1: MutableRow, buffer2: InternalRow): Unit
-
- override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String =
- throw new UnsupportedOperationException(s"Cannot evaluate expression: $this")
+ def update(mutableAggBuffer: MutableRow, inputRow: InternalRow): Unit
/**
- * Indicates if this function supports partial aggregation.
- * Currently Hive UDAF is the only one that doesn't support partial aggregation.
+ * Combines new intermediate results from the `inputAggBuffer` with the existing intermediate
+ * results in the `mutableAggBuffer.`
+ *
+ * Use `fieldNumber + mutableAggBufferOffset` to access fields of `mutableAggBuffer`.
+ * Use `fieldNumber + inputAggBufferOffset` to access fields of `inputAggBuffer`.
*/
- def supportsPartial: Boolean = true
+ def merge(mutableAggBuffer: MutableRow, inputAggBuffer: InternalRow): Unit
+
+ final lazy val inputAggBufferAttributes: Seq[AttributeReference] =
+ aggBufferAttributes.map(_.newInstance())
}
/**
- * A helper class for aggregate functions that can be implemented in terms of catalyst expressions.
+ * API for aggregation functions that are expressed in terms of Catalyst expressions.
+ *
+ * When implementing a new expression-based aggregate function, start by implementing
+ * `bufferAttributes`, defining attributes for the fields of the mutable aggregation buffer. You
+ * can then use these attributes when defining `updateExpressions`, `mergeExpressions`, and
+ * `evaluateExpressions`.
*/
-abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable with Unevaluable {
+abstract class ExpressionAggregate
+ extends AggregateFunction2
+ with Serializable
+ with Unevaluable {
+ /**
+ * Expressions for initializing empty aggregation buffers.
+ */
val initialValues: Seq[Expression]
+
+ /**
+ * Expressions for updating the mutable aggregation buffer based on an input row.
+ */
val updateExpressions: Seq[Expression]
+
+ /**
+ * A sequence of expressions for merging two aggregation buffers together. When defining these
+ * expressions, you can use the syntax `attributeName.left` and `attributeName.right` to refer
+ * to the attributes corresponding to each of the buffers being merged (this magic is enabled
+ * by the [[RichAttribute]] implicit class).
+ */
val mergeExpressions: Seq[Expression]
+
+ /**
+ * An expression which returns the final value for this aggregate function. Its data type should
+ * match this expression's [[dataType]].
+ */
val evaluateExpression: Expression
- override lazy val cloneBufferAttributes = bufferAttributes.map(_.newInstance())
+ /** An expression-based aggregate's bufferSchema is derived from bufferAttributes. */
+ final override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)
+
+ final lazy val inputAggBufferAttributes: Seq[AttributeReference] =
+ aggBufferAttributes.map(_.newInstance())
/**
* A helper class for representing an attribute used in merging two
@@ -194,33 +285,13 @@ abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable w
* we merge buffer values and then update bufferLeft. A [[RichAttribute]]
* of an [[AttributeReference]] `a` has two functions `left` and `right`,
* which represent `a` in `bufferLeft` and `bufferRight`, respectively.
- * @param a
*/
implicit class RichAttribute(a: AttributeReference) {
/** Represents this attribute at the mutable buffer side. */
def left: AttributeReference = a
/** Represents this attribute at the input buffer side (the data value is read-only). */
- def right: AttributeReference = cloneBufferAttributes(bufferAttributes.indexOf(a))
- }
-
- /** An AlgebraicAggregate's bufferSchema is derived from bufferAttributes. */
- override def bufferSchema: StructType = StructType.fromAttributes(bufferAttributes)
-
- override def initialize(buffer: MutableRow): Unit = {
- throw new UnsupportedOperationException(
- "AlgebraicAggregate's initialize should not be called directly")
- }
-
- override final def update(buffer: MutableRow, input: InternalRow): Unit = {
- throw new UnsupportedOperationException(
- "AlgebraicAggregate's update should not be called directly")
+ def right: AttributeReference = inputAggBufferAttributes(aggBufferAttributes.indexOf(a))
}
-
- override final def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
- throw new UnsupportedOperationException(
- "AlgebraicAggregate's merge should not be called directly")
- }
-
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala
index ecc0644164..0d32949775 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala
@@ -38,7 +38,7 @@ class HyperLogLogPlusPlusSuite extends SparkFunSuite {
}
def createBuffer(hll: HyperLogLogPlusPlus): MutableRow = {
- val buffer = new SpecificMutableRow(hll.bufferAttributes.map(_.dataType))
+ val buffer = new SpecificMutableRow(hll.aggBufferAttributes.map(_.dataType))
hll.initialize(buffer)
buffer
}