aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorYin Huai <yhuai@databricks.com>2015-07-28 19:01:25 -0700
committerYin Huai <yhuai@databricks.com>2015-07-28 19:01:25 -0700
commit3744b7fd42e52011af60cc205fcb4e4b23b35c68 (patch)
treecd48bb28d354e3da5550f6b2be56a6dc618742c3 /sql
parente78ec1a8fabfe409c92c4904208f53dbdcfcf139 (diff)
downloadspark-3744b7fd42e52011af60cc205fcb4e4b23b35c68.tar.gz
spark-3744b7fd42e52011af60cc205fcb4e4b23b35c68.tar.bz2
spark-3744b7fd42e52011af60cc205fcb4e4b23b35c68.zip
[SPARK-9422] [SQL] Remove the placeholder attributes used in the aggregation buffers
https://issues.apache.org/jira/browse/SPARK-9422 Author: Yin Huai <yhuai@databricks.com> Closes #7737 from yhuai/removePlaceHolder and squashes the following commits: ec29b44 [Yin Huai] Remove placeholder attributes.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala27
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala209
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala17
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala4
5 files changed, 121 insertions, 140 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index 10bd19c8a8..9fb7623172 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -103,9 +103,30 @@ abstract class AggregateFunction2
final override def foldable: Boolean = false
/**
- * The offset of this function's buffer in the underlying buffer shared with other functions.
+ * The offset of this function's start buffer value in the
+ * underlying shared mutable aggregation buffer.
+ * For example, we have two aggregate functions `avg(x)` and `avg(y)`, which share
+ * the same aggregation buffer. In this shared buffer, the position of the first
+ * buffer value of `avg(x)` will be 0 and the position of the first buffer value of `avg(y)`
+ * will be 2.
*/
- var bufferOffset: Int = 0
+ var mutableBufferOffset: Int = 0
+
+ /**
+ * The offset of this function's start buffer value in the
+ * underlying shared input aggregation buffer. An input aggregation buffer is used
+ * when we merge two aggregation buffers and it is basically the immutable one
+ * (we merge an input aggregation buffer and a mutable aggregation buffer and
+ * then store the new buffer values to the mutable aggregation buffer).
+ * Usually, an input aggregation buffer also contain extra elements like grouping
+ * keys at the beginning. So, mutableBufferOffset and inputBufferOffset are often
+ * different.
+ * For example, we have a grouping expression `key``, and two aggregate functions
+ * `avg(x)` and `avg(y)`. In this shared input aggregation buffer, the position of the first
+ * buffer value of `avg(x)` will be 1 and the position of the first buffer value of `avg(y)`
+ * will be 3 (position 0 is used for the value of key`).
+ */
+ var inputBufferOffset: Int = 0
/** The schema of the aggregation buffer. */
def bufferSchema: StructType
@@ -176,7 +197,7 @@ abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable w
override def initialize(buffer: MutableRow): Unit = {
var i = 0
while (i < bufferAttributes.size) {
- buffer(i + bufferOffset) = initialValues(i).eval()
+ buffer(i + mutableBufferOffset) = initialValues(i).eval()
i += 1
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala
index 0c9082897f..98538c462b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala
@@ -72,8 +72,10 @@ case class Aggregate2Sort(
protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
child.execute().mapPartitions { iter =>
if (aggregateExpressions.length == 0) {
- new GroupingIterator(
+ new FinalSortAggregationIterator(
groupingExpressions,
+ Nil,
+ Nil,
resultExpressions,
newMutableProjection,
child.output,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala
index 1b89edafa8..2ca0cb82c1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala
@@ -41,7 +41,8 @@ private[sql] abstract class SortAggregationIterator(
///////////////////////////////////////////////////////////////////////////
protected val aggregateFunctions: Array[AggregateFunction2] = {
- var bufferOffset = initialBufferOffset
+ var mutableBufferOffset = 0
+ var inputBufferOffset: Int = initialInputBufferOffset
val functions = new Array[AggregateFunction2](aggregateExpressions.length)
var i = 0
while (i < aggregateExpressions.length) {
@@ -54,13 +55,18 @@ private[sql] abstract class SortAggregationIterator(
// function's children in the update method of this aggregate function.
// Those eval calls require BoundReferences to work.
BindReferences.bindReference(func, inputAttributes)
- case _ => func
+ case _ =>
+ // We only need to set inputBufferOffset for aggregate functions with mode
+ // PartialMerge and Final.
+ func.inputBufferOffset = inputBufferOffset
+ inputBufferOffset += func.bufferSchema.length
+ func
}
- // Set bufferOffset for this function. It is important that setting bufferOffset
- // happens after all potential bindReference operations because bindReference
- // will create a new instance of the function.
- funcWithBoundReferences.bufferOffset = bufferOffset
- bufferOffset += funcWithBoundReferences.bufferSchema.length
+ // Set mutableBufferOffset for this function. It is important that setting
+ // mutableBufferOffset happens after all potential bindReference operations
+ // because bindReference will create a new instance of the function.
+ funcWithBoundReferences.mutableBufferOffset = mutableBufferOffset
+ mutableBufferOffset += funcWithBoundReferences.bufferSchema.length
functions(i) = funcWithBoundReferences
i += 1
}
@@ -97,25 +103,24 @@ private[sql] abstract class SortAggregationIterator(
// The number of elements of the underlying buffer of this operator.
// All aggregate functions are sharing this underlying buffer and they find their
// buffer values through bufferOffset.
- var size = initialBufferOffset
- var i = 0
- while (i < aggregateFunctions.length) {
- size += aggregateFunctions(i).bufferSchema.length
- i += 1
- }
- new GenericMutableRow(size)
+ // var size = 0
+ // var i = 0
+ // while (i < aggregateFunctions.length) {
+ // size += aggregateFunctions(i).bufferSchema.length
+ // i += 1
+ // }
+ new GenericMutableRow(aggregateFunctions.map(_.bufferSchema.length).sum)
}
protected val joinedRow = new JoinedRow
- protected val placeholderExpressions = Seq.fill(initialBufferOffset)(NoOp)
-
// This projection is used to initialize buffer values for all AlgebraicAggregates.
protected val algebraicInitialProjection = {
- val initExpressions = placeholderExpressions ++ aggregateFunctions.flatMap {
+ val initExpressions = aggregateFunctions.flatMap {
case ae: AlgebraicAggregate => ae.initialValues
case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
}
+
newMutableProjection(initExpressions, Nil)().target(buffer)
}
@@ -132,10 +137,6 @@ private[sql] abstract class SortAggregationIterator(
// Indicates if we has new group of rows to process.
protected var hasNewGroup: Boolean = true
- ///////////////////////////////////////////////////////////////////////////
- // Private methods
- ///////////////////////////////////////////////////////////////////////////
-
/** Initializes buffer values for all aggregate functions. */
protected def initializeBuffer(): Unit = {
algebraicInitialProjection(EmptyRow)
@@ -160,6 +161,10 @@ private[sql] abstract class SortAggregationIterator(
}
}
+ ///////////////////////////////////////////////////////////////////////////
+ // Private methods
+ ///////////////////////////////////////////////////////////////////////////
+
/** Processes rows in the current group. It will stop when it find a new group. */
private def processCurrentGroup(): Unit = {
currentGroupingKey = nextGroupingKey
@@ -218,10 +223,13 @@ private[sql] abstract class SortAggregationIterator(
// Methods that need to be implemented
///////////////////////////////////////////////////////////////////////////
- protected def initialBufferOffset: Int
+ /** The initial input buffer offset for `inputBufferOffset` of an [[AggregateFunction2]]. */
+ protected def initialInputBufferOffset: Int
+ /** The function used to process an input row. */
protected def processRow(row: InternalRow): Unit
+ /** The function used to generate the result row. */
protected def generateOutput(): InternalRow
///////////////////////////////////////////////////////////////////////////
@@ -232,37 +240,6 @@ private[sql] abstract class SortAggregationIterator(
}
/**
- * An iterator only used to group input rows according to values of `groupingExpressions`.
- * It assumes that input rows are already grouped by values of `groupingExpressions`.
- */
-class GroupingIterator(
- groupingExpressions: Seq[NamedExpression],
- resultExpressions: Seq[NamedExpression],
- newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
- inputAttributes: Seq[Attribute],
- inputIter: Iterator[InternalRow])
- extends SortAggregationIterator(
- groupingExpressions,
- Nil,
- newMutableProjection,
- inputAttributes,
- inputIter) {
-
- private val resultProjection =
- newMutableProjection(resultExpressions, groupingExpressions.map(_.toAttribute))()
-
- override protected def initialBufferOffset: Int = 0
-
- override protected def processRow(row: InternalRow): Unit = {
- // Since we only do grouping, there is nothing to do at here.
- }
-
- override protected def generateOutput(): InternalRow = {
- resultProjection(currentGroupingKey)
- }
-}
-
-/**
* An iterator used to do partial aggregations (for those aggregate functions with mode Partial).
* It assumes that input rows are already grouped by values of `groupingExpressions`.
* The format of its output rows is:
@@ -291,7 +268,7 @@ class PartialSortAggregationIterator(
newMutableProjection(updateExpressions, bufferSchema ++ inputAttributes)().target(buffer)
}
- override protected def initialBufferOffset: Int = 0
+ override protected def initialInputBufferOffset: Int = 0
override protected def processRow(row: InternalRow): Unit = {
// Process all algebraic aggregate functions.
@@ -318,11 +295,7 @@ class PartialSortAggregationIterator(
* |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN|
*
* The format of its internal buffer is:
- * |placeholder1|...|placeholderN|aggregationBuffer1|...|aggregationBufferN|
- * Every placeholder is for a grouping expression.
- * The actual buffers are stored after placeholderN.
- * The reason that we have placeholders at here is to make our underlying buffer have the same
- * length with a input row.
+ * |aggregationBuffer1|...|aggregationBufferN|
*
* The format of its output rows is:
* |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN|
@@ -340,33 +313,21 @@ class PartialMergeSortAggregationIterator(
inputAttributes,
inputIter) {
- private val placeholderAttributes =
- Seq.fill(initialBufferOffset)(AttributeReference("placeholder", NullType)())
-
// This projection is used to merge buffer values for all AlgebraicAggregates.
private val algebraicMergeProjection = {
- val bufferSchemata =
- placeholderAttributes ++ aggregateFunctions.flatMap(_.bufferAttributes) ++
- placeholderAttributes ++ aggregateFunctions.flatMap(_.cloneBufferAttributes)
- val mergeExpressions = placeholderExpressions ++ aggregateFunctions.flatMap {
+ val mergeInputSchema =
+ aggregateFunctions.flatMap(_.bufferAttributes) ++
+ groupingExpressions.map(_.toAttribute) ++
+ aggregateFunctions.flatMap(_.cloneBufferAttributes)
+ val mergeExpressions = aggregateFunctions.flatMap {
case ae: AlgebraicAggregate => ae.mergeExpressions
case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
}
- newMutableProjection(mergeExpressions, bufferSchemata)()
+ newMutableProjection(mergeExpressions, mergeInputSchema)()
}
- // This projection is used to extract aggregation buffers from the underlying buffer.
- // We need it because the underlying buffer has placeholders at its beginning.
- private val extractsBufferValues = {
- val expressions = aggregateFunctions.flatMap {
- case agg => agg.bufferAttributes
- }
-
- newMutableProjection(expressions, inputAttributes)()
- }
-
- override protected def initialBufferOffset: Int = groupingExpressions.length
+ override protected def initialInputBufferOffset: Int = groupingExpressions.length
override protected def processRow(row: InternalRow): Unit = {
// Process all algebraic aggregate functions.
@@ -381,7 +342,7 @@ class PartialMergeSortAggregationIterator(
override protected def generateOutput(): InternalRow = {
// We output grouping expressions and aggregation buffers.
- joinedRow(currentGroupingKey, extractsBufferValues(buffer))
+ joinedRow(currentGroupingKey, buffer).copy()
}
}
@@ -393,11 +354,7 @@ class PartialMergeSortAggregationIterator(
* |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN|
*
* The format of its internal buffer is:
- * |placeholder1|...|placeholder N|aggregationBuffer1|...|aggregationBufferN|
- * Every placeholder is for a grouping expression.
- * The actual buffers are stored after placeholderN.
- * The reason that we have placeholders at here is to make our underlying buffer have the same
- * length with a input row.
+ * |aggregationBuffer1|...|aggregationBufferN|
*
* The format of its output rows is represented by the schema of `resultExpressions`.
*/
@@ -425,27 +382,23 @@ class FinalSortAggregationIterator(
newMutableProjection(
resultExpressions, groupingExpressions.map(_.toAttribute) ++ aggregateAttributes)()
- private val offsetAttributes =
- Seq.fill(initialBufferOffset)(AttributeReference("placeholder", NullType)())
-
// This projection is used to merge buffer values for all AlgebraicAggregates.
private val algebraicMergeProjection = {
- val bufferSchemata =
- offsetAttributes ++ aggregateFunctions.flatMap(_.bufferAttributes) ++
- offsetAttributes ++ aggregateFunctions.flatMap(_.cloneBufferAttributes)
- val mergeExpressions = placeholderExpressions ++ aggregateFunctions.flatMap {
+ val mergeInputSchema =
+ aggregateFunctions.flatMap(_.bufferAttributes) ++
+ groupingExpressions.map(_.toAttribute) ++
+ aggregateFunctions.flatMap(_.cloneBufferAttributes)
+ val mergeExpressions = aggregateFunctions.flatMap {
case ae: AlgebraicAggregate => ae.mergeExpressions
case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
}
- newMutableProjection(mergeExpressions, bufferSchemata)()
+ newMutableProjection(mergeExpressions, mergeInputSchema)()
}
// This projection is used to evaluate all AlgebraicAggregates.
private val algebraicEvalProjection = {
- val bufferSchemata =
- offsetAttributes ++ aggregateFunctions.flatMap(_.bufferAttributes) ++
- offsetAttributes ++ aggregateFunctions.flatMap(_.cloneBufferAttributes)
+ val bufferSchemata = aggregateFunctions.flatMap(_.bufferAttributes)
val evalExpressions = aggregateFunctions.map {
case ae: AlgebraicAggregate => ae.evaluateExpression
case agg: AggregateFunction2 => NoOp
@@ -454,7 +407,7 @@ class FinalSortAggregationIterator(
newMutableProjection(evalExpressions, bufferSchemata)()
}
- override protected def initialBufferOffset: Int = groupingExpressions.length
+ override protected def initialInputBufferOffset: Int = groupingExpressions.length
override def initialize(): Unit = {
if (inputIter.hasNext) {
@@ -471,7 +424,10 @@ class FinalSortAggregationIterator(
// Right now, the buffer only contains initial buffer values. Because
// merging two buffers with initial values will generate a row that
// still store initial values. We set the currentRow as the copy of the current buffer.
- val currentRow = buffer.copy()
+ // Because input aggregation buffer has initialInputBufferOffset extra values at the
+ // beginning, we create a dummy row for this part.
+ val currentRow =
+ joinedRow(new GenericInternalRow(initialInputBufferOffset), buffer).copy()
nextGroupingKey = groupGenerator(currentRow).copy()
firstRowInNextGroup = currentRow
} else {
@@ -518,18 +474,15 @@ class FinalSortAggregationIterator(
* Final mode.
*
* The format of its internal buffer is:
- * |placeholder1|...|placeholder(N+M)|aggregationBuffer1|...|aggregationBuffer(N+M)|
- * The first N placeholders represent slots of grouping expressions.
- * Then, next M placeholders represent slots of col1 to colM.
+ * |aggregationBuffer1|...|aggregationBuffer(N+M)|
* For aggregation buffers, first N aggregation buffers are used by N aggregate functions with
* mode Final. Then, the last M aggregation buffers are used by M aggregate functions with mode
- * Complete. The reason that we have placeholders at here is to make our underlying buffer
- * have the same length with a input row.
+ * Complete.
*
* The format of its output rows is represented by the schema of `resultExpressions`.
*/
class FinalAndCompleteSortAggregationIterator(
- override protected val initialBufferOffset: Int,
+ override protected val initialInputBufferOffset: Int,
groupingExpressions: Seq[NamedExpression],
finalAggregateExpressions: Seq[AggregateExpression2],
finalAggregateAttributes: Seq[Attribute],
@@ -561,9 +514,6 @@ class FinalAndCompleteSortAggregationIterator(
newMutableProjection(resultExpressions, inputSchema)()
}
- private val offsetAttributes =
- Seq.fill(initialBufferOffset)(AttributeReference("placeholder", NullType)())
-
// All aggregate functions with mode Final.
private val finalAggregateFunctions: Array[AggregateFunction2] = {
val functions = new Array[AggregateFunction2](finalAggregateExpressions.length)
@@ -601,38 +551,38 @@ class FinalAndCompleteSortAggregationIterator(
// This projection is used to merge buffer values for all AlgebraicAggregates with mode
// Final.
private val finalAlgebraicMergeProjection = {
- val numCompleteOffsetAttributes =
- completeAggregateFunctions.map(_.bufferAttributes.length).sum
- val completeOffsetAttributes =
- Seq.fill(numCompleteOffsetAttributes)(AttributeReference("placeholder", NullType)())
- val completeOffsetExpressions = Seq.fill(numCompleteOffsetAttributes)(NoOp)
-
- val bufferSchemata =
- offsetAttributes ++ finalAggregateFunctions.flatMap(_.bufferAttributes) ++
- completeOffsetAttributes ++ offsetAttributes ++
- finalAggregateFunctions.flatMap(_.cloneBufferAttributes) ++ completeOffsetAttributes
+ // The first initialInputBufferOffset values of the input aggregation buffer is
+ // for grouping expressions and distinct columns.
+ val groupingAttributesAndDistinctColumns = inputAttributes.take(initialInputBufferOffset)
+
+ val completeOffsetExpressions =
+ Seq.fill(completeAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp)
+
+ val mergeInputSchema =
+ finalAggregateFunctions.flatMap(_.bufferAttributes) ++
+ completeAggregateFunctions.flatMap(_.bufferAttributes) ++
+ groupingAttributesAndDistinctColumns ++
+ finalAggregateFunctions.flatMap(_.cloneBufferAttributes)
val mergeExpressions =
- placeholderExpressions ++ finalAggregateFunctions.flatMap {
+ finalAggregateFunctions.flatMap {
case ae: AlgebraicAggregate => ae.mergeExpressions
case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
} ++ completeOffsetExpressions
-
- newMutableProjection(mergeExpressions, bufferSchemata)()
+ newMutableProjection(mergeExpressions, mergeInputSchema)()
}
// This projection is used to update buffer values for all AlgebraicAggregates with mode
// Complete.
private val completeAlgebraicUpdateProjection = {
- val numFinalOffsetAttributes = finalAggregateFunctions.map(_.bufferAttributes.length).sum
- val finalOffsetAttributes =
- Seq.fill(numFinalOffsetAttributes)(AttributeReference("placeholder", NullType)())
- val finalOffsetExpressions = Seq.fill(numFinalOffsetAttributes)(NoOp)
+ // We do not touch buffer values of aggregate functions with the Final mode.
+ val finalOffsetExpressions =
+ Seq.fill(finalAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp)
val bufferSchema =
- offsetAttributes ++ finalOffsetAttributes ++
+ finalAggregateFunctions.flatMap(_.bufferAttributes) ++
completeAggregateFunctions.flatMap(_.bufferAttributes)
val updateExpressions =
- placeholderExpressions ++ finalOffsetExpressions ++ completeAggregateFunctions.flatMap {
+ finalOffsetExpressions ++ completeAggregateFunctions.flatMap {
case ae: AlgebraicAggregate => ae.updateExpressions
case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
}
@@ -641,9 +591,7 @@ class FinalAndCompleteSortAggregationIterator(
// This projection is used to evaluate all AlgebraicAggregates.
private val algebraicEvalProjection = {
- val bufferSchemata =
- offsetAttributes ++ aggregateFunctions.flatMap(_.bufferAttributes) ++
- offsetAttributes ++ aggregateFunctions.flatMap(_.cloneBufferAttributes)
+ val bufferSchemata = aggregateFunctions.flatMap(_.bufferAttributes)
val evalExpressions = aggregateFunctions.map {
case ae: AlgebraicAggregate => ae.evaluateExpression
case agg: AggregateFunction2 => NoOp
@@ -667,7 +615,10 @@ class FinalAndCompleteSortAggregationIterator(
// Right now, the buffer only contains initial buffer values. Because
// merging two buffers with initial values will generate a row that
// still store initial values. We set the currentRow as the copy of the current buffer.
- val currentRow = buffer.copy()
+ // Because input aggregation buffer has initialInputBufferOffset extra values at the
+ // beginning, we create a dummy row for this part.
+ val currentRow =
+ joinedRow(new GenericInternalRow(initialInputBufferOffset), buffer).copy()
nextGroupingKey = groupGenerator(currentRow).copy()
firstRowInNextGroup = currentRow
} else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
index 073c45ae2f..cc54319171 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
@@ -184,7 +184,7 @@ private[sql] case class ScalaUDAF(
bufferSchema,
bufferValuesToCatalystConverters,
bufferValuesToScalaConverters,
- bufferOffset,
+ inputBufferOffset,
null)
lazy val mutableAggregateBuffer: MutableAggregationBufferImpl =
@@ -192,9 +192,16 @@ private[sql] case class ScalaUDAF(
bufferSchema,
bufferValuesToCatalystConverters,
bufferValuesToScalaConverters,
- bufferOffset,
+ mutableBufferOffset,
null)
+ lazy val evalAggregateBuffer: InputAggregationBuffer =
+ new InputAggregationBuffer(
+ bufferSchema,
+ bufferValuesToCatalystConverters,
+ bufferValuesToScalaConverters,
+ mutableBufferOffset,
+ null)
override def initialize(buffer: MutableRow): Unit = {
mutableAggregateBuffer.underlyingBuffer = buffer
@@ -217,10 +224,10 @@ private[sql] case class ScalaUDAF(
udaf.merge(mutableAggregateBuffer, inputAggregateBuffer)
}
- override def eval(buffer: InternalRow = null): Any = {
- inputAggregateBuffer.underlyingInputBuffer = buffer
+ override def eval(buffer: InternalRow): Any = {
+ evalAggregateBuffer.underlyingInputBuffer = buffer
- udaf.evaluate(inputAggregateBuffer)
+ udaf.evaluate(evalAggregateBuffer)
}
override def toString: String = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
index 5bbe6c162f..6549c87752 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
@@ -292,8 +292,8 @@ object Utils {
AggregateExpression2(aggregateFunction, PartialMerge, false)
}
val partialMergeAggregateAttributes =
- partialMergeAggregateExpressions.map {
- expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)
+ partialMergeAggregateExpressions.flatMap { agg =>
+ agg.aggregateFunction.bufferAttributes
}
val partialMergeAggregate =
Aggregate2Sort(