aboutsummaryrefslogtreecommitdiff
path: root/sql/hive/src/test
diff options
context:
space:
mode:
authorCheng Lian <lian@databricks.com>2016-11-16 14:32:36 -0800
committerYin Huai <yhuai@databricks.com>2016-11-16 14:32:36 -0800
commit2ca8ae9aa1b29bf1f46d0b656d9885e438e67f53 (patch)
treed254abf510f28e509b15fd97e7457e0b2ed66b27 /sql/hive/src/test
parenta36a76ac43c36a3b897a748bd9f138b629dbc684 (diff)
downloadspark-2ca8ae9aa1b29bf1f46d0b656d9885e438e67f53.tar.gz
spark-2ca8ae9aa1b29bf1f46d0b656d9885e438e67f53.tar.bz2
spark-2ca8ae9aa1b29bf1f46d0b656d9885e438e67f53.zip
[SPARK-18186] Migrate HiveUDAFFunction to TypedImperativeAggregate for partial aggregation support
## What changes were proposed in this pull request? While being evaluated in Spark SQL, Hive UDAFs don't support partial aggregation. This PR migrates `HiveUDAFFunction`s to `TypedImperativeAggregate`, which already provides partial aggregation support for aggregate functions that may use arbitrary Java objects as aggregation states. The following snippet shows the effect of this PR: ```scala import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax sql(s"CREATE FUNCTION hive_max AS '${classOf[GenericUDAFMax].getName}'") spark.range(100).createOrReplaceTempView("t") // A query using both Spark SQL native `max` and Hive `max` sql(s"SELECT max(id), hive_max(id) FROM t").explain() ``` Before this PR: ``` == Physical Plan == SortAggregate(key=[], functions=[max(id#1L), default.hive_max(default.hive_max, HiveFunctionWrapper(org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax,org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax7475f57e), id#1L, false, 0, 0)]) +- Exchange SinglePartition +- *Range (0, 100, step=1, splits=Some(1)) ``` After this PR: ``` == Physical Plan == SortAggregate(key=[], functions=[max(id#1L), default.hive_max(default.hive_max, HiveFunctionWrapper(org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax,org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax5e18a6a7), id#1L, false, 0, 0)]) +- Exchange SinglePartition +- SortAggregate(key=[], functions=[partial_max(id#1L), partial_default.hive_max(default.hive_max, HiveFunctionWrapper(org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax,org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax5e18a6a7), id#1L, false, 0, 0)]) +- *Range (0, 100, step=1, splits=Some(1)) ``` The tricky part of the PR is mostly about updating and passing around aggregation states of `HiveUDAFFunction`s since the aggregation state of a Hive UDAF may appear in three different forms. Let's take a look at the testing `MockUDAF` added in this PR as an example. This UDAF computes the count of non-null values together with the count of nulls of a given column. Its aggregation state may appear as the following forms at different time: 1. A `MockUDAFBuffer`, which is a concrete subclass of `GenericUDAFEvaluator.AggregationBuffer` The form used by Hive UDAF API. This form is required by the following scenarios: - Calling `GenericUDAFEvaluator.iterate()` to update an existing aggregation state with new input values. - Calling `GenericUDAFEvaluator.terminate()` to get the final aggregated value from an existing aggregation state. - Calling `GenericUDAFEvaluator.merge()` to merge other aggregation states into an existing aggregation state. The existing aggregation state to be updated must be in this form. Conversions: - To form 2: `GenericUDAFEvaluator.terminatePartial()` - To form 3: Convert to form 2 first, and then to 3. 2. An `Object[]` array containing two `java.lang.Long` values. The form used to interact with Hive's `ObjectInspector`s. This form is required by the following scenarios: - Calling `GenericUDAFEvaluator.terminatePartial()` to convert an existing aggregation state in form 1 to form 2. - Calling `GenericUDAFEvaluator.merge()` to merge other aggregation states into an existing aggregation state. The input aggregation state must be in this form. Conversions: - To form 1: No direct method. Have to create an empty `AggregationBuffer` and merge it into the empty buffer. - To form 3: `unwrapperFor()`/`unwrap()` method of `HiveInspectors` 3. The byte array that holds data of an `UnsafeRow` with two `LongType` fields. The form used by Spark SQL to shuffle partial aggregation results. This form is required because `TypedImperativeAggregate` always asks its subclasses to serialize their aggregation states into a byte array. Conversions: - To form 1: Convert to form 2 first, and then to 1. - To form 2: `wrapperFor()`/`wrap()` method of `HiveInspectors` Here're some micro-benchmark results produced by the most recent master and this PR branch. Master: ``` Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.10.5 Intel(R) Core(TM) i7-4960HQ CPU 2.60GHz hive udaf vs spark af: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ w/o groupBy 339 / 372 3.1 323.2 1.0X w/ groupBy 503 / 529 2.1 479.7 0.7X ``` This PR: ``` Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.10.5 Intel(R) Core(TM) i7-4960HQ CPU 2.60GHz hive udaf vs spark af: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ w/o groupBy 116 / 126 9.0 110.8 1.0X w/ groupBy 151 / 159 6.9 144.0 0.8X ``` Benchmark code snippet: ```scala test("Hive UDAF benchmark") { val N = 1 << 20 sparkSession.sql(s"CREATE TEMPORARY FUNCTION hive_max AS '${classOf[GenericUDAFMax].getName}'") val benchmark = new Benchmark( name = "hive udaf vs spark af", valuesPerIteration = N, minNumIters = 5, warmupTime = 5.seconds, minTime = 5.seconds, outputPerIteration = true ) benchmark.addCase("w/o groupBy") { _ => sparkSession.range(N).agg("id" -> "hive_max").collect() } benchmark.addCase("w/ groupBy") { _ => sparkSession.range(N).groupBy($"id" % 10).agg("id" -> "hive_max").collect() } benchmark.run() sparkSession.sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_max") } ``` ## How was this patch tested? New test suite `HiveUDAFSuite` is added. Author: Cheng Lian <lian@databricks.com> Closes #15703 from liancheng/partial-agg-hive-udaf.
Diffstat (limited to 'sql/hive/src/test')
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala152
1 files changed, 152 insertions, 0 deletions
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala
new file mode 100644
index 0000000000..c9ef72ee11
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.execution
+
+import scala.collection.JavaConverters._
+
+import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, GenericUDAFEvaluator, GenericUDAFMax}
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.{AggregationBuffer, Mode}
+import org.apache.hadoop.hive.ql.util.JavaDataModel
+import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory}
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
+import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo
+
+import org.apache.spark.sql.{QueryTest, Row}
+import org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec
+import org.apache.spark.sql.hive.test.TestHiveSingleton
+import org.apache.spark.sql.test.SQLTestUtils
+
+class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
+ import testImplicits._
+
+ protected override def beforeAll(): Unit = {
+ sql(s"CREATE TEMPORARY FUNCTION mock AS '${classOf[MockUDAF].getName}'")
+ sql(s"CREATE TEMPORARY FUNCTION hive_max AS '${classOf[GenericUDAFMax].getName}'")
+
+ Seq(
+ (0: Integer) -> "val_0",
+ (1: Integer) -> "val_1",
+ (2: Integer) -> null,
+ (3: Integer) -> null
+ ).toDF("key", "value").repartition(2).createOrReplaceTempView("t")
+ }
+
+ protected override def afterAll(): Unit = {
+ sql(s"DROP TEMPORARY FUNCTION IF EXISTS mock")
+ sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_max")
+ }
+
+ test("built-in Hive UDAF") {
+ val df = sql("SELECT key % 2, hive_max(key) FROM t GROUP BY key % 2")
+
+ val aggs = df.queryExecution.executedPlan.collect {
+ case agg: ObjectHashAggregateExec => agg
+ }
+
+ // There should be two aggregate operators, one for partial aggregation, and the other for
+ // global aggregation.
+ assert(aggs.length == 2)
+
+ checkAnswer(df, Seq(
+ Row(0, 2),
+ Row(1, 3)
+ ))
+ }
+
+ test("customized Hive UDAF") {
+ val df = sql("SELECT key % 2, mock(value) FROM t GROUP BY key % 2")
+
+ val aggs = df.queryExecution.executedPlan.collect {
+ case agg: ObjectHashAggregateExec => agg
+ }
+
+ // There should be two aggregate operators, one for partial aggregation, and the other for
+ // global aggregation.
+ assert(aggs.length == 2)
+
+ checkAnswer(df, Seq(
+ Row(0, Row(1, 1)),
+ Row(1, Row(1, 1))
+ ))
+ }
+}
+
+/**
+ * A testing Hive UDAF that computes the counts of both non-null values and nulls of a given column.
+ */
+class MockUDAF extends AbstractGenericUDAFResolver {
+ override def getEvaluator(info: Array[TypeInfo]): GenericUDAFEvaluator = new MockUDAFEvaluator
+}
+
+class MockUDAFBuffer(var nonNullCount: Long, var nullCount: Long)
+ extends GenericUDAFEvaluator.AbstractAggregationBuffer {
+
+ override def estimate(): Int = JavaDataModel.PRIMITIVES2 * 2
+}
+
+class MockUDAFEvaluator extends GenericUDAFEvaluator {
+ private val nonNullCountOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector
+
+ private val nullCountOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector
+
+ private val bufferOI = {
+ val fieldNames = Seq("nonNullCount", "nullCount").asJava
+ val fieldOIs = Seq(nonNullCountOI: ObjectInspector, nullCountOI: ObjectInspector).asJava
+ ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs)
+ }
+
+ private val nonNullCountField = bufferOI.getStructFieldRef("nonNullCount")
+
+ private val nullCountField = bufferOI.getStructFieldRef("nullCount")
+
+ override def getNewAggregationBuffer: AggregationBuffer = new MockUDAFBuffer(0L, 0L)
+
+ override def reset(agg: AggregationBuffer): Unit = {
+ val buffer = agg.asInstanceOf[MockUDAFBuffer]
+ buffer.nonNullCount = 0L
+ buffer.nullCount = 0L
+ }
+
+ override def init(mode: Mode, parameters: Array[ObjectInspector]): ObjectInspector = bufferOI
+
+ override def iterate(agg: AggregationBuffer, parameters: Array[AnyRef]): Unit = {
+ val buffer = agg.asInstanceOf[MockUDAFBuffer]
+ if (parameters.head eq null) {
+ buffer.nullCount += 1L
+ } else {
+ buffer.nonNullCount += 1L
+ }
+ }
+
+ override def merge(agg: AggregationBuffer, partial: Object): Unit = {
+ if (partial ne null) {
+ val nonNullCount = nonNullCountOI.get(bufferOI.getStructFieldData(partial, nonNullCountField))
+ val nullCount = nullCountOI.get(bufferOI.getStructFieldData(partial, nullCountField))
+ val buffer = agg.asInstanceOf[MockUDAFBuffer]
+ buffer.nonNullCount += nonNullCount
+ buffer.nullCount += nullCount
+ }
+ }
+
+ override def terminatePartial(agg: AggregationBuffer): AnyRef = {
+ val buffer = agg.asInstanceOf[MockUDAFBuffer]
+ Array[Object](buffer.nonNullCount: java.lang.Long, buffer.nullCount: java.lang.Long)
+ }
+
+ override def terminate(agg: AggregationBuffer): AnyRef = terminatePartial(agg)
+}