From d96d1515638da20b594f7bfe3cfdb50088f25a04 Mon Sep 17 00:00:00 2001 From: Sean Zhong Date: Thu, 25 Aug 2016 16:36:16 -0700 Subject: [SPARK-17187][SQL] Supports using arbitrary Java object as internal aggregation buffer object ## What changes were proposed in this pull request? This PR introduces an abstract class `TypedImperativeAggregate` so that an aggregation function of TypedImperativeAggregate can use **arbitrary** user-defined Java object as intermediate aggregation buffer object. **This has advantages like:** 1. It now can support larger category of aggregation functions. For example, it will be much easier to implement aggregation function `percentile_approx`, which has a complex aggregation buffer definition. 2. It can be used to avoid doing serialization/de-serialization for every call of `update` or `merge` when converting domain specific aggregation object to internal Spark-Sql storage format. 3. It is easier to integrate with other existing monoid libraries like algebird, and supports more aggregation functions with high performance. Please see `org.apache.spark.sql.TypedImperativeAggregateSuite.TypedMaxAggregate` to find an example of how to defined a `TypedImperativeAggregate` aggregation function. Please see Java doc of `TypedImperativeAggregate` and Jira ticket SPARK-17187 for more information. ## How was this patch tested? Unit tests. Author: Sean Zhong Author: Yin Huai Closes #14753 from clockfly/object_aggregation_buffer_try_2. --- .../execution/aggregate/AggregationIterator.scala | 15 ++ .../spark/sql/TypedImperativeAggregateSuite.scala | 300 +++++++++++++++++++++ 2 files changed, 315 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala (limited to 'sql/core/src') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala index 34de76dd4a..dfed084fe6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -234,7 +234,22 @@ abstract class AggregationIterator( val resultProjection = UnsafeProjection.create( groupingAttributes ++ bufferAttributes, groupingAttributes ++ bufferAttributes) + + // TypedImperativeAggregate stores generic object in aggregation buffer, and requires + // calling serialization before shuffling. See [[TypedImperativeAggregate]] for more info. + val typedImperativeAggregates: Array[TypedImperativeAggregate[_]] = { + aggregateFunctions.collect { + case (ag: TypedImperativeAggregate[_]) => ag + } + } + (currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => { + // Serializes the generic object stored in aggregation buffer + var i = 0 + while (i < typedImperativeAggregates.length) { + typedImperativeAggregates(i).serializeAggregateBufferInPlace(currentBuffer) + i += 1 + } resultProjection(joinedRow(currentGroupingKey, currentBuffer)) } } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala new file mode 100644 index 0000000000..b5eb16b6f6 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala @@ -0,0 +1,300 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} + +import org.apache.spark.sql.TypedImperativeAggregateSuite.TypedMax +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, GenericMutableRow, SpecificMutableRow} +import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate +import org.apache.spark.sql.execution.aggregate.SortAggregateExec +import org.apache.spark.sql.expressions.Window +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{AbstractDataType, BinaryType, DataType, IntegerType, LongType} + +class TypedImperativeAggregateSuite extends QueryTest with SharedSQLContext { + + import testImplicits._ + + private val random = new java.util.Random() + + private val data = (0 until 1000).map { _ => + (random.nextInt(10), random.nextInt(100)) + } + + test("aggregate with object aggregate buffer") { + val agg = new TypedMax(BoundReference(0, IntegerType, nullable = false)) + + val group1 = (0 until data.length / 2) + val group1Buffer = agg.createAggregationBuffer() + group1.foreach { index => + val input = InternalRow(data(index)._1, data(index)._2) + agg.update(group1Buffer, input) + } + + val group2 = (data.length / 2 until data.length) + val group2Buffer = agg.createAggregationBuffer() + group2.foreach { index => + val input = InternalRow(data(index)._1, data(index)._2) + agg.update(group2Buffer, input) + } + + val mergeBuffer = agg.createAggregationBuffer() + agg.merge(mergeBuffer, group1Buffer) + agg.merge(mergeBuffer, group2Buffer) + + assert(mergeBuffer.value == data.map(_._1).max) + assert(agg.eval(mergeBuffer) == data.map(_._1).max) + + // Tests low level eval(row: InternalRow) API. + val row = new GenericMutableRow(Array(mergeBuffer): Array[Any]) + + // Evaluates directly on row consist of aggregation buffer object. + assert(agg.eval(row) == data.map(_._1).max) + } + + test("supports SpecificMutableRow as mutable row") { + val aggregationBufferSchema = Seq(IntegerType, LongType, BinaryType, IntegerType) + val aggBufferOffset = 2 + val buffer = new SpecificMutableRow(aggregationBufferSchema) + val agg = new TypedMax(BoundReference(ordinal = 1, dataType = IntegerType, nullable = false)) + .withNewMutableAggBufferOffset(aggBufferOffset) + + agg.initialize(buffer) + data.foreach { kv => + val input = InternalRow(kv._1, kv._2) + agg.update(buffer, input) + } + assert(agg.eval(buffer) == data.map(_._2).max) + } + + test("dataframe aggregate with object aggregate buffer, should not use HashAggregate") { + val df = data.toDF("a", "b") + val max = new TypedMax($"a".expr) + + // Always uses SortAggregateExec + val sparkPlan = df.select(Column(max.toAggregateExpression())).queryExecution.sparkPlan + assert(sparkPlan.isInstanceOf[SortAggregateExec]) + } + + test("dataframe aggregate with object aggregate buffer, no group by") { + val df = data.toDF("key", "value").coalesce(2) + val query = df.select(typedMax($"key"), count($"key"), typedMax($"value"), count($"value")) + val maxKey = data.map(_._1).max + val countKey = data.size + val maxValue = data.map(_._2).max + val countValue = data.size + val expected = Seq(Row(maxKey, countKey, maxValue, countValue)) + checkAnswer(query, expected) + } + + test("dataframe aggregate with object aggregate buffer, non-nullable aggregator") { + val df = data.toDF("key", "value").coalesce(2) + + // Test non-nullable typedMax + val query = df.select(typedMax(lit(null)), count($"key"), typedMax(lit(null)), + count($"value")) + + // typedMax is not nullable + val maxNull = Int.MinValue + val countKey = data.size + val countValue = data.size + val expected = Seq(Row(maxNull, countKey, maxNull, countValue)) + checkAnswer(query, expected) + } + + test("dataframe aggregate with object aggregate buffer, nullable aggregator") { + val df = data.toDF("key", "value").coalesce(2) + + // Test nullable nullableTypedMax + val query = df.select(nullableTypedMax(lit(null)), count($"key"), nullableTypedMax(lit(null)), + count($"value")) + + // nullableTypedMax is nullable + val maxNull = null + val countKey = data.size + val countValue = data.size + val expected = Seq(Row(maxNull, countKey, maxNull, countValue)) + checkAnswer(query, expected) + } + + test("dataframe aggregation with object aggregate buffer, input row contains null") { + + val nullableData = (0 until 1000).map {id => + val nullableKey: Integer = if (random.nextBoolean()) null else random.nextInt(100) + val nullableValue: Integer = if (random.nextBoolean()) null else random.nextInt(100) + (nullableKey, nullableValue) + } + + val df = nullableData.toDF("key", "value").coalesce(2) + val query = df.select(typedMax($"key"), count($"key"), typedMax($"value"), + count($"value")) + val maxKey = nullableData.map(_._1).filter(_ != null).max + val countKey = nullableData.map(_._1).filter(_ != null).size + val maxValue = nullableData.map(_._2).filter(_ != null).max + val countValue = nullableData.map(_._2).filter(_ != null).size + val expected = Seq(Row(maxKey, countKey, maxValue, countValue)) + checkAnswer(query, expected) + } + + test("dataframe aggregate with object aggregate buffer, with group by") { + val df = data.toDF("value", "key").coalesce(2) + val query = df.groupBy($"key").agg(typedMax($"value"), count($"value"), typedMax($"value")) + val expected = data.groupBy(_._2).toSeq.map { group => + val (key, values) = group + val valueMax = values.map(_._1).max + val countValue = values.size + Row(key, valueMax, countValue, valueMax) + } + checkAnswer(query, expected) + } + + test("dataframe aggregate with object aggregate buffer, empty inputs, no group by") { + val empty = Seq.empty[(Int, Int)].toDF("a", "b") + checkAnswer( + empty.select(typedMax($"a"), count($"a"), typedMax($"b"), count($"b")), + Seq(Row(Int.MinValue, 0, Int.MinValue, 0))) + } + + test("dataframe aggregate with object aggregate buffer, empty inputs, with group by") { + val empty = Seq.empty[(Int, Int)].toDF("a", "b") + checkAnswer( + empty.groupBy($"b").agg(typedMax($"a"), count($"a"), typedMax($"a")), + Seq.empty[Row]) + } + + test("TypedImperativeAggregate should not break Window function") { + val df = data.toDF("key", "value") + // OVER (PARTITION BY a ORDER BY b ROW BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) + val w = Window.orderBy("value").partitionBy("key").rowsBetween(Long.MinValue, 0) + + val query = df.select(sum($"key").over(w), typedMax($"key").over(w), sum($"value").over(w), + typedMax($"value").over(w)) + + val expected = data.groupBy(_._1).toSeq.flatMap { group => + val (key, values) = group + val sortedValues = values.map(_._2).sorted + + var outputRows = Seq.empty[Row] + var i = 0 + while (i < sortedValues.size) { + val unboundedPrecedingAndCurrent = sortedValues.slice(0, i + 1) + val sumKey = key * unboundedPrecedingAndCurrent.size + val maxKey = key + val sumValue = unboundedPrecedingAndCurrent.sum + val maxValue = unboundedPrecedingAndCurrent.max + + outputRows :+= Row(sumKey, maxKey, sumValue, maxValue) + i += 1 + } + + outputRows + } + checkAnswer(query, expected) + } + + private def typedMax(column: Column): Column = { + val max = TypedMax(column.expr, nullable = false) + Column(max.toAggregateExpression()) + } + + private def nullableTypedMax(column: Column): Column = { + val max = TypedMax(column.expr, nullable = true) + Column(max.toAggregateExpression()) + } +} + +object TypedImperativeAggregateSuite { + + /** + * Calculate the max value with object aggregation buffer. This stores class MaxValue + * in aggregation buffer. + */ + private case class TypedMax( + child: Expression, + nullable: Boolean = false, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[MaxValue] { + + + override def createAggregationBuffer(): MaxValue = { + // Returns Int.MinValue if all inputs are null + new MaxValue(Int.MinValue) + } + + override def update(buffer: MaxValue, input: InternalRow): Unit = { + child.eval(input) match { + case inputValue: Int => + if (inputValue > buffer.value) { + buffer.value = inputValue + buffer.isValueSet = true + } + case null => // skip + } + } + + override def merge(bufferMax: MaxValue, inputMax: MaxValue): Unit = { + if (inputMax.value > bufferMax.value) { + bufferMax.value = inputMax.value + bufferMax.isValueSet = bufferMax.isValueSet || inputMax.isValueSet + } + } + + override def eval(bufferMax: MaxValue): Any = { + if (nullable && bufferMax.isValueSet == false) { + null + } else { + bufferMax.value + } + } + + override def deterministic: Boolean = true + + override def children: Seq[Expression] = Seq(child) + + override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType) + + override def dataType: DataType = IntegerType + + override def withNewMutableAggBufferOffset(newOffset: Int): TypedImperativeAggregate[MaxValue] = + copy(mutableAggBufferOffset = newOffset) + + override def withNewInputAggBufferOffset(newOffset: Int): TypedImperativeAggregate[MaxValue] = + copy(inputAggBufferOffset = newOffset) + + override def serialize(buffer: MaxValue): Array[Byte] = { + val out = new ByteArrayOutputStream() + val stream = new DataOutputStream(out) + stream.writeBoolean(buffer.isValueSet) + stream.writeInt(buffer.value) + out.toByteArray + } + + override def deserialize(storageFormat: Array[Byte]): MaxValue = { + val in = new ByteArrayInputStream(storageFormat) + val stream = new DataInputStream(in) + val isValueSet = stream.readBoolean() + val value = stream.readInt() + new MaxValue(value, isValueSet) + } + } + + private class MaxValue(var value: Int, var isValueSet: Boolean = false) +} -- cgit v1.2.3