diff options
author | Wenchen Fan <wenchen@databricks.com> | 2016-12-26 22:10:20 +0800 |
---|---|---|
committer | Wenchen Fan <wenchen@databricks.com> | 2016-12-26 22:10:20 +0800 |
commit | 8a7db8a608a9e27b10f205cc1b4ed5f2c3e83799 (patch) | |
tree | d7ce89b11fb0dbb9becb7fbf2e3358afc34fd0b3 /sql/hive | |
parent | 7026ee23e0a684e13f9d7dfbb8f85e810106d022 (diff) | |
download | spark-8a7db8a608a9e27b10f205cc1b4ed5f2c3e83799.tar.gz spark-8a7db8a608a9e27b10f205cc1b4ed5f2c3e83799.tar.bz2 spark-8a7db8a608a9e27b10f205cc1b4ed5f2c3e83799.zip |
[SPARK-18980][SQL] implement Aggregator with TypedImperativeAggregate
## What changes were proposed in this pull request?
Currently we implement `Aggregator` with `DeclarativeAggregate`, which will serialize/deserialize the buffer object every time we process an input.
This PR implements `Aggregator` with `TypedImperativeAggregate` and avoids to serialize/deserialize buffer object many times. The benchmark shows we get about 2 times speed up.
For simple buffer object that doesn't need serialization, we still go with `DeclarativeAggregate`, to avoid performance regression.
## How was this patch tested?
N/A
Author: Wenchen Fan <wenchen@databricks.com>
Closes #16383 from cloud-fan/aggregator.
Diffstat (limited to 'sql/hive')
-rw-r--r-- | sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala | 6 | ||||
-rw-r--r-- | sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala | 6 |
2 files changed, 8 insertions, 4 deletions
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 26dc372d7c..fcefd69272 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -397,17 +397,19 @@ private[hive] case class HiveUDAFFunction( @transient private lazy val inputProjection = UnsafeProjection.create(children) - override def update(buffer: AggregationBuffer, input: InternalRow): Unit = { + override def update(buffer: AggregationBuffer, input: InternalRow): AggregationBuffer = { partial1ModeEvaluator.iterate( buffer, wrap(inputProjection(input), inputWrappers, cached, inputDataTypes)) + buffer } - override def merge(buffer: AggregationBuffer, input: AggregationBuffer): Unit = { + override def merge(buffer: AggregationBuffer, input: AggregationBuffer): AggregationBuffer = { // The 2nd argument of the Hive `GenericUDAFEvaluator.merge()` method is an input aggregation // buffer in the 3rd format mentioned in the ScalaDoc of this class. Originally, Hive converts // this `AggregationBuffer`s into this format before shuffling partial aggregation results, and // calls `GenericUDAFEvaluator.terminatePartial()` to do the conversion. partial2ModeEvaluator.merge(buffer, partial1ModeEvaluator.terminatePartial(input)) + buffer } override def eval(buffer: AggregationBuffer): Any = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala index d27287bad0..aaf1db65a6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala @@ -46,14 +46,16 @@ case class TestingTypedCount( override def createAggregationBuffer(): State = TestingTypedCount.State(0L) - override def update(buffer: State, input: InternalRow): Unit = { + override def update(buffer: State, input: InternalRow): State = { if (child.eval(input) != null) { buffer.count += 1 } + buffer } - override def merge(buffer: State, input: State): Unit = { + override def merge(buffer: State, input: State): State = { buffer.count += input.count + buffer } override def eval(buffer: State): Any = buffer.count |