aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorSean Zhong <seanzhong@databricks.com>2016-08-25 16:36:16 -0700
committerYin Huai <yhuai@databricks.com>2016-08-25 16:36:16 -0700
commitd96d1515638da20b594f7bfe3cfdb50088f25a04 (patch)
tree69e7803b4f49d0ed03073795843eb95d8f63529f /sql/catalyst
parent9b5a1d1d53bc4412de3cbc86dc819b0c213229a8 (diff)
downloadspark-d96d1515638da20b594f7bfe3cfdb50088f25a04.tar.gz
spark-d96d1515638da20b594f7bfe3cfdb50088f25a04.tar.bz2
spark-d96d1515638da20b594f7bfe3cfdb50088f25a04.zip
[SPARK-17187][SQL] Supports using arbitrary Java object as internal aggregation buffer object
## What changes were proposed in this pull request? This PR introduces an abstract class `TypedImperativeAggregate` so that an aggregation function of TypedImperativeAggregate can use **arbitrary** user-defined Java object as intermediate aggregation buffer object. **This has advantages like:** 1. It now can support larger category of aggregation functions. For example, it will be much easier to implement aggregation function `percentile_approx`, which has a complex aggregation buffer definition. 2. It can be used to avoid doing serialization/de-serialization for every call of `update` or `merge` when converting domain specific aggregation object to internal Spark-Sql storage format. 3. It is easier to integrate with other existing monoid libraries like algebird, and supports more aggregation functions with high performance. Please see `org.apache.spark.sql.TypedImperativeAggregateSuite.TypedMaxAggregate` to find an example of how to defined a `TypedImperativeAggregate` aggregation function. Please see Java doc of `TypedImperativeAggregate` and Jira ticket SPARK-17187 for more information. ## How was this patch tested? Unit tests. Author: Sean Zhong <seanzhong@databricks.com> Author: Yin Huai <yhuai@databricks.com> Closes #14753 from clockfly/object_aggregation_buffer_try_2.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala141
1 files changed, 141 insertions, 0 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index 7a39e568fa..ecbaa2f466 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -389,3 +389,144 @@ abstract class DeclarativeAggregate
def right: AttributeReference = inputAggBufferAttributes(aggBufferAttributes.indexOf(a))
}
}
+
+/**
+ * Aggregation function which allows **arbitrary** user-defined java object to be used as internal
+ * aggregation buffer object.
+ *
+ * {{{
+ * aggregation buffer for normal aggregation function `avg`
+ * |
+ * v
+ * +--------------+---------------+-----------------------------------+
+ * | sum1 (Long) | count1 (Long) | generic user-defined java objects |
+ * +--------------+---------------+-----------------------------------+
+ * ^
+ * |
+ * Aggregation buffer object for `TypedImperativeAggregate` aggregation function
+ * }}}
+ *
+ * Work flow (Partial mode aggregate at Mapper side, and Final mode aggregate at Reducer side):
+ *
+ * Stage 1: Partial aggregate at Mapper side:
+ *
+ * 1. The framework calls `createAggregationBuffer(): T` to create an empty internal aggregation
+ * buffer object.
+ * 2. Upon each input row, the framework calls
+ * `update(buffer: T, input: InternalRow): Unit` to update the aggregation buffer object T.
+ * 3. After processing all rows of current group (group by key), the framework will serialize
+ * aggregation buffer object T to storage format (Array[Byte]) and persist the Array[Byte]
+ * to disk if needed.
+ * 4. The framework moves on to next group, until all groups have been processed.
+ *
+ * Shuffling exchange data to Reducer tasks...
+ *
+ * Stage 2: Final mode aggregate at Reducer side:
+ *
+ * 1. The framework calls `createAggregationBuffer(): T` to create an empty internal aggregation
+ * buffer object (type T) for merging.
+ * 2. For each aggregation output of Stage 1, The framework de-serializes the storage
+ * format (Array[Byte]) and produces one input aggregation object (type T).
+ * 3. For each input aggregation object, the framework calls `merge(buffer: T, input: T): Unit`
+ * to merge the input aggregation object into aggregation buffer object.
+ * 4. After processing all input aggregation objects of current group (group by key), the framework
+ * calls method `eval(buffer: T)` to generate the final output for this group.
+ * 5. The framework moves on to next group, until all groups have been processed.
+ *
+ * NOTE: SQL with TypedImperativeAggregate functions is planned in sort based aggregation,
+ * instead of hash based aggregation, as TypedImperativeAggregate use BinaryType as aggregation
+ * buffer's storage format, which is not supported by hash based aggregation. Hash based
+ * aggregation only support aggregation buffer of mutable types (like LongType, IntType that have
+ * fixed length and can be mutated in place in UnsafeRow)
+ */
+abstract class TypedImperativeAggregate[T] extends ImperativeAggregate {
+
+ /**
+ * Creates an empty aggregation buffer object. This is called before processing each key group
+ * (group by key).
+ *
+ * @return an aggregation buffer object
+ */
+ def createAggregationBuffer(): T
+
+ /**
+ * In-place updates the aggregation buffer object with an input row. buffer = buffer + input.
+ * This is typically called when doing Partial or Complete mode aggregation.
+ *
+ * @param buffer The aggregation buffer object.
+ * @param input an input row
+ */
+ def update(buffer: T, input: InternalRow): Unit
+
+ /**
+ * Merges an input aggregation object into aggregation buffer object. buffer = buffer + input.
+ * This is typically called when doing PartialMerge or Final mode aggregation.
+ *
+ * @param buffer the aggregation buffer object used to store the aggregation result.
+ * @param input an input aggregation object. Input aggregation object can be produced by
+ * de-serializing the partial aggregate's output from Mapper side.
+ */
+ def merge(buffer: T, input: T): Unit
+
+ /**
+ * Generates the final aggregation result value for current key group with the aggregation buffer
+ * object.
+ *
+ * @param buffer aggregation buffer object.
+ * @return The aggregation result of current key group
+ */
+ def eval(buffer: T): Any
+
+ /** Serializes the aggregation buffer object T to Array[Byte] */
+ def serialize(buffer: T): Array[Byte]
+
+ /** De-serializes the serialized format Array[Byte], and produces aggregation buffer object T */
+ def deserialize(storageFormat: Array[Byte]): T
+
+ final override def initialize(buffer: MutableRow): Unit = {
+ val bufferObject = createAggregationBuffer()
+ buffer.update(mutableAggBufferOffset, bufferObject)
+ }
+
+ final override def update(buffer: MutableRow, input: InternalRow): Unit = {
+ val bufferObject = getField[T](buffer, mutableAggBufferOffset)
+ update(bufferObject, input)
+ }
+
+ final override def merge(buffer: MutableRow, inputBuffer: InternalRow): Unit = {
+ val bufferObject = getField[T](buffer, mutableAggBufferOffset)
+ // The inputBuffer stores serialized aggregation buffer object produced by partial aggregate
+ val inputObject = deserialize(inputBuffer.getBinary(inputAggBufferOffset))
+ merge(bufferObject, inputObject)
+ }
+
+ final override def eval(buffer: InternalRow): Any = {
+ val bufferObject = getField[T](buffer, mutableAggBufferOffset)
+ eval(bufferObject)
+ }
+
+ private[this] val anyObjectType = ObjectType(classOf[AnyRef])
+ private def getField[U](input: InternalRow, fieldIndex: Int): U = {
+ input.get(fieldIndex, anyObjectType).asInstanceOf[U]
+ }
+
+ final override lazy val aggBufferAttributes: Seq[AttributeReference] = {
+ // Underlying storage type for the aggregation buffer object
+ Seq(AttributeReference("buf", BinaryType)())
+ }
+
+ final override lazy val inputAggBufferAttributes: Seq[AttributeReference] =
+ aggBufferAttributes.map(_.newInstance())
+
+ final override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)
+
+ /**
+ * In-place replaces the aggregation buffer object stored at buffer's index
+ * `mutableAggBufferOffset`, with SparkSQL internally supported underlying storage format
+ * (BinaryType).
+ */
+ final def serializeAggregateBufferInPlace(buffer: MutableRow): Unit = {
+ val bufferObject = getField[T](buffer, mutableAggBufferOffset)
+ buffer(mutableAggBufferOffset) = serialize(bufferObject)
+ }
+}