diff options
author | Josh Rosen <joshrosen@databricks.com> | 2015-10-07 13:19:49 -0700 |
---|---|---|
committer | Yin Huai <yhuai@databricks.com> | 2015-10-07 13:19:49 -0700 |
commit | a9ecd06149df4ccafd3927c35f63b9f03f170ae5 (patch) | |
tree | 99d41c5b5f8cd786e2c161c339373a740b4048eb /sql/catalyst | |
parent | 5be5d247440d6346d667c4b3d817666126f62906 (diff) | |
download | spark-a9ecd06149df4ccafd3927c35f63b9f03f170ae5.tar.gz spark-a9ecd06149df4ccafd3927c35f63b9f03f170ae5.tar.bz2 spark-a9ecd06149df4ccafd3927c35f63b9f03f170ae5.zip |
[SPARK-10941] [SQL] Refactor AggregateFunction2 and AlgebraicAggregate interfaces to improve code clarity
This patch refactors several of the Aggregate2 interfaces in order to improve code clarity.
The biggest change is a refactoring of the `AggregateFunction2` class hierarchy. In the old code, we had a class named `AlgebraicAggregate` that inherited from `AggregateFunction2`, added a new set of methods, then banned the use of the inherited methods. I found this to be fairly confusing because.
If you look carefully at the existing code, you'll see that subclasses of `AggregateFunction2` fall into two disjoint categories: imperative aggregation functions which directly extended `AggregateFunction2` and declarative, expression-based aggregate functions which extended `AlgebraicAggregate`. In order to make this more explicit, this patch refactors things so that `AggregateFunction2` is a sealed abstract class with two subclasses, `ImperativeAggregateFunction` and `ExpressionAggregateFunction`. The superclass, `AggregateFunction2`, now only contains methods and fields that are common to both subclasses.
After making this change, I updated the various AggregationIterator classes to comply with this new naming scheme. I also performed several small renamings in the aggregate interfaces themselves in order to improve clarity and rewrote or expanded a number of comments.
Author: Josh Rosen <joshrosen@databricks.com>
Closes #8973 from JoshRosen/tungsten-agg-comments.
Diffstat (limited to 'sql/catalyst')
3 files changed, 188 insertions, 119 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index e7f8104d43..4ad2607a85 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ -case class Average(child: Expression) extends AlgebraicAggregate { +case class Average(child: Expression) extends ExpressionAggregate { override def children: Seq[Expression] = child :: Nil @@ -57,7 +57,7 @@ case class Average(child: Expression) extends AlgebraicAggregate { private val currentSum = AttributeReference("currentSum", sumDataType)() private val currentCount = AttributeReference("currentCount", LongType)() - override val bufferAttributes = currentSum :: currentCount :: Nil + override val aggBufferAttributes = currentSum :: currentCount :: Nil override val initialValues = Seq( /* currentSum = */ Cast(Literal(0), sumDataType), @@ -88,7 +88,7 @@ case class Average(child: Expression) extends AlgebraicAggregate { } } -case class Count(child: Expression) extends AlgebraicAggregate { +case class Count(child: Expression) extends ExpressionAggregate { override def children: Seq[Expression] = child :: Nil override def nullable: Boolean = false @@ -101,7 +101,7 @@ case class Count(child: Expression) extends AlgebraicAggregate { private val currentCount = AttributeReference("currentCount", LongType)() - override val bufferAttributes = currentCount :: Nil + override val aggBufferAttributes = currentCount :: Nil override val initialValues = Seq( /* currentCount = */ Literal(0L) @@ -118,7 +118,7 @@ case class Count(child: Expression) extends AlgebraicAggregate { override val evaluateExpression = Cast(currentCount, LongType) } -case class First(child: Expression) extends AlgebraicAggregate { +case class First(child: Expression) extends ExpressionAggregate { override def children: Seq[Expression] = child :: Nil @@ -135,7 +135,7 @@ case class First(child: Expression) extends AlgebraicAggregate { private val first = AttributeReference("first", child.dataType)() - override val bufferAttributes = first :: Nil + override val aggBufferAttributes = first :: Nil override val initialValues = Seq( /* first = */ Literal.create(null, child.dataType) @@ -152,7 +152,7 @@ case class First(child: Expression) extends AlgebraicAggregate { override val evaluateExpression = first } -case class Last(child: Expression) extends AlgebraicAggregate { +case class Last(child: Expression) extends ExpressionAggregate { override def children: Seq[Expression] = child :: Nil @@ -169,7 +169,7 @@ case class Last(child: Expression) extends AlgebraicAggregate { private val last = AttributeReference("last", child.dataType)() - override val bufferAttributes = last :: Nil + override val aggBufferAttributes = last :: Nil override val initialValues = Seq( /* last = */ Literal.create(null, child.dataType) @@ -186,7 +186,7 @@ case class Last(child: Expression) extends AlgebraicAggregate { override val evaluateExpression = last } -case class Max(child: Expression) extends AlgebraicAggregate { +case class Max(child: Expression) extends ExpressionAggregate { override def children: Seq[Expression] = child :: Nil @@ -200,7 +200,7 @@ case class Max(child: Expression) extends AlgebraicAggregate { private val max = AttributeReference("max", child.dataType)() - override val bufferAttributes = max :: Nil + override val aggBufferAttributes = max :: Nil override val initialValues = Seq( /* max = */ Literal.create(null, child.dataType) @@ -220,7 +220,7 @@ case class Max(child: Expression) extends AlgebraicAggregate { override val evaluateExpression = max } -case class Min(child: Expression) extends AlgebraicAggregate { +case class Min(child: Expression) extends ExpressionAggregate { override def children: Seq[Expression] = child :: Nil @@ -234,7 +234,7 @@ case class Min(child: Expression) extends AlgebraicAggregate { private val min = AttributeReference("min", child.dataType)() - override val bufferAttributes = min :: Nil + override val aggBufferAttributes = min :: Nil override val initialValues = Seq( /* min = */ Literal.create(null, child.dataType) @@ -277,7 +277,7 @@ case class StddevSamp(child: Expression) extends StddevAgg(child) { // Compute standard deviation based on online algorithm specified here: // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance -abstract class StddevAgg(child: Expression) extends AlgebraicAggregate { +abstract class StddevAgg(child: Expression) extends ExpressionAggregate { override def children: Seq[Expression] = child :: Nil @@ -304,7 +304,7 @@ abstract class StddevAgg(child: Expression) extends AlgebraicAggregate { private val currentAvg = AttributeReference("currentAvg", resultType)() private val currentMk = AttributeReference("currentMk", resultType)() - override val bufferAttributes = preCount :: currentCount :: preAvg :: + override val aggBufferAttributes = preCount :: currentCount :: preAvg :: currentAvg :: currentMk :: Nil override val initialValues = Seq( @@ -397,7 +397,7 @@ abstract class StddevAgg(child: Expression) extends AlgebraicAggregate { } } -case class Sum(child: Expression) extends AlgebraicAggregate { +case class Sum(child: Expression) extends ExpressionAggregate { override def children: Seq[Expression] = child :: Nil @@ -429,7 +429,7 @@ case class Sum(child: Expression) extends AlgebraicAggregate { private val zero = Cast(Literal(0), sumDataType) - override val bufferAttributes = currentSum :: Nil + override val aggBufferAttributes = currentSum :: Nil override val initialValues = Seq( /* currentSum = */ Literal.create(null, sumDataType) @@ -473,7 +473,7 @@ case class Sum(child: Expression) extends AlgebraicAggregate { */ // scalastyle:on case class HyperLogLogPlusPlus(child: Expression, relativeSD: Double = 0.05) - extends AggregateFunction2 { + extends ImperativeAggregate { import HyperLogLogPlusPlus._ /** @@ -531,28 +531,26 @@ case class HyperLogLogPlusPlus(child: Expression, relativeSD: Double = 0.05) */ private[this] val numWords = m / REGISTERS_PER_WORD + 1 - def children: Seq[Expression] = Seq(child) + override def children: Seq[Expression] = Seq(child) - def nullable: Boolean = false - - def dataType: DataType = LongType + override def nullable: Boolean = false - def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + override def dataType: DataType = LongType - def bufferSchema: StructType = StructType.fromAttributes(bufferAttributes) + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - def cloneBufferAttributes: Seq[Attribute] = bufferAttributes.map(_.newInstance()) + override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) /** Allocate enough words to store all registers. */ - val bufferAttributes: Seq[AttributeReference] = Seq.tabulate(numWords) { i => + override val aggBufferAttributes: Seq[AttributeReference] = Seq.tabulate(numWords) { i => AttributeReference(s"MS[$i]", LongType)() } /** Fill all words with zeros. */ - def initialize(buffer: MutableRow): Unit = { + override def initialize(buffer: MutableRow): Unit = { var word = 0 while (word < numWords) { - buffer.setLong(mutableBufferOffset + word, 0) + buffer.setLong(mutableAggBufferOffset + word, 0) word += 1 } } @@ -562,7 +560,7 @@ case class HyperLogLogPlusPlus(child: Expression, relativeSD: Double = 0.05) * * Variable names in the HLL++ paper match variable names in the code. */ - def update(buffer: MutableRow, input: InternalRow): Unit = { + override def update(buffer: MutableRow, input: InternalRow): Unit = { val v = child.eval(input) if (v != null) { // Create the hashed value 'x'. @@ -576,7 +574,7 @@ case class HyperLogLogPlusPlus(child: Expression, relativeSD: Double = 0.05) // Get the word containing the register we are interested in. val wordOffset = idx / REGISTERS_PER_WORD - val word = buffer.getLong(mutableBufferOffset + wordOffset) + val word = buffer.getLong(mutableAggBufferOffset + wordOffset) // Extract the M[J] register value from the word. val shift = REGISTER_SIZE * (idx - (wordOffset * REGISTERS_PER_WORD)) @@ -585,7 +583,7 @@ case class HyperLogLogPlusPlus(child: Expression, relativeSD: Double = 0.05) // Assign the maximum number of leading zeros to the register. if (pw > Midx) { - buffer.setLong(mutableBufferOffset + wordOffset, (word & ~mask) | (pw << shift)) + buffer.setLong(mutableAggBufferOffset + wordOffset, (word & ~mask) | (pw << shift)) } } } @@ -594,12 +592,12 @@ case class HyperLogLogPlusPlus(child: Expression, relativeSD: Double = 0.05) * Merge the HLL buffers by iterating through the registers in both buffers and select the * maximum number of leading zeros for each register. */ - def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { + override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { var idx = 0 var wordOffset = 0 while (wordOffset < numWords) { - val word1 = buffer1.getLong(mutableBufferOffset + wordOffset) - val word2 = buffer2.getLong(inputBufferOffset + wordOffset) + val word1 = buffer1.getLong(mutableAggBufferOffset + wordOffset) + val word2 = buffer2.getLong(inputAggBufferOffset + wordOffset) var word = 0L var i = 0 var mask = REGISTER_WORD_MASK @@ -609,7 +607,7 @@ case class HyperLogLogPlusPlus(child: Expression, relativeSD: Double = 0.05) i += 1 idx += 1 } - buffer1.setLong(mutableBufferOffset + wordOffset, word) + buffer1.setLong(mutableAggBufferOffset + wordOffset, word) wordOffset += 1 } } @@ -664,14 +662,14 @@ case class HyperLogLogPlusPlus(child: Expression, relativeSD: Double = 0.05) * * Variable names in the HLL++ paper match variable names in the code. */ - def eval(buffer: InternalRow): Any = { + override def eval(buffer: InternalRow): Any = { // Compute the inverse of indicator value 'z' and count the number of zeros 'V'. var zInverse = 0.0d var V = 0.0d var idx = 0 var wordOffset = 0 while (wordOffset < numWords) { - val word = buffer.getLong(mutableBufferOffset + wordOffset) + val word = buffer.getLong(mutableAggBufferOffset + wordOffset) var i = 0 var shift = 0 while (idx < m && i < REGISTERS_PER_WORD) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index d8699533cd..74e15ec90b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -69,9 +69,6 @@ private[sql] case object NoOp extends Expression with Unevaluable { /** * A container for an [[AggregateFunction2]] with its [[AggregateMode]] and a field * (`isDistinct`) indicating if DISTINCT keyword is specified for this function. - * @param aggregateFunction - * @param mode - * @param isDistinct */ private[sql] case class AggregateExpression2( aggregateFunction: AggregateFunction2, @@ -86,7 +83,7 @@ private[sql] case class AggregateExpression2( override def references: AttributeSet = { val childReferences = mode match { case Partial | Complete => aggregateFunction.references.toSeq - case PartialMerge | Final => aggregateFunction.bufferAttributes + case PartialMerge | Final => aggregateFunction.aggBufferAttributes } AttributeSet(childReferences) @@ -95,98 +92,192 @@ private[sql] case class AggregateExpression2( override def toString: String = s"(${aggregateFunction},mode=$mode,isDistinct=$isDistinct)" } -abstract class AggregateFunction2 - extends Expression with ImplicitCastInputTypes { +/** + * AggregateFunction2 is the superclass of two aggregation function interfaces: + * + * - [[ImperativeAggregate]] is for aggregation functions that are specified in terms of + * initialize(), update(), and merge() functions that operate on Row-based aggregation buffers. + * - [[ExpressionAggregate]] is for aggregation functions that are specified using + * Catalyst expressions. + * + * In both interfaces, aggregates must define the schema ([[aggBufferSchema]]) and attributes + * ([[aggBufferAttributes]]) of an aggregation buffer which is used to hold partial aggregate + * results. At runtime, multiple aggregate functions are evaluated by the same operator using a + * combined aggregation buffer which concatenates the aggregation buffers of the individual + * aggregate functions. + * + * Code which accepts [[AggregateFunction2]] instances should be prepared to handle both types of + * aggregate functions. + */ +sealed abstract class AggregateFunction2 extends Expression with ImplicitCastInputTypes { /** An aggregate function is not foldable. */ final override def foldable: Boolean = false + /** The schema of the aggregation buffer. */ + def aggBufferSchema: StructType + + /** Attributes of fields in aggBufferSchema. */ + def aggBufferAttributes: Seq[AttributeReference] + /** - * The offset of this function's start buffer value in the - * underlying shared mutable aggregation buffer. - * For example, we have two aggregate functions `avg(x)` and `avg(y)`, which share - * the same aggregation buffer. In this shared buffer, the position of the first - * buffer value of `avg(x)` will be 0 and the position of the first buffer value of `avg(y)` - * will be 2. + * Attributes of fields in input aggregation buffers (immutable aggregation buffers that are + * merged with mutable aggregation buffers in the merge() function or merge expressions). + * These attributes are created automatically by cloning the [[aggBufferAttributes]]. */ - protected var mutableBufferOffset: Int = 0 - - def withNewMutableBufferOffset(newMutableBufferOffset: Int): Unit = { - mutableBufferOffset = newMutableBufferOffset - } + def inputAggBufferAttributes: Seq[AttributeReference] /** - * The offset of this function's start buffer value in the - * underlying shared input aggregation buffer. An input aggregation buffer is used - * when we merge two aggregation buffers and it is basically the immutable one - * (we merge an input aggregation buffer and a mutable aggregation buffer and - * then store the new buffer values to the mutable aggregation buffer). - * Usually, an input aggregation buffer also contain extra elements like grouping - * keys at the beginning. So, mutableBufferOffset and inputBufferOffset are often - * different. - * For example, we have a grouping expression `key``, and two aggregate functions - * `avg(x)` and `avg(y)`. In this shared input aggregation buffer, the position of the first - * buffer value of `avg(x)` will be 1 and the position of the first buffer value of `avg(y)` - * will be 3 (position 0 is used for the value of key`). + * Indicates if this function supports partial aggregation. + * Currently Hive UDAF is the only one that doesn't support partial aggregation. */ - protected var inputBufferOffset: Int = 0 + def supportsPartial: Boolean = true - def withNewInputBufferOffset(newInputBufferOffset: Int): Unit = { - inputBufferOffset = newInputBufferOffset - } + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = + throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") +} - /** The schema of the aggregation buffer. */ - def bufferSchema: StructType +/** + * API for aggregation functions that are expressed in terms of imperative initialize(), update(), + * and merge() functions which operate on Row-based aggregation buffers. + * + * Within these functions, code should access fields of the mutable aggregation buffer by adding the + * bufferSchema-relative field number to `mutableAggBufferOffset` then using this new field number + * to access the buffer Row. This is necessary because this aggregation function's buffer is + * embedded inside of a larger shared aggregation buffer when an aggregation operator evaluates + * multiple aggregate functions at the same time. + * + * We need to perform similar field number arithmetic when merging multiple intermediate + * aggregate buffers together in `merge()` (in this case, use `inputAggBufferOffset` when accessing + * the input buffer). + */ +abstract class ImperativeAggregate extends AggregateFunction2 { - /** Attributes of fields in bufferSchema. */ - def bufferAttributes: Seq[AttributeReference] + /** + * The offset of this function's first buffer value in the underlying shared mutable aggregation + * buffer. + * + * For example, we have two aggregate functions `avg(x)` and `avg(y)`, which share the same + * aggregation buffer. In this shared buffer, the position of the first buffer value of `avg(x)` + * will be 0 and the position of the first buffer value of `avg(y)` will be 2: + * + * avg(x) mutableAggBufferOffset = 0 + * | + * v + * +--------+--------+--------+--------+ + * | sum1 | count1 | sum2 | count2 | + * +--------+--------+--------+--------+ + * ^ + * | + * avg(y) mutableAggBufferOffset = 2 + * + */ + protected var mutableAggBufferOffset: Int = 0 - /** Clones bufferAttributes. */ - def cloneBufferAttributes: Seq[Attribute] + def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): Unit = { + mutableAggBufferOffset = newMutableAggBufferOffset + } /** - * Initializes its aggregation buffer located in `buffer`. - * It will use bufferOffset to find the starting point of - * its buffer in the given `buffer` shared with other functions. + * The offset of this function's start buffer value in the underlying shared input aggregation + * buffer. An input aggregation buffer is used when we merge two aggregation buffers together in + * the `update()` function and is immutable (we merge an input aggregation buffer and a mutable + * aggregation buffer and then store the new buffer values to the mutable aggregation buffer). + * + * An input aggregation buffer may contain extra fields, such as grouping keys, at its start, so + * mutableAggBufferOffset and inputAggBufferOffset are often different. + * + * For example, say we have a grouping expression, `key`, and two aggregate functions, + * `avg(x)` and `avg(y)`. In the shared input aggregation buffer, the position of the first + * buffer value of `avg(x)` will be 1 and the position of the first buffer value of `avg(y)` + * will be 3 (position 0 is used for the value of `key`): + * + * avg(x) inputAggBufferOffset = 1 + * | + * v + * +--------+--------+--------+--------+--------+ + * | key | sum1 | count1 | sum2 | count2 | + * +--------+--------+--------+--------+--------+ + * ^ + * | + * avg(y) inputAggBufferOffset = 3 + * */ - def initialize(buffer: MutableRow): Unit + protected var inputAggBufferOffset: Int = 0 + + def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): Unit = { + inputAggBufferOffset = newInputAggBufferOffset + } /** - * Updates its aggregation buffer located in `buffer` based on the given `input`. - * It will use bufferOffset to find the starting point of its buffer in the given `buffer` - * shared with other functions. + * Initializes the mutable aggregation buffer located in `mutableAggBuffer`. + * + * Use `fieldNumber + mutableAggBufferOffset` to access fields of `mutableAggBuffer`. */ - def update(buffer: MutableRow, input: InternalRow): Unit + def initialize(mutableAggBuffer: MutableRow): Unit /** - * Updates its aggregation buffer located in `buffer1` by combining intermediate results - * in the current buffer and intermediate results from another buffer `buffer2`. - * It will use bufferOffset to find the starting point of its buffer in the given `buffer1` - * and `buffer2`. + * Updates its aggregation buffer, located in `mutableAggBuffer`, based on the given `inputRow`. + * + * Use `fieldNumber + mutableAggBufferOffset` to access fields of `mutableAggBuffer`. */ - def merge(buffer1: MutableRow, buffer2: InternalRow): Unit - - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = - throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") + def update(mutableAggBuffer: MutableRow, inputRow: InternalRow): Unit /** - * Indicates if this function supports partial aggregation. - * Currently Hive UDAF is the only one that doesn't support partial aggregation. + * Combines new intermediate results from the `inputAggBuffer` with the existing intermediate + * results in the `mutableAggBuffer.` + * + * Use `fieldNumber + mutableAggBufferOffset` to access fields of `mutableAggBuffer`. + * Use `fieldNumber + inputAggBufferOffset` to access fields of `inputAggBuffer`. */ - def supportsPartial: Boolean = true + def merge(mutableAggBuffer: MutableRow, inputAggBuffer: InternalRow): Unit + + final lazy val inputAggBufferAttributes: Seq[AttributeReference] = + aggBufferAttributes.map(_.newInstance()) } /** - * A helper class for aggregate functions that can be implemented in terms of catalyst expressions. + * API for aggregation functions that are expressed in terms of Catalyst expressions. + * + * When implementing a new expression-based aggregate function, start by implementing + * `bufferAttributes`, defining attributes for the fields of the mutable aggregation buffer. You + * can then use these attributes when defining `updateExpressions`, `mergeExpressions`, and + * `evaluateExpressions`. */ -abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable with Unevaluable { +abstract class ExpressionAggregate + extends AggregateFunction2 + with Serializable + with Unevaluable { + /** + * Expressions for initializing empty aggregation buffers. + */ val initialValues: Seq[Expression] + + /** + * Expressions for updating the mutable aggregation buffer based on an input row. + */ val updateExpressions: Seq[Expression] + + /** + * A sequence of expressions for merging two aggregation buffers together. When defining these + * expressions, you can use the syntax `attributeName.left` and `attributeName.right` to refer + * to the attributes corresponding to each of the buffers being merged (this magic is enabled + * by the [[RichAttribute]] implicit class). + */ val mergeExpressions: Seq[Expression] + + /** + * An expression which returns the final value for this aggregate function. Its data type should + * match this expression's [[dataType]]. + */ val evaluateExpression: Expression - override lazy val cloneBufferAttributes = bufferAttributes.map(_.newInstance()) + /** An expression-based aggregate's bufferSchema is derived from bufferAttributes. */ + final override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) + + final lazy val inputAggBufferAttributes: Seq[AttributeReference] = + aggBufferAttributes.map(_.newInstance()) /** * A helper class for representing an attribute used in merging two @@ -194,33 +285,13 @@ abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable w * we merge buffer values and then update bufferLeft. A [[RichAttribute]] * of an [[AttributeReference]] `a` has two functions `left` and `right`, * which represent `a` in `bufferLeft` and `bufferRight`, respectively. - * @param a */ implicit class RichAttribute(a: AttributeReference) { /** Represents this attribute at the mutable buffer side. */ def left: AttributeReference = a /** Represents this attribute at the input buffer side (the data value is read-only). */ - def right: AttributeReference = cloneBufferAttributes(bufferAttributes.indexOf(a)) - } - - /** An AlgebraicAggregate's bufferSchema is derived from bufferAttributes. */ - override def bufferSchema: StructType = StructType.fromAttributes(bufferAttributes) - - override def initialize(buffer: MutableRow): Unit = { - throw new UnsupportedOperationException( - "AlgebraicAggregate's initialize should not be called directly") - } - - override final def update(buffer: MutableRow, input: InternalRow): Unit = { - throw new UnsupportedOperationException( - "AlgebraicAggregate's update should not be called directly") + def right: AttributeReference = inputAggBufferAttributes(aggBufferAttributes.indexOf(a)) } - - override final def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { - throw new UnsupportedOperationException( - "AlgebraicAggregate's merge should not be called directly") - } - } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala index ecc0644164..0d32949775 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala @@ -38,7 +38,7 @@ class HyperLogLogPlusPlusSuite extends SparkFunSuite { } def createBuffer(hll: HyperLogLogPlusPlus): MutableRow = { - val buffer = new SpecificMutableRow(hll.bufferAttributes.map(_.dataType)) + val buffer = new SpecificMutableRow(hll.aggBufferAttributes.map(_.dataType)) hll.initialize(buffer) buffer } |