From 970ab8f6ddc66401ad1cf4b2d1050dd0c8876224 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 26 Aug 2016 10:56:57 -0700 Subject: [SPARK-17187][SQL][FOLLOW-UP] improve document of TypedImperativeAggregate ## What changes were proposed in this pull request? improve the document to make it easier to understand and also mention window operator. ## How was this patch tested? N/A Author: Wenchen Fan Closes #14822 from cloud-fan/object-agg. --- .../expressions/aggregate/interfaces.scala | 101 +++++++++++++-------- 1 file changed, 61 insertions(+), 40 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index ecbaa2f466..b5c0844fbf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -390,48 +390,69 @@ abstract class DeclarativeAggregate } } + /** * Aggregation function which allows **arbitrary** user-defined java object to be used as internal - * aggregation buffer object. + * aggregation buffer. * * {{{ - * aggregation buffer for normal aggregation function `avg` - * | - * v - * +--------------+---------------+-----------------------------------+ - * | sum1 (Long) | count1 (Long) | generic user-defined java objects | - * +--------------+---------------+-----------------------------------+ - * ^ - * | - * Aggregation buffer object for `TypedImperativeAggregate` aggregation function + * aggregation buffer for normal aggregation function `avg` aggregate buffer for `sum` + * | | + * v v + * +--------------+---------------+-----------------------------------+-------------+ + * | sum1 (Long) | count1 (Long) | generic user-defined java objects | sum2 (Long) | + * +--------------+---------------+-----------------------------------+-------------+ + * ^ + * | + * aggregation buffer object for `TypedImperativeAggregate` aggregation function * }}} * - * Work flow (Partial mode aggregate at Mapper side, and Final mode aggregate at Reducer side): + * General work flow: + * + * Stage 1: initialize aggregate buffer object. + * + * 1. The framework calls `initialize(buffer: MutableRow)` to set up the empty aggregate buffer. + * 2. In `initialize`, we call `createAggregationBuffer(): T` to get the initial buffer object, + * and set it to the global buffer row. + * + * + * Stage 2: process input rows. * - * Stage 1: Partial aggregate at Mapper side: + * If the aggregate mode is `Partial` or `Complete`: + * 1. The framework calls `update(buffer: MutableRow, input: InternalRow)` to process the input + * row. + * 2. In `update`, we get the buffer object from the global buffer row and call + * `update(buffer: T, input: InternalRow): Unit`. * - * 1. The framework calls `createAggregationBuffer(): T` to create an empty internal aggregation - * buffer object. - * 2. Upon each input row, the framework calls - * `update(buffer: T, input: InternalRow): Unit` to update the aggregation buffer object T. - * 3. After processing all rows of current group (group by key), the framework will serialize - * aggregation buffer object T to storage format (Array[Byte]) and persist the Array[Byte] - * to disk if needed. - * 4. The framework moves on to next group, until all groups have been processed. + * If the aggregate mode is `PartialMerge` or `Final`: + * 1. The framework call `merge(buffer: MutableRow, inputBuffer: InternalRow)` to process the + * input row, which are serialized buffer objects shuffled from other nodes. + * 2. In `merge`, we get the buffer object from the global buffer row, and get the binary data + * from input row and deserialize it to buffer object, then we call + * `merge(buffer: T, input: T): Unit` to merge these 2 buffer objects. * - * Shuffling exchange data to Reducer tasks... * - * Stage 2: Final mode aggregate at Reducer side: + * Stage 3: output results. + * + * If the aggregate mode is `Partial` or `PartialMerge`: + * 1. The framework calls `serializeAggregateBufferInPlace` to replace the buffer object in the + * global buffer row with binary data. + * 2. In `serializeAggregateBufferInPlace`, we get the buffer object from the global buffer row + * and call `serialize(buffer: T): Array[Byte]` to serialize the buffer object to binary. + * 3. The framework outputs buffer attributes and shuffle them to other nodes. + * + * If the aggregate mode is `Final` or `Complete`: + * 1. The framework calls `eval(buffer: InternalRow)` to calculate the final result. + * 2. In `eval`, we get the buffer object from the global buffer row and call + * `eval(buffer: T): Any` to get the final result. + * 3. The framework outputs these final results. + * + * + * Window function work flow: + * The framework calls `update(buffer: MutableRow, input: InternalRow)` several times and then + * call `eval(buffer: InternalRow)`, so there is no need for window operator to call + * `serializeAggregateBufferInPlace`. * - * 1. The framework calls `createAggregationBuffer(): T` to create an empty internal aggregation - * buffer object (type T) for merging. - * 2. For each aggregation output of Stage 1, The framework de-serializes the storage - * format (Array[Byte]) and produces one input aggregation object (type T). - * 3. For each input aggregation object, the framework calls `merge(buffer: T, input: T): Unit` - * to merge the input aggregation object into aggregation buffer object. - * 4. After processing all input aggregation objects of current group (group by key), the framework - * calls method `eval(buffer: T)` to generate the final output for this group. - * 5. The framework moves on to next group, until all groups have been processed. * * NOTE: SQL with TypedImperativeAggregate functions is planned in sort based aggregation, * instead of hash based aggregation, as TypedImperativeAggregate use BinaryType as aggregation @@ -489,25 +510,23 @@ abstract class TypedImperativeAggregate[T] extends ImperativeAggregate { } final override def update(buffer: MutableRow, input: InternalRow): Unit = { - val bufferObject = getField[T](buffer, mutableAggBufferOffset) - update(bufferObject, input) + update(getBufferObject(buffer), input) } final override def merge(buffer: MutableRow, inputBuffer: InternalRow): Unit = { - val bufferObject = getField[T](buffer, mutableAggBufferOffset) + val bufferObject = getBufferObject(buffer) // The inputBuffer stores serialized aggregation buffer object produced by partial aggregate val inputObject = deserialize(inputBuffer.getBinary(inputAggBufferOffset)) merge(bufferObject, inputObject) } final override def eval(buffer: InternalRow): Any = { - val bufferObject = getField[T](buffer, mutableAggBufferOffset) - eval(bufferObject) + eval(getBufferObject(buffer)) } private[this] val anyObjectType = ObjectType(classOf[AnyRef]) - private def getField[U](input: InternalRow, fieldIndex: Int): U = { - input.get(fieldIndex, anyObjectType).asInstanceOf[U] + private def getBufferObject(bufferRow: InternalRow): T = { + bufferRow.get(mutableAggBufferOffset, anyObjectType).asInstanceOf[T] } final override lazy val aggBufferAttributes: Seq[AttributeReference] = { @@ -524,9 +543,11 @@ abstract class TypedImperativeAggregate[T] extends ImperativeAggregate { * In-place replaces the aggregation buffer object stored at buffer's index * `mutableAggBufferOffset`, with SparkSQL internally supported underlying storage format * (BinaryType). + * + * This is only called when doing Partial or PartialMerge mode aggregation, before the framework + * shuffle out aggregate buffers. */ final def serializeAggregateBufferInPlace(buffer: MutableRow): Unit = { - val bufferObject = getField[T](buffer, mutableAggBufferOffset) - buffer(mutableAggBufferOffset) = serialize(bufferObject) + buffer(mutableAggBufferOffset) = serialize(getBufferObject(buffer)) } } -- cgit v1.2.3