aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src/main/scala/org/apache
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-08-26 10:56:57 -0700
committerYin Huai <yhuai@databricks.com>2016-08-26 10:56:57 -0700
commit970ab8f6ddc66401ad1cf4b2d1050dd0c8876224 (patch)
tree16a44d484f25041fda36d9b462f7e58ad62b5000 /sql/catalyst/src/main/scala/org/apache
parent28ab17922a227e8d93654d3478c0d493bfb599d5 (diff)
downloadspark-970ab8f6ddc66401ad1cf4b2d1050dd0c8876224.tar.gz
spark-970ab8f6ddc66401ad1cf4b2d1050dd0c8876224.tar.bz2
spark-970ab8f6ddc66401ad1cf4b2d1050dd0c8876224.zip
[SPARK-17187][SQL][FOLLOW-UP] improve document of TypedImperativeAggregate
## What changes were proposed in this pull request? improve the document to make it easier to understand and also mention window operator. ## How was this patch tested? N/A Author: Wenchen Fan <wenchen@databricks.com> Closes #14822 from cloud-fan/object-agg.
Diffstat (limited to 'sql/catalyst/src/main/scala/org/apache')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala101
1 files changed, 61 insertions, 40 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index ecbaa2f466..b5c0844fbf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -390,48 +390,69 @@ abstract class DeclarativeAggregate
}
}
+
/**
* Aggregation function which allows **arbitrary** user-defined java object to be used as internal
- * aggregation buffer object.
+ * aggregation buffer.
*
* {{{
- * aggregation buffer for normal aggregation function `avg`
- * |
- * v
- * +--------------+---------------+-----------------------------------+
- * | sum1 (Long) | count1 (Long) | generic user-defined java objects |
- * +--------------+---------------+-----------------------------------+
- * ^
- * |
- * Aggregation buffer object for `TypedImperativeAggregate` aggregation function
+ * aggregation buffer for normal aggregation function `avg` aggregate buffer for `sum`
+ * | |
+ * v v
+ * +--------------+---------------+-----------------------------------+-------------+
+ * | sum1 (Long) | count1 (Long) | generic user-defined java objects | sum2 (Long) |
+ * +--------------+---------------+-----------------------------------+-------------+
+ * ^
+ * |
+ * aggregation buffer object for `TypedImperativeAggregate` aggregation function
* }}}
*
- * Work flow (Partial mode aggregate at Mapper side, and Final mode aggregate at Reducer side):
+ * General work flow:
+ *
+ * Stage 1: initialize aggregate buffer object.
+ *
+ * 1. The framework calls `initialize(buffer: MutableRow)` to set up the empty aggregate buffer.
+ * 2. In `initialize`, we call `createAggregationBuffer(): T` to get the initial buffer object,
+ * and set it to the global buffer row.
+ *
+ *
+ * Stage 2: process input rows.
*
- * Stage 1: Partial aggregate at Mapper side:
+ * If the aggregate mode is `Partial` or `Complete`:
+ * 1. The framework calls `update(buffer: MutableRow, input: InternalRow)` to process the input
+ * row.
+ * 2. In `update`, we get the buffer object from the global buffer row and call
+ * `update(buffer: T, input: InternalRow): Unit`.
*
- * 1. The framework calls `createAggregationBuffer(): T` to create an empty internal aggregation
- * buffer object.
- * 2. Upon each input row, the framework calls
- * `update(buffer: T, input: InternalRow): Unit` to update the aggregation buffer object T.
- * 3. After processing all rows of current group (group by key), the framework will serialize
- * aggregation buffer object T to storage format (Array[Byte]) and persist the Array[Byte]
- * to disk if needed.
- * 4. The framework moves on to next group, until all groups have been processed.
+ * If the aggregate mode is `PartialMerge` or `Final`:
+ * 1. The framework call `merge(buffer: MutableRow, inputBuffer: InternalRow)` to process the
+ * input row, which are serialized buffer objects shuffled from other nodes.
+ * 2. In `merge`, we get the buffer object from the global buffer row, and get the binary data
+ * from input row and deserialize it to buffer object, then we call
+ * `merge(buffer: T, input: T): Unit` to merge these 2 buffer objects.
*
- * Shuffling exchange data to Reducer tasks...
*
- * Stage 2: Final mode aggregate at Reducer side:
+ * Stage 3: output results.
+ *
+ * If the aggregate mode is `Partial` or `PartialMerge`:
+ * 1. The framework calls `serializeAggregateBufferInPlace` to replace the buffer object in the
+ * global buffer row with binary data.
+ * 2. In `serializeAggregateBufferInPlace`, we get the buffer object from the global buffer row
+ * and call `serialize(buffer: T): Array[Byte]` to serialize the buffer object to binary.
+ * 3. The framework outputs buffer attributes and shuffle them to other nodes.
+ *
+ * If the aggregate mode is `Final` or `Complete`:
+ * 1. The framework calls `eval(buffer: InternalRow)` to calculate the final result.
+ * 2. In `eval`, we get the buffer object from the global buffer row and call
+ * `eval(buffer: T): Any` to get the final result.
+ * 3. The framework outputs these final results.
+ *
+ *
+ * Window function work flow:
+ * The framework calls `update(buffer: MutableRow, input: InternalRow)` several times and then
+ * call `eval(buffer: InternalRow)`, so there is no need for window operator to call
+ * `serializeAggregateBufferInPlace`.
*
- * 1. The framework calls `createAggregationBuffer(): T` to create an empty internal aggregation
- * buffer object (type T) for merging.
- * 2. For each aggregation output of Stage 1, The framework de-serializes the storage
- * format (Array[Byte]) and produces one input aggregation object (type T).
- * 3. For each input aggregation object, the framework calls `merge(buffer: T, input: T): Unit`
- * to merge the input aggregation object into aggregation buffer object.
- * 4. After processing all input aggregation objects of current group (group by key), the framework
- * calls method `eval(buffer: T)` to generate the final output for this group.
- * 5. The framework moves on to next group, until all groups have been processed.
*
* NOTE: SQL with TypedImperativeAggregate functions is planned in sort based aggregation,
* instead of hash based aggregation, as TypedImperativeAggregate use BinaryType as aggregation
@@ -489,25 +510,23 @@ abstract class TypedImperativeAggregate[T] extends ImperativeAggregate {
}
final override def update(buffer: MutableRow, input: InternalRow): Unit = {
- val bufferObject = getField[T](buffer, mutableAggBufferOffset)
- update(bufferObject, input)
+ update(getBufferObject(buffer), input)
}
final override def merge(buffer: MutableRow, inputBuffer: InternalRow): Unit = {
- val bufferObject = getField[T](buffer, mutableAggBufferOffset)
+ val bufferObject = getBufferObject(buffer)
// The inputBuffer stores serialized aggregation buffer object produced by partial aggregate
val inputObject = deserialize(inputBuffer.getBinary(inputAggBufferOffset))
merge(bufferObject, inputObject)
}
final override def eval(buffer: InternalRow): Any = {
- val bufferObject = getField[T](buffer, mutableAggBufferOffset)
- eval(bufferObject)
+ eval(getBufferObject(buffer))
}
private[this] val anyObjectType = ObjectType(classOf[AnyRef])
- private def getField[U](input: InternalRow, fieldIndex: Int): U = {
- input.get(fieldIndex, anyObjectType).asInstanceOf[U]
+ private def getBufferObject(bufferRow: InternalRow): T = {
+ bufferRow.get(mutableAggBufferOffset, anyObjectType).asInstanceOf[T]
}
final override lazy val aggBufferAttributes: Seq[AttributeReference] = {
@@ -524,9 +543,11 @@ abstract class TypedImperativeAggregate[T] extends ImperativeAggregate {
* In-place replaces the aggregation buffer object stored at buffer's index
* `mutableAggBufferOffset`, with SparkSQL internally supported underlying storage format
* (BinaryType).
+ *
+ * This is only called when doing Partial or PartialMerge mode aggregation, before the framework
+ * shuffle out aggregate buffers.
*/
final def serializeAggregateBufferInPlace(buffer: MutableRow): Unit = {
- val bufferObject = getField[T](buffer, mutableAggBufferOffset)
- buffer(mutableAggBufferOffset) = serialize(bufferObject)
+ buffer(mutableAggBufferOffset) = serialize(getBufferObject(buffer))
}
}