aboutsummaryrefslogtreecommitdiff
path: root/sql/hive
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-12-26 22:10:20 +0800
committerWenchen Fan <wenchen@databricks.com>2016-12-26 22:10:20 +0800
commit8a7db8a608a9e27b10f205cc1b4ed5f2c3e83799 (patch)
treed7ce89b11fb0dbb9becb7fbf2e3358afc34fd0b3 /sql/hive
parent7026ee23e0a684e13f9d7dfbb8f85e810106d022 (diff)
downloadspark-8a7db8a608a9e27b10f205cc1b4ed5f2c3e83799.tar.gz
spark-8a7db8a608a9e27b10f205cc1b4ed5f2c3e83799.tar.bz2
spark-8a7db8a608a9e27b10f205cc1b4ed5f2c3e83799.zip
[SPARK-18980][SQL] implement Aggregator with TypedImperativeAggregate
## What changes were proposed in this pull request? Currently we implement `Aggregator` with `DeclarativeAggregate`, which will serialize/deserialize the buffer object every time we process an input. This PR implements `Aggregator` with `TypedImperativeAggregate` and avoids to serialize/deserialize buffer object many times. The benchmark shows we get about 2 times speed up. For simple buffer object that doesn't need serialization, we still go with `DeclarativeAggregate`, to avoid performance regression. ## How was this patch tested? N/A Author: Wenchen Fan <wenchen@databricks.com> Closes #16383 from cloud-fan/aggregator.
Diffstat (limited to 'sql/hive')
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala6
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala6
2 files changed, 8 insertions, 4 deletions
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index 26dc372d7c..fcefd69272 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -397,17 +397,19 @@ private[hive] case class HiveUDAFFunction(
@transient
private lazy val inputProjection = UnsafeProjection.create(children)
- override def update(buffer: AggregationBuffer, input: InternalRow): Unit = {
+ override def update(buffer: AggregationBuffer, input: InternalRow): AggregationBuffer = {
partial1ModeEvaluator.iterate(
buffer, wrap(inputProjection(input), inputWrappers, cached, inputDataTypes))
+ buffer
}
- override def merge(buffer: AggregationBuffer, input: AggregationBuffer): Unit = {
+ override def merge(buffer: AggregationBuffer, input: AggregationBuffer): AggregationBuffer = {
// The 2nd argument of the Hive `GenericUDAFEvaluator.merge()` method is an input aggregation
// buffer in the 3rd format mentioned in the ScalaDoc of this class. Originally, Hive converts
// this `AggregationBuffer`s into this format before shuffling partial aggregation results, and
// calls `GenericUDAFEvaluator.terminatePartial()` to do the conversion.
partial2ModeEvaluator.merge(buffer, partial1ModeEvaluator.terminatePartial(input))
+ buffer
}
override def eval(buffer: AggregationBuffer): Any = {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala
index d27287bad0..aaf1db65a6 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala
@@ -46,14 +46,16 @@ case class TestingTypedCount(
override def createAggregationBuffer(): State = TestingTypedCount.State(0L)
- override def update(buffer: State, input: InternalRow): Unit = {
+ override def update(buffer: State, input: InternalRow): State = {
if (child.eval(input) != null) {
buffer.count += 1
}
+ buffer
}
- override def merge(buffer: State, input: State): Unit = {
+ override def merge(buffer: State, input: State): State = {
buffer.count += input.count
+ buffer
}
override def eval(buffer: State): Any = buffer.count