diff options
author | Cheng Lian <lian@databricks.com> | 2016-11-16 14:32:36 -0800 |
---|---|---|
committer | Yin Huai <yhuai@databricks.com> | 2016-11-16 14:32:36 -0800 |
commit | 2ca8ae9aa1b29bf1f46d0b656d9885e438e67f53 (patch) | |
tree | d254abf510f28e509b15fd97e7457e0b2ed66b27 /sql/hive/src/test | |
parent | a36a76ac43c36a3b897a748bd9f138b629dbc684 (diff) | |
download | spark-2ca8ae9aa1b29bf1f46d0b656d9885e438e67f53.tar.gz spark-2ca8ae9aa1b29bf1f46d0b656d9885e438e67f53.tar.bz2 spark-2ca8ae9aa1b29bf1f46d0b656d9885e438e67f53.zip |
[SPARK-18186] Migrate HiveUDAFFunction to TypedImperativeAggregate for partial aggregation support
## What changes were proposed in this pull request?
While being evaluated in Spark SQL, Hive UDAFs don't support partial aggregation. This PR migrates `HiveUDAFFunction`s to `TypedImperativeAggregate`, which already provides partial aggregation support for aggregate functions that may use arbitrary Java objects as aggregation states.
The following snippet shows the effect of this PR:
```scala
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax
sql(s"CREATE FUNCTION hive_max AS '${classOf[GenericUDAFMax].getName}'")
spark.range(100).createOrReplaceTempView("t")
// A query using both Spark SQL native `max` and Hive `max`
sql(s"SELECT max(id), hive_max(id) FROM t").explain()
```
Before this PR:
```
== Physical Plan ==
SortAggregate(key=[], functions=[max(id#1L), default.hive_max(default.hive_max, HiveFunctionWrapper(org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax,org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax7475f57e), id#1L, false, 0, 0)])
+- Exchange SinglePartition
+- *Range (0, 100, step=1, splits=Some(1))
```
After this PR:
```
== Physical Plan ==
SortAggregate(key=[], functions=[max(id#1L), default.hive_max(default.hive_max, HiveFunctionWrapper(org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax,org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax5e18a6a7), id#1L, false, 0, 0)])
+- Exchange SinglePartition
+- SortAggregate(key=[], functions=[partial_max(id#1L), partial_default.hive_max(default.hive_max, HiveFunctionWrapper(org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax,org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax5e18a6a7), id#1L, false, 0, 0)])
+- *Range (0, 100, step=1, splits=Some(1))
```
The tricky part of the PR is mostly about updating and passing around aggregation states of `HiveUDAFFunction`s since the aggregation state of a Hive UDAF may appear in three different forms. Let's take a look at the testing `MockUDAF` added in this PR as an example. This UDAF computes the count of non-null values together with the count of nulls of a given column. Its aggregation state may appear as the following forms at different time:
1. A `MockUDAFBuffer`, which is a concrete subclass of `GenericUDAFEvaluator.AggregationBuffer`
The form used by Hive UDAF API. This form is required by the following scenarios:
- Calling `GenericUDAFEvaluator.iterate()` to update an existing aggregation state with new input values.
- Calling `GenericUDAFEvaluator.terminate()` to get the final aggregated value from an existing aggregation state.
- Calling `GenericUDAFEvaluator.merge()` to merge other aggregation states into an existing aggregation state.
The existing aggregation state to be updated must be in this form.
Conversions:
- To form 2:
`GenericUDAFEvaluator.terminatePartial()`
- To form 3:
Convert to form 2 first, and then to 3.
2. An `Object[]` array containing two `java.lang.Long` values.
The form used to interact with Hive's `ObjectInspector`s. This form is required by the following scenarios:
- Calling `GenericUDAFEvaluator.terminatePartial()` to convert an existing aggregation state in form 1 to form 2.
- Calling `GenericUDAFEvaluator.merge()` to merge other aggregation states into an existing aggregation state.
The input aggregation state must be in this form.
Conversions:
- To form 1:
No direct method. Have to create an empty `AggregationBuffer` and merge it into the empty buffer.
- To form 3:
`unwrapperFor()`/`unwrap()` method of `HiveInspectors`
3. The byte array that holds data of an `UnsafeRow` with two `LongType` fields.
The form used by Spark SQL to shuffle partial aggregation results. This form is required because `TypedImperativeAggregate` always asks its subclasses to serialize their aggregation states into a byte array.
Conversions:
- To form 1:
Convert to form 2 first, and then to 1.
- To form 2:
`wrapperFor()`/`wrap()` method of `HiveInspectors`
Here're some micro-benchmark results produced by the most recent master and this PR branch.
Master:
```
Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.10.5
Intel(R) Core(TM) i7-4960HQ CPU 2.60GHz
hive udaf vs spark af: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
w/o groupBy 339 / 372 3.1 323.2 1.0X
w/ groupBy 503 / 529 2.1 479.7 0.7X
```
This PR:
```
Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.10.5
Intel(R) Core(TM) i7-4960HQ CPU 2.60GHz
hive udaf vs spark af: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
w/o groupBy 116 / 126 9.0 110.8 1.0X
w/ groupBy 151 / 159 6.9 144.0 0.8X
```
Benchmark code snippet:
```scala
test("Hive UDAF benchmark") {
val N = 1 << 20
sparkSession.sql(s"CREATE TEMPORARY FUNCTION hive_max AS '${classOf[GenericUDAFMax].getName}'")
val benchmark = new Benchmark(
name = "hive udaf vs spark af",
valuesPerIteration = N,
minNumIters = 5,
warmupTime = 5.seconds,
minTime = 5.seconds,
outputPerIteration = true
)
benchmark.addCase("w/o groupBy") { _ =>
sparkSession.range(N).agg("id" -> "hive_max").collect()
}
benchmark.addCase("w/ groupBy") { _ =>
sparkSession.range(N).groupBy($"id" % 10).agg("id" -> "hive_max").collect()
}
benchmark.run()
sparkSession.sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_max")
}
```
## How was this patch tested?
New test suite `HiveUDAFSuite` is added.
Author: Cheng Lian <lian@databricks.com>
Closes #15703 from liancheng/partial-agg-hive-udaf.
Diffstat (limited to 'sql/hive/src/test')
-rw-r--r-- | sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala | 152 |
1 files changed, 152 insertions, 0 deletions
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala new file mode 100644 index 0000000000..c9ef72ee11 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, GenericUDAFEvaluator, GenericUDAFMax} +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.{AggregationBuffer, Mode} +import org.apache.hadoop.hive.ql.util.JavaDataModel +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory} +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils + +class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { + import testImplicits._ + + protected override def beforeAll(): Unit = { + sql(s"CREATE TEMPORARY FUNCTION mock AS '${classOf[MockUDAF].getName}'") + sql(s"CREATE TEMPORARY FUNCTION hive_max AS '${classOf[GenericUDAFMax].getName}'") + + Seq( + (0: Integer) -> "val_0", + (1: Integer) -> "val_1", + (2: Integer) -> null, + (3: Integer) -> null + ).toDF("key", "value").repartition(2).createOrReplaceTempView("t") + } + + protected override def afterAll(): Unit = { + sql(s"DROP TEMPORARY FUNCTION IF EXISTS mock") + sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_max") + } + + test("built-in Hive UDAF") { + val df = sql("SELECT key % 2, hive_max(key) FROM t GROUP BY key % 2") + + val aggs = df.queryExecution.executedPlan.collect { + case agg: ObjectHashAggregateExec => agg + } + + // There should be two aggregate operators, one for partial aggregation, and the other for + // global aggregation. + assert(aggs.length == 2) + + checkAnswer(df, Seq( + Row(0, 2), + Row(1, 3) + )) + } + + test("customized Hive UDAF") { + val df = sql("SELECT key % 2, mock(value) FROM t GROUP BY key % 2") + + val aggs = df.queryExecution.executedPlan.collect { + case agg: ObjectHashAggregateExec => agg + } + + // There should be two aggregate operators, one for partial aggregation, and the other for + // global aggregation. + assert(aggs.length == 2) + + checkAnswer(df, Seq( + Row(0, Row(1, 1)), + Row(1, Row(1, 1)) + )) + } +} + +/** + * A testing Hive UDAF that computes the counts of both non-null values and nulls of a given column. + */ +class MockUDAF extends AbstractGenericUDAFResolver { + override def getEvaluator(info: Array[TypeInfo]): GenericUDAFEvaluator = new MockUDAFEvaluator +} + +class MockUDAFBuffer(var nonNullCount: Long, var nullCount: Long) + extends GenericUDAFEvaluator.AbstractAggregationBuffer { + + override def estimate(): Int = JavaDataModel.PRIMITIVES2 * 2 +} + +class MockUDAFEvaluator extends GenericUDAFEvaluator { + private val nonNullCountOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector + + private val nullCountOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector + + private val bufferOI = { + val fieldNames = Seq("nonNullCount", "nullCount").asJava + val fieldOIs = Seq(nonNullCountOI: ObjectInspector, nullCountOI: ObjectInspector).asJava + ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs) + } + + private val nonNullCountField = bufferOI.getStructFieldRef("nonNullCount") + + private val nullCountField = bufferOI.getStructFieldRef("nullCount") + + override def getNewAggregationBuffer: AggregationBuffer = new MockUDAFBuffer(0L, 0L) + + override def reset(agg: AggregationBuffer): Unit = { + val buffer = agg.asInstanceOf[MockUDAFBuffer] + buffer.nonNullCount = 0L + buffer.nullCount = 0L + } + + override def init(mode: Mode, parameters: Array[ObjectInspector]): ObjectInspector = bufferOI + + override def iterate(agg: AggregationBuffer, parameters: Array[AnyRef]): Unit = { + val buffer = agg.asInstanceOf[MockUDAFBuffer] + if (parameters.head eq null) { + buffer.nullCount += 1L + } else { + buffer.nonNullCount += 1L + } + } + + override def merge(agg: AggregationBuffer, partial: Object): Unit = { + if (partial ne null) { + val nonNullCount = nonNullCountOI.get(bufferOI.getStructFieldData(partial, nonNullCountField)) + val nullCount = nullCountOI.get(bufferOI.getStructFieldData(partial, nullCountField)) + val buffer = agg.asInstanceOf[MockUDAFBuffer] + buffer.nonNullCount += nonNullCount + buffer.nullCount += nullCount + } + } + + override def terminatePartial(agg: AggregationBuffer): AnyRef = { + val buffer = agg.asInstanceOf[MockUDAFBuffer] + Array[Object](buffer.nonNullCount: java.lang.Long, buffer.nullCount: java.lang.Long) + } + + override def terminate(agg: AggregationBuffer): AnyRef = terminatePartial(agg) +} |