aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSean Zhong <seanzhong@databricks.com>2016-08-25 16:36:16 -0700
committerYin Huai <yhuai@databricks.com>2016-08-25 16:36:16 -0700
commitd96d1515638da20b594f7bfe3cfdb50088f25a04 (patch)
tree69e7803b4f49d0ed03073795843eb95d8f63529f
parent9b5a1d1d53bc4412de3cbc86dc819b0c213229a8 (diff)
downloadspark-d96d1515638da20b594f7bfe3cfdb50088f25a04.tar.gz
spark-d96d1515638da20b594f7bfe3cfdb50088f25a04.tar.bz2
spark-d96d1515638da20b594f7bfe3cfdb50088f25a04.zip
[SPARK-17187][SQL] Supports using arbitrary Java object as internal aggregation buffer object
## What changes were proposed in this pull request? This PR introduces an abstract class `TypedImperativeAggregate` so that an aggregation function of TypedImperativeAggregate can use **arbitrary** user-defined Java object as intermediate aggregation buffer object. **This has advantages like:** 1. It now can support larger category of aggregation functions. For example, it will be much easier to implement aggregation function `percentile_approx`, which has a complex aggregation buffer definition. 2. It can be used to avoid doing serialization/de-serialization for every call of `update` or `merge` when converting domain specific aggregation object to internal Spark-Sql storage format. 3. It is easier to integrate with other existing monoid libraries like algebird, and supports more aggregation functions with high performance. Please see `org.apache.spark.sql.TypedImperativeAggregateSuite.TypedMaxAggregate` to find an example of how to defined a `TypedImperativeAggregate` aggregation function. Please see Java doc of `TypedImperativeAggregate` and Jira ticket SPARK-17187 for more information. ## How was this patch tested? Unit tests. Author: Sean Zhong <seanzhong@databricks.com> Author: Yin Huai <yhuai@databricks.com> Closes #14753 from clockfly/object_aggregation_buffer_try_2.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala141
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala15
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala300
3 files changed, 456 insertions, 0 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index 7a39e568fa..ecbaa2f466 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -389,3 +389,144 @@ abstract class DeclarativeAggregate
def right: AttributeReference = inputAggBufferAttributes(aggBufferAttributes.indexOf(a))
}
}
+
+/**
+ * Aggregation function which allows **arbitrary** user-defined java object to be used as internal
+ * aggregation buffer object.
+ *
+ * {{{
+ * aggregation buffer for normal aggregation function `avg`
+ * |
+ * v
+ * +--------------+---------------+-----------------------------------+
+ * | sum1 (Long) | count1 (Long) | generic user-defined java objects |
+ * +--------------+---------------+-----------------------------------+
+ * ^
+ * |
+ * Aggregation buffer object for `TypedImperativeAggregate` aggregation function
+ * }}}
+ *
+ * Work flow (Partial mode aggregate at Mapper side, and Final mode aggregate at Reducer side):
+ *
+ * Stage 1: Partial aggregate at Mapper side:
+ *
+ * 1. The framework calls `createAggregationBuffer(): T` to create an empty internal aggregation
+ * buffer object.
+ * 2. Upon each input row, the framework calls
+ * `update(buffer: T, input: InternalRow): Unit` to update the aggregation buffer object T.
+ * 3. After processing all rows of current group (group by key), the framework will serialize
+ * aggregation buffer object T to storage format (Array[Byte]) and persist the Array[Byte]
+ * to disk if needed.
+ * 4. The framework moves on to next group, until all groups have been processed.
+ *
+ * Shuffling exchange data to Reducer tasks...
+ *
+ * Stage 2: Final mode aggregate at Reducer side:
+ *
+ * 1. The framework calls `createAggregationBuffer(): T` to create an empty internal aggregation
+ * buffer object (type T) for merging.
+ * 2. For each aggregation output of Stage 1, The framework de-serializes the storage
+ * format (Array[Byte]) and produces one input aggregation object (type T).
+ * 3. For each input aggregation object, the framework calls `merge(buffer: T, input: T): Unit`
+ * to merge the input aggregation object into aggregation buffer object.
+ * 4. After processing all input aggregation objects of current group (group by key), the framework
+ * calls method `eval(buffer: T)` to generate the final output for this group.
+ * 5. The framework moves on to next group, until all groups have been processed.
+ *
+ * NOTE: SQL with TypedImperativeAggregate functions is planned in sort based aggregation,
+ * instead of hash based aggregation, as TypedImperativeAggregate use BinaryType as aggregation
+ * buffer's storage format, which is not supported by hash based aggregation. Hash based
+ * aggregation only support aggregation buffer of mutable types (like LongType, IntType that have
+ * fixed length and can be mutated in place in UnsafeRow)
+ */
+abstract class TypedImperativeAggregate[T] extends ImperativeAggregate {
+
+ /**
+ * Creates an empty aggregation buffer object. This is called before processing each key group
+ * (group by key).
+ *
+ * @return an aggregation buffer object
+ */
+ def createAggregationBuffer(): T
+
+ /**
+ * In-place updates the aggregation buffer object with an input row. buffer = buffer + input.
+ * This is typically called when doing Partial or Complete mode aggregation.
+ *
+ * @param buffer The aggregation buffer object.
+ * @param input an input row
+ */
+ def update(buffer: T, input: InternalRow): Unit
+
+ /**
+ * Merges an input aggregation object into aggregation buffer object. buffer = buffer + input.
+ * This is typically called when doing PartialMerge or Final mode aggregation.
+ *
+ * @param buffer the aggregation buffer object used to store the aggregation result.
+ * @param input an input aggregation object. Input aggregation object can be produced by
+ * de-serializing the partial aggregate's output from Mapper side.
+ */
+ def merge(buffer: T, input: T): Unit
+
+ /**
+ * Generates the final aggregation result value for current key group with the aggregation buffer
+ * object.
+ *
+ * @param buffer aggregation buffer object.
+ * @return The aggregation result of current key group
+ */
+ def eval(buffer: T): Any
+
+ /** Serializes the aggregation buffer object T to Array[Byte] */
+ def serialize(buffer: T): Array[Byte]
+
+ /** De-serializes the serialized format Array[Byte], and produces aggregation buffer object T */
+ def deserialize(storageFormat: Array[Byte]): T
+
+ final override def initialize(buffer: MutableRow): Unit = {
+ val bufferObject = createAggregationBuffer()
+ buffer.update(mutableAggBufferOffset, bufferObject)
+ }
+
+ final override def update(buffer: MutableRow, input: InternalRow): Unit = {
+ val bufferObject = getField[T](buffer, mutableAggBufferOffset)
+ update(bufferObject, input)
+ }
+
+ final override def merge(buffer: MutableRow, inputBuffer: InternalRow): Unit = {
+ val bufferObject = getField[T](buffer, mutableAggBufferOffset)
+ // The inputBuffer stores serialized aggregation buffer object produced by partial aggregate
+ val inputObject = deserialize(inputBuffer.getBinary(inputAggBufferOffset))
+ merge(bufferObject, inputObject)
+ }
+
+ final override def eval(buffer: InternalRow): Any = {
+ val bufferObject = getField[T](buffer, mutableAggBufferOffset)
+ eval(bufferObject)
+ }
+
+ private[this] val anyObjectType = ObjectType(classOf[AnyRef])
+ private def getField[U](input: InternalRow, fieldIndex: Int): U = {
+ input.get(fieldIndex, anyObjectType).asInstanceOf[U]
+ }
+
+ final override lazy val aggBufferAttributes: Seq[AttributeReference] = {
+ // Underlying storage type for the aggregation buffer object
+ Seq(AttributeReference("buf", BinaryType)())
+ }
+
+ final override lazy val inputAggBufferAttributes: Seq[AttributeReference] =
+ aggBufferAttributes.map(_.newInstance())
+
+ final override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)
+
+ /**
+ * In-place replaces the aggregation buffer object stored at buffer's index
+ * `mutableAggBufferOffset`, with SparkSQL internally supported underlying storage format
+ * (BinaryType).
+ */
+ final def serializeAggregateBufferInPlace(buffer: MutableRow): Unit = {
+ val bufferObject = getField[T](buffer, mutableAggBufferOffset)
+ buffer(mutableAggBufferOffset) = serialize(bufferObject)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
index 34de76dd4a..dfed084fe6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
@@ -234,7 +234,22 @@ abstract class AggregationIterator(
val resultProjection = UnsafeProjection.create(
groupingAttributes ++ bufferAttributes,
groupingAttributes ++ bufferAttributes)
+
+ // TypedImperativeAggregate stores generic object in aggregation buffer, and requires
+ // calling serialization before shuffling. See [[TypedImperativeAggregate]] for more info.
+ val typedImperativeAggregates: Array[TypedImperativeAggregate[_]] = {
+ aggregateFunctions.collect {
+ case (ag: TypedImperativeAggregate[_]) => ag
+ }
+ }
+
(currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => {
+ // Serializes the generic object stored in aggregation buffer
+ var i = 0
+ while (i < typedImperativeAggregates.length) {
+ typedImperativeAggregates(i).serializeAggregateBufferInPlace(currentBuffer)
+ i += 1
+ }
resultProjection(joinedRow(currentGroupingKey, currentBuffer))
}
} else {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala
new file mode 100644
index 0000000000..b5eb16b6f6
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala
@@ -0,0 +1,300 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
+
+import org.apache.spark.sql.TypedImperativeAggregateSuite.TypedMax
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, GenericMutableRow, SpecificMutableRow}
+import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
+import org.apache.spark.sql.execution.aggregate.SortAggregateExec
+import org.apache.spark.sql.expressions.Window
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.{AbstractDataType, BinaryType, DataType, IntegerType, LongType}
+
+class TypedImperativeAggregateSuite extends QueryTest with SharedSQLContext {
+
+ import testImplicits._
+
+ private val random = new java.util.Random()
+
+ private val data = (0 until 1000).map { _ =>
+ (random.nextInt(10), random.nextInt(100))
+ }
+
+ test("aggregate with object aggregate buffer") {
+ val agg = new TypedMax(BoundReference(0, IntegerType, nullable = false))
+
+ val group1 = (0 until data.length / 2)
+ val group1Buffer = agg.createAggregationBuffer()
+ group1.foreach { index =>
+ val input = InternalRow(data(index)._1, data(index)._2)
+ agg.update(group1Buffer, input)
+ }
+
+ val group2 = (data.length / 2 until data.length)
+ val group2Buffer = agg.createAggregationBuffer()
+ group2.foreach { index =>
+ val input = InternalRow(data(index)._1, data(index)._2)
+ agg.update(group2Buffer, input)
+ }
+
+ val mergeBuffer = agg.createAggregationBuffer()
+ agg.merge(mergeBuffer, group1Buffer)
+ agg.merge(mergeBuffer, group2Buffer)
+
+ assert(mergeBuffer.value == data.map(_._1).max)
+ assert(agg.eval(mergeBuffer) == data.map(_._1).max)
+
+ // Tests low level eval(row: InternalRow) API.
+ val row = new GenericMutableRow(Array(mergeBuffer): Array[Any])
+
+ // Evaluates directly on row consist of aggregation buffer object.
+ assert(agg.eval(row) == data.map(_._1).max)
+ }
+
+ test("supports SpecificMutableRow as mutable row") {
+ val aggregationBufferSchema = Seq(IntegerType, LongType, BinaryType, IntegerType)
+ val aggBufferOffset = 2
+ val buffer = new SpecificMutableRow(aggregationBufferSchema)
+ val agg = new TypedMax(BoundReference(ordinal = 1, dataType = IntegerType, nullable = false))
+ .withNewMutableAggBufferOffset(aggBufferOffset)
+
+ agg.initialize(buffer)
+ data.foreach { kv =>
+ val input = InternalRow(kv._1, kv._2)
+ agg.update(buffer, input)
+ }
+ assert(agg.eval(buffer) == data.map(_._2).max)
+ }
+
+ test("dataframe aggregate with object aggregate buffer, should not use HashAggregate") {
+ val df = data.toDF("a", "b")
+ val max = new TypedMax($"a".expr)
+
+ // Always uses SortAggregateExec
+ val sparkPlan = df.select(Column(max.toAggregateExpression())).queryExecution.sparkPlan
+ assert(sparkPlan.isInstanceOf[SortAggregateExec])
+ }
+
+ test("dataframe aggregate with object aggregate buffer, no group by") {
+ val df = data.toDF("key", "value").coalesce(2)
+ val query = df.select(typedMax($"key"), count($"key"), typedMax($"value"), count($"value"))
+ val maxKey = data.map(_._1).max
+ val countKey = data.size
+ val maxValue = data.map(_._2).max
+ val countValue = data.size
+ val expected = Seq(Row(maxKey, countKey, maxValue, countValue))
+ checkAnswer(query, expected)
+ }
+
+ test("dataframe aggregate with object aggregate buffer, non-nullable aggregator") {
+ val df = data.toDF("key", "value").coalesce(2)
+
+ // Test non-nullable typedMax
+ val query = df.select(typedMax(lit(null)), count($"key"), typedMax(lit(null)),
+ count($"value"))
+
+ // typedMax is not nullable
+ val maxNull = Int.MinValue
+ val countKey = data.size
+ val countValue = data.size
+ val expected = Seq(Row(maxNull, countKey, maxNull, countValue))
+ checkAnswer(query, expected)
+ }
+
+ test("dataframe aggregate with object aggregate buffer, nullable aggregator") {
+ val df = data.toDF("key", "value").coalesce(2)
+
+ // Test nullable nullableTypedMax
+ val query = df.select(nullableTypedMax(lit(null)), count($"key"), nullableTypedMax(lit(null)),
+ count($"value"))
+
+ // nullableTypedMax is nullable
+ val maxNull = null
+ val countKey = data.size
+ val countValue = data.size
+ val expected = Seq(Row(maxNull, countKey, maxNull, countValue))
+ checkAnswer(query, expected)
+ }
+
+ test("dataframe aggregation with object aggregate buffer, input row contains null") {
+
+ val nullableData = (0 until 1000).map {id =>
+ val nullableKey: Integer = if (random.nextBoolean()) null else random.nextInt(100)
+ val nullableValue: Integer = if (random.nextBoolean()) null else random.nextInt(100)
+ (nullableKey, nullableValue)
+ }
+
+ val df = nullableData.toDF("key", "value").coalesce(2)
+ val query = df.select(typedMax($"key"), count($"key"), typedMax($"value"),
+ count($"value"))
+ val maxKey = nullableData.map(_._1).filter(_ != null).max
+ val countKey = nullableData.map(_._1).filter(_ != null).size
+ val maxValue = nullableData.map(_._2).filter(_ != null).max
+ val countValue = nullableData.map(_._2).filter(_ != null).size
+ val expected = Seq(Row(maxKey, countKey, maxValue, countValue))
+ checkAnswer(query, expected)
+ }
+
+ test("dataframe aggregate with object aggregate buffer, with group by") {
+ val df = data.toDF("value", "key").coalesce(2)
+ val query = df.groupBy($"key").agg(typedMax($"value"), count($"value"), typedMax($"value"))
+ val expected = data.groupBy(_._2).toSeq.map { group =>
+ val (key, values) = group
+ val valueMax = values.map(_._1).max
+ val countValue = values.size
+ Row(key, valueMax, countValue, valueMax)
+ }
+ checkAnswer(query, expected)
+ }
+
+ test("dataframe aggregate with object aggregate buffer, empty inputs, no group by") {
+ val empty = Seq.empty[(Int, Int)].toDF("a", "b")
+ checkAnswer(
+ empty.select(typedMax($"a"), count($"a"), typedMax($"b"), count($"b")),
+ Seq(Row(Int.MinValue, 0, Int.MinValue, 0)))
+ }
+
+ test("dataframe aggregate with object aggregate buffer, empty inputs, with group by") {
+ val empty = Seq.empty[(Int, Int)].toDF("a", "b")
+ checkAnswer(
+ empty.groupBy($"b").agg(typedMax($"a"), count($"a"), typedMax($"a")),
+ Seq.empty[Row])
+ }
+
+ test("TypedImperativeAggregate should not break Window function") {
+ val df = data.toDF("key", "value")
+ // OVER (PARTITION BY a ORDER BY b ROW BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)
+ val w = Window.orderBy("value").partitionBy("key").rowsBetween(Long.MinValue, 0)
+
+ val query = df.select(sum($"key").over(w), typedMax($"key").over(w), sum($"value").over(w),
+ typedMax($"value").over(w))
+
+ val expected = data.groupBy(_._1).toSeq.flatMap { group =>
+ val (key, values) = group
+ val sortedValues = values.map(_._2).sorted
+
+ var outputRows = Seq.empty[Row]
+ var i = 0
+ while (i < sortedValues.size) {
+ val unboundedPrecedingAndCurrent = sortedValues.slice(0, i + 1)
+ val sumKey = key * unboundedPrecedingAndCurrent.size
+ val maxKey = key
+ val sumValue = unboundedPrecedingAndCurrent.sum
+ val maxValue = unboundedPrecedingAndCurrent.max
+
+ outputRows :+= Row(sumKey, maxKey, sumValue, maxValue)
+ i += 1
+ }
+
+ outputRows
+ }
+ checkAnswer(query, expected)
+ }
+
+ private def typedMax(column: Column): Column = {
+ val max = TypedMax(column.expr, nullable = false)
+ Column(max.toAggregateExpression())
+ }
+
+ private def nullableTypedMax(column: Column): Column = {
+ val max = TypedMax(column.expr, nullable = true)
+ Column(max.toAggregateExpression())
+ }
+}
+
+object TypedImperativeAggregateSuite {
+
+ /**
+ * Calculate the max value with object aggregation buffer. This stores class MaxValue
+ * in aggregation buffer.
+ */
+ private case class TypedMax(
+ child: Expression,
+ nullable: Boolean = false,
+ mutableAggBufferOffset: Int = 0,
+ inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[MaxValue] {
+
+
+ override def createAggregationBuffer(): MaxValue = {
+ // Returns Int.MinValue if all inputs are null
+ new MaxValue(Int.MinValue)
+ }
+
+ override def update(buffer: MaxValue, input: InternalRow): Unit = {
+ child.eval(input) match {
+ case inputValue: Int =>
+ if (inputValue > buffer.value) {
+ buffer.value = inputValue
+ buffer.isValueSet = true
+ }
+ case null => // skip
+ }
+ }
+
+ override def merge(bufferMax: MaxValue, inputMax: MaxValue): Unit = {
+ if (inputMax.value > bufferMax.value) {
+ bufferMax.value = inputMax.value
+ bufferMax.isValueSet = bufferMax.isValueSet || inputMax.isValueSet
+ }
+ }
+
+ override def eval(bufferMax: MaxValue): Any = {
+ if (nullable && bufferMax.isValueSet == false) {
+ null
+ } else {
+ bufferMax.value
+ }
+ }
+
+ override def deterministic: Boolean = true
+
+ override def children: Seq[Expression] = Seq(child)
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType)
+
+ override def dataType: DataType = IntegerType
+
+ override def withNewMutableAggBufferOffset(newOffset: Int): TypedImperativeAggregate[MaxValue] =
+ copy(mutableAggBufferOffset = newOffset)
+
+ override def withNewInputAggBufferOffset(newOffset: Int): TypedImperativeAggregate[MaxValue] =
+ copy(inputAggBufferOffset = newOffset)
+
+ override def serialize(buffer: MaxValue): Array[Byte] = {
+ val out = new ByteArrayOutputStream()
+ val stream = new DataOutputStream(out)
+ stream.writeBoolean(buffer.isValueSet)
+ stream.writeInt(buffer.value)
+ out.toByteArray
+ }
+
+ override def deserialize(storageFormat: Array[Byte]): MaxValue = {
+ val in = new ByteArrayInputStream(storageFormat)
+ val stream = new DataInputStream(in)
+ val isValueSet = stream.readBoolean()
+ val value = stream.readInt()
+ new MaxValue(value, isValueSet)
+ }
+ }
+
+ private class MaxValue(var value: Int, var isValueSet: Boolean = false)
+}