aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala199
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala152
2 files changed, 301 insertions, 50 deletions
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index 42033080dc..32edd4aec2 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -17,16 +17,18 @@
package org.apache.spark.sql.hive
+import java.nio.ByteBuffer
+
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import org.apache.hadoop.hive.ql.exec._
import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType}
import org.apache.hadoop.hive.ql.udf.generic._
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper
-import org.apache.hadoop.hive.serde2.objectinspector.{ConstantObjectInspector, ObjectInspector,
- ObjectInspectorFactory}
+import org.apache.hadoop.hive.serde2.objectinspector.{ConstantObjectInspector, ObjectInspector, ObjectInspectorFactory}
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions
import org.apache.spark.internal.Logging
@@ -58,7 +60,7 @@ private[hive] case class HiveSimpleUDF(
@transient
private lazy val isUDFDeterministic = {
- val udfType = function.getClass().getAnnotation(classOf[HiveUDFType])
+ val udfType = function.getClass.getAnnotation(classOf[HiveUDFType])
udfType != null && udfType.deterministic()
}
@@ -75,7 +77,7 @@ private[hive] case class HiveSimpleUDF(
@transient
lazy val unwrapper = unwrapperFor(ObjectInspectorFactory.getReflectionObjectInspector(
- method.getGenericReturnType(), ObjectInspectorOptions.JAVA))
+ method.getGenericReturnType, ObjectInspectorOptions.JAVA))
@transient
private lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length)
@@ -263,8 +265,35 @@ private[hive] case class HiveGenericUDTF(
}
/**
- * Currently we don't support partial aggregation for queries using Hive UDAF, which may hurt
- * performance a lot.
+ * While being evaluated by Spark SQL, the aggregation state of a Hive UDAF may be in the following
+ * three formats:
+ *
+ * 1. An instance of some concrete `GenericUDAFEvaluator.AggregationBuffer` class
+ *
+ * This is the native Hive representation of an aggregation state. Hive `GenericUDAFEvaluator`
+ * methods like `iterate()`, `merge()`, `terminatePartial()`, and `terminate()` use this format.
+ * We call these methods to evaluate Hive UDAFs.
+ *
+ * 2. A Java object that can be inspected using the `ObjectInspector` returned by the
+ * `GenericUDAFEvaluator.init()` method.
+ *
+ * Hive uses this format to produce a serializable aggregation state so that it can shuffle
+ * partial aggregation results. Whenever we need to convert a Hive `AggregationBuffer` instance
+ * into a Spark SQL value, we have to convert it to this format first and then do the conversion
+ * with the help of `ObjectInspector`s.
+ *
+ * 3. A Spark SQL value
+ *
+ * We use this format for serializing Hive UDAF aggregation states on Spark side. To be more
+ * specific, we convert `AggregationBuffer`s into equivalent Spark SQL values, write them into
+ * `UnsafeRow`s, and then retrieve the byte array behind those `UnsafeRow`s as serialization
+ * results.
+ *
+ * We may use the following methods to convert the aggregation state back and forth:
+ *
+ * - `wrap()`/`wrapperFor()`: from 3 to 1
+ * - `unwrap()`/`unwrapperFor()`: from 1 to 3
+ * - `GenericUDAFEvaluator.terminatePartial()`: from 2 to 3
*/
private[hive] case class HiveUDAFFunction(
name: String,
@@ -273,7 +302,7 @@ private[hive] case class HiveUDAFFunction(
isUDAFBridgeRequired: Boolean = false,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
- extends ImperativeAggregate with HiveInspectors {
+ extends TypedImperativeAggregate[GenericUDAFEvaluator.AggregationBuffer] with HiveInspectors {
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
@@ -281,73 +310,73 @@ private[hive] case class HiveUDAFFunction(
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)
+ // Hive `ObjectInspector`s for all child expressions (input parameters of the function).
@transient
- private lazy val resolver =
- if (isUDAFBridgeRequired) {
+ private lazy val inputInspectors = children.map(toInspector).toArray
+
+ // Spark SQL data types of input parameters.
+ @transient
+ private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray
+
+ private def newEvaluator(): GenericUDAFEvaluator = {
+ val resolver = if (isUDAFBridgeRequired) {
new GenericUDAFBridge(funcWrapper.createFunction[UDAF]())
} else {
funcWrapper.createFunction[AbstractGenericUDAFResolver]()
}
- @transient
- private lazy val inspectors = children.map(toInspector).toArray
-
- @transient
- private lazy val functionAndInspector = {
- val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors, false, false)
- val f = resolver.getEvaluator(parameterInfo)
- f -> f.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors)
+ val parameterInfo = new SimpleGenericUDAFParameterInfo(inputInspectors, false, false)
+ resolver.getEvaluator(parameterInfo)
}
+ // The UDAF evaluator used to consume raw input rows and produce partial aggregation results.
@transient
- private lazy val function = functionAndInspector._1
+ private lazy val partial1ModeEvaluator = newEvaluator()
+ // Hive `ObjectInspector` used to inspect partial aggregation results.
@transient
- private lazy val wrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray
+ private val partialResultInspector = partial1ModeEvaluator.init(
+ GenericUDAFEvaluator.Mode.PARTIAL1,
+ inputInspectors
+ )
+ // The UDAF evaluator used to merge partial aggregation results.
@transient
- private lazy val returnInspector = functionAndInspector._2
+ private lazy val partial2ModeEvaluator = {
+ val evaluator = newEvaluator()
+ evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL2, Array(partialResultInspector))
+ evaluator
+ }
+ // Spark SQL data type of partial aggregation results
@transient
- private lazy val unwrapper = unwrapperFor(returnInspector)
+ private lazy val partialResultDataType = inspectorToDataType(partialResultInspector)
+ // The UDAF evaluator used to compute the final result from a partial aggregation result objects.
@transient
- private[this] var buffer: GenericUDAFEvaluator.AggregationBuffer = _
-
- override def eval(input: InternalRow): Any = unwrapper(function.evaluate(buffer))
+ private lazy val finalModeEvaluator = newEvaluator()
+ // Hive `ObjectInspector` used to inspect the final aggregation result object.
@transient
- private lazy val inputProjection = new InterpretedProjection(children)
+ private val returnInspector = finalModeEvaluator.init(
+ GenericUDAFEvaluator.Mode.FINAL,
+ Array(partialResultInspector)
+ )
+ // Wrapper functions used to wrap Spark SQL input arguments into Hive specific format.
@transient
- private lazy val cached = new Array[AnyRef](children.length)
+ private lazy val inputWrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray
+ // Unwrapper function used to unwrap final aggregation result objects returned by Hive UDAFs into
+ // Spark SQL specific format.
@transient
- private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray
-
- // Hive UDAF has its own buffer, so we don't need to occupy a slot in the aggregation
- // buffer for it.
- override def aggBufferSchema: StructType = StructType(Nil)
-
- override def update(_buffer: InternalRow, input: InternalRow): Unit = {
- val inputs = inputProjection(input)
- function.iterate(buffer, wrap(inputs, wrappers, cached, inputDataTypes))
- }
-
- override def merge(buffer1: InternalRow, buffer2: InternalRow): Unit = {
- throw new UnsupportedOperationException(
- "Hive UDAF doesn't support partial aggregate")
- }
+ private lazy val resultUnwrapper = unwrapperFor(returnInspector)
- override def initialize(_buffer: InternalRow): Unit = {
- buffer = function.getNewAggregationBuffer
- }
-
- override val aggBufferAttributes: Seq[AttributeReference] = Nil
+ @transient
+ private lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length)
- // Note: although this simply copies aggBufferAttributes, this common code can not be placed
- // in the superclass because that will lead to initialization ordering issues.
- override val inputAggBufferAttributes: Seq[AttributeReference] = Nil
+ @transient
+ private lazy val aggBufferSerDe: AggregationBufferSerDe = new AggregationBufferSerDe
// We rely on Hive to check the input data types, so use `AnyDataType` here to bypass our
// catalyst type checking framework.
@@ -355,7 +384,7 @@ private[hive] case class HiveUDAFFunction(
override def nullable: Boolean = true
- override def supportsPartial: Boolean = false
+ override def supportsPartial: Boolean = true
override lazy val dataType: DataType = inspectorToDataType(returnInspector)
@@ -365,4 +394,74 @@ private[hive] case class HiveUDAFFunction(
val distinct = if (isDistinct) "DISTINCT " else " "
s"$name($distinct${children.map(_.sql).mkString(", ")})"
}
+
+ override def createAggregationBuffer(): AggregationBuffer =
+ partial1ModeEvaluator.getNewAggregationBuffer
+
+ @transient
+ private lazy val inputProjection = UnsafeProjection.create(children)
+
+ override def update(buffer: AggregationBuffer, input: InternalRow): Unit = {
+ partial1ModeEvaluator.iterate(
+ buffer, wrap(inputProjection(input), inputWrappers, cached, inputDataTypes))
+ }
+
+ override def merge(buffer: AggregationBuffer, input: AggregationBuffer): Unit = {
+ // The 2nd argument of the Hive `GenericUDAFEvaluator.merge()` method is an input aggregation
+ // buffer in the 3rd format mentioned in the ScalaDoc of this class. Originally, Hive converts
+ // this `AggregationBuffer`s into this format before shuffling partial aggregation results, and
+ // calls `GenericUDAFEvaluator.terminatePartial()` to do the conversion.
+ partial2ModeEvaluator.merge(buffer, partial1ModeEvaluator.terminatePartial(input))
+ }
+
+ override def eval(buffer: AggregationBuffer): Any = {
+ resultUnwrapper(finalModeEvaluator.terminate(buffer))
+ }
+
+ override def serialize(buffer: AggregationBuffer): Array[Byte] = {
+ // Serializes an `AggregationBuffer` that holds partial aggregation results so that we can
+ // shuffle it for global aggregation later.
+ aggBufferSerDe.serialize(buffer)
+ }
+
+ override def deserialize(bytes: Array[Byte]): AggregationBuffer = {
+ // Deserializes an `AggregationBuffer` from the shuffled partial aggregation phase to prepare
+ // for global aggregation by merging multiple partial aggregation results within a single group.
+ aggBufferSerDe.deserialize(bytes)
+ }
+
+ // Helper class used to de/serialize Hive UDAF `AggregationBuffer` objects
+ private class AggregationBufferSerDe {
+ private val partialResultUnwrapper = unwrapperFor(partialResultInspector)
+
+ private val partialResultWrapper = wrapperFor(partialResultInspector, partialResultDataType)
+
+ private val projection = UnsafeProjection.create(Array(partialResultDataType))
+
+ private val mutableRow = new GenericInternalRow(1)
+
+ def serialize(buffer: AggregationBuffer): Array[Byte] = {
+ // `GenericUDAFEvaluator.terminatePartial()` converts an `AggregationBuffer` into an object
+ // that can be inspected by the `ObjectInspector` returned by `GenericUDAFEvaluator.init()`.
+ // Then we can unwrap it to a Spark SQL value.
+ mutableRow.update(0, partialResultUnwrapper(partial1ModeEvaluator.terminatePartial(buffer)))
+ val unsafeRow = projection(mutableRow)
+ val bytes = ByteBuffer.allocate(unsafeRow.getSizeInBytes)
+ unsafeRow.writeTo(bytes)
+ bytes.array()
+ }
+
+ def deserialize(bytes: Array[Byte]): AggregationBuffer = {
+ // `GenericUDAFEvaluator` doesn't provide any method that is capable to convert an object
+ // returned by `GenericUDAFEvaluator.terminatePartial()` back to an `AggregationBuffer`. The
+ // workaround here is creating an initial `AggregationBuffer` first and then merge the
+ // deserialized object into the buffer.
+ val buffer = partial2ModeEvaluator.getNewAggregationBuffer
+ val unsafeRow = new UnsafeRow(1)
+ unsafeRow.pointTo(bytes, bytes.length)
+ val partialResult = unsafeRow.get(0, partialResultDataType)
+ partial2ModeEvaluator.merge(buffer, partialResultWrapper(partialResult))
+ buffer
+ }
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala
new file mode 100644
index 0000000000..c9ef72ee11
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.execution
+
+import scala.collection.JavaConverters._
+
+import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, GenericUDAFEvaluator, GenericUDAFMax}
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.{AggregationBuffer, Mode}
+import org.apache.hadoop.hive.ql.util.JavaDataModel
+import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory}
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
+import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo
+
+import org.apache.spark.sql.{QueryTest, Row}
+import org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec
+import org.apache.spark.sql.hive.test.TestHiveSingleton
+import org.apache.spark.sql.test.SQLTestUtils
+
+class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
+ import testImplicits._
+
+ protected override def beforeAll(): Unit = {
+ sql(s"CREATE TEMPORARY FUNCTION mock AS '${classOf[MockUDAF].getName}'")
+ sql(s"CREATE TEMPORARY FUNCTION hive_max AS '${classOf[GenericUDAFMax].getName}'")
+
+ Seq(
+ (0: Integer) -> "val_0",
+ (1: Integer) -> "val_1",
+ (2: Integer) -> null,
+ (3: Integer) -> null
+ ).toDF("key", "value").repartition(2).createOrReplaceTempView("t")
+ }
+
+ protected override def afterAll(): Unit = {
+ sql(s"DROP TEMPORARY FUNCTION IF EXISTS mock")
+ sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_max")
+ }
+
+ test("built-in Hive UDAF") {
+ val df = sql("SELECT key % 2, hive_max(key) FROM t GROUP BY key % 2")
+
+ val aggs = df.queryExecution.executedPlan.collect {
+ case agg: ObjectHashAggregateExec => agg
+ }
+
+ // There should be two aggregate operators, one for partial aggregation, and the other for
+ // global aggregation.
+ assert(aggs.length == 2)
+
+ checkAnswer(df, Seq(
+ Row(0, 2),
+ Row(1, 3)
+ ))
+ }
+
+ test("customized Hive UDAF") {
+ val df = sql("SELECT key % 2, mock(value) FROM t GROUP BY key % 2")
+
+ val aggs = df.queryExecution.executedPlan.collect {
+ case agg: ObjectHashAggregateExec => agg
+ }
+
+ // There should be two aggregate operators, one for partial aggregation, and the other for
+ // global aggregation.
+ assert(aggs.length == 2)
+
+ checkAnswer(df, Seq(
+ Row(0, Row(1, 1)),
+ Row(1, Row(1, 1))
+ ))
+ }
+}
+
+/**
+ * A testing Hive UDAF that computes the counts of both non-null values and nulls of a given column.
+ */
+class MockUDAF extends AbstractGenericUDAFResolver {
+ override def getEvaluator(info: Array[TypeInfo]): GenericUDAFEvaluator = new MockUDAFEvaluator
+}
+
+class MockUDAFBuffer(var nonNullCount: Long, var nullCount: Long)
+ extends GenericUDAFEvaluator.AbstractAggregationBuffer {
+
+ override def estimate(): Int = JavaDataModel.PRIMITIVES2 * 2
+}
+
+class MockUDAFEvaluator extends GenericUDAFEvaluator {
+ private val nonNullCountOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector
+
+ private val nullCountOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector
+
+ private val bufferOI = {
+ val fieldNames = Seq("nonNullCount", "nullCount").asJava
+ val fieldOIs = Seq(nonNullCountOI: ObjectInspector, nullCountOI: ObjectInspector).asJava
+ ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs)
+ }
+
+ private val nonNullCountField = bufferOI.getStructFieldRef("nonNullCount")
+
+ private val nullCountField = bufferOI.getStructFieldRef("nullCount")
+
+ override def getNewAggregationBuffer: AggregationBuffer = new MockUDAFBuffer(0L, 0L)
+
+ override def reset(agg: AggregationBuffer): Unit = {
+ val buffer = agg.asInstanceOf[MockUDAFBuffer]
+ buffer.nonNullCount = 0L
+ buffer.nullCount = 0L
+ }
+
+ override def init(mode: Mode, parameters: Array[ObjectInspector]): ObjectInspector = bufferOI
+
+ override def iterate(agg: AggregationBuffer, parameters: Array[AnyRef]): Unit = {
+ val buffer = agg.asInstanceOf[MockUDAFBuffer]
+ if (parameters.head eq null) {
+ buffer.nullCount += 1L
+ } else {
+ buffer.nonNullCount += 1L
+ }
+ }
+
+ override def merge(agg: AggregationBuffer, partial: Object): Unit = {
+ if (partial ne null) {
+ val nonNullCount = nonNullCountOI.get(bufferOI.getStructFieldData(partial, nonNullCountField))
+ val nullCount = nullCountOI.get(bufferOI.getStructFieldData(partial, nullCountField))
+ val buffer = agg.asInstanceOf[MockUDAFBuffer]
+ buffer.nonNullCount += nonNullCount
+ buffer.nullCount += nullCount
+ }
+ }
+
+ override def terminatePartial(agg: AggregationBuffer): AnyRef = {
+ val buffer = agg.asInstanceOf[MockUDAFBuffer]
+ Array[Object](buffer.nonNullCount: java.lang.Long, buffer.nullCount: java.lang.Long)
+ }
+
+ override def terminate(agg: AggregationBuffer): AnyRef = terminatePartial(agg)
+}