diff options
-rw-r--r-- | sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala | 199 | ||||
-rw-r--r-- | sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala | 152 |
2 files changed, 301 insertions, 50 deletions
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 42033080dc..32edd4aec2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -17,16 +17,18 @@ package org.apache.spark.sql.hive +import java.nio.ByteBuffer + import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.hive.ql.exec._ import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} import org.apache.hadoop.hive.ql.udf.generic._ +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._ import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper -import org.apache.hadoop.hive.serde2.objectinspector.{ConstantObjectInspector, ObjectInspector, - ObjectInspectorFactory} +import org.apache.hadoop.hive.serde2.objectinspector.{ConstantObjectInspector, ObjectInspector, ObjectInspectorFactory} import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions import org.apache.spark.internal.Logging @@ -58,7 +60,7 @@ private[hive] case class HiveSimpleUDF( @transient private lazy val isUDFDeterministic = { - val udfType = function.getClass().getAnnotation(classOf[HiveUDFType]) + val udfType = function.getClass.getAnnotation(classOf[HiveUDFType]) udfType != null && udfType.deterministic() } @@ -75,7 +77,7 @@ private[hive] case class HiveSimpleUDF( @transient lazy val unwrapper = unwrapperFor(ObjectInspectorFactory.getReflectionObjectInspector( - method.getGenericReturnType(), ObjectInspectorOptions.JAVA)) + method.getGenericReturnType, ObjectInspectorOptions.JAVA)) @transient private lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length) @@ -263,8 +265,35 @@ private[hive] case class HiveGenericUDTF( } /** - * Currently we don't support partial aggregation for queries using Hive UDAF, which may hurt - * performance a lot. + * While being evaluated by Spark SQL, the aggregation state of a Hive UDAF may be in the following + * three formats: + * + * 1. An instance of some concrete `GenericUDAFEvaluator.AggregationBuffer` class + * + * This is the native Hive representation of an aggregation state. Hive `GenericUDAFEvaluator` + * methods like `iterate()`, `merge()`, `terminatePartial()`, and `terminate()` use this format. + * We call these methods to evaluate Hive UDAFs. + * + * 2. A Java object that can be inspected using the `ObjectInspector` returned by the + * `GenericUDAFEvaluator.init()` method. + * + * Hive uses this format to produce a serializable aggregation state so that it can shuffle + * partial aggregation results. Whenever we need to convert a Hive `AggregationBuffer` instance + * into a Spark SQL value, we have to convert it to this format first and then do the conversion + * with the help of `ObjectInspector`s. + * + * 3. A Spark SQL value + * + * We use this format for serializing Hive UDAF aggregation states on Spark side. To be more + * specific, we convert `AggregationBuffer`s into equivalent Spark SQL values, write them into + * `UnsafeRow`s, and then retrieve the byte array behind those `UnsafeRow`s as serialization + * results. + * + * We may use the following methods to convert the aggregation state back and forth: + * + * - `wrap()`/`wrapperFor()`: from 3 to 1 + * - `unwrap()`/`unwrapperFor()`: from 1 to 3 + * - `GenericUDAFEvaluator.terminatePartial()`: from 2 to 3 */ private[hive] case class HiveUDAFFunction( name: String, @@ -273,7 +302,7 @@ private[hive] case class HiveUDAFFunction( isUDAFBridgeRequired: Boolean = false, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) - extends ImperativeAggregate with HiveInspectors { + extends TypedImperativeAggregate[GenericUDAFEvaluator.AggregationBuffer] with HiveInspectors { override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) @@ -281,73 +310,73 @@ private[hive] case class HiveUDAFFunction( override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = copy(inputAggBufferOffset = newInputAggBufferOffset) + // Hive `ObjectInspector`s for all child expressions (input parameters of the function). @transient - private lazy val resolver = - if (isUDAFBridgeRequired) { + private lazy val inputInspectors = children.map(toInspector).toArray + + // Spark SQL data types of input parameters. + @transient + private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray + + private def newEvaluator(): GenericUDAFEvaluator = { + val resolver = if (isUDAFBridgeRequired) { new GenericUDAFBridge(funcWrapper.createFunction[UDAF]()) } else { funcWrapper.createFunction[AbstractGenericUDAFResolver]() } - @transient - private lazy val inspectors = children.map(toInspector).toArray - - @transient - private lazy val functionAndInspector = { - val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors, false, false) - val f = resolver.getEvaluator(parameterInfo) - f -> f.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors) + val parameterInfo = new SimpleGenericUDAFParameterInfo(inputInspectors, false, false) + resolver.getEvaluator(parameterInfo) } + // The UDAF evaluator used to consume raw input rows and produce partial aggregation results. @transient - private lazy val function = functionAndInspector._1 + private lazy val partial1ModeEvaluator = newEvaluator() + // Hive `ObjectInspector` used to inspect partial aggregation results. @transient - private lazy val wrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray + private val partialResultInspector = partial1ModeEvaluator.init( + GenericUDAFEvaluator.Mode.PARTIAL1, + inputInspectors + ) + // The UDAF evaluator used to merge partial aggregation results. @transient - private lazy val returnInspector = functionAndInspector._2 + private lazy val partial2ModeEvaluator = { + val evaluator = newEvaluator() + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL2, Array(partialResultInspector)) + evaluator + } + // Spark SQL data type of partial aggregation results @transient - private lazy val unwrapper = unwrapperFor(returnInspector) + private lazy val partialResultDataType = inspectorToDataType(partialResultInspector) + // The UDAF evaluator used to compute the final result from a partial aggregation result objects. @transient - private[this] var buffer: GenericUDAFEvaluator.AggregationBuffer = _ - - override def eval(input: InternalRow): Any = unwrapper(function.evaluate(buffer)) + private lazy val finalModeEvaluator = newEvaluator() + // Hive `ObjectInspector` used to inspect the final aggregation result object. @transient - private lazy val inputProjection = new InterpretedProjection(children) + private val returnInspector = finalModeEvaluator.init( + GenericUDAFEvaluator.Mode.FINAL, + Array(partialResultInspector) + ) + // Wrapper functions used to wrap Spark SQL input arguments into Hive specific format. @transient - private lazy val cached = new Array[AnyRef](children.length) + private lazy val inputWrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray + // Unwrapper function used to unwrap final aggregation result objects returned by Hive UDAFs into + // Spark SQL specific format. @transient - private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray - - // Hive UDAF has its own buffer, so we don't need to occupy a slot in the aggregation - // buffer for it. - override def aggBufferSchema: StructType = StructType(Nil) - - override def update(_buffer: InternalRow, input: InternalRow): Unit = { - val inputs = inputProjection(input) - function.iterate(buffer, wrap(inputs, wrappers, cached, inputDataTypes)) - } - - override def merge(buffer1: InternalRow, buffer2: InternalRow): Unit = { - throw new UnsupportedOperationException( - "Hive UDAF doesn't support partial aggregate") - } + private lazy val resultUnwrapper = unwrapperFor(returnInspector) - override def initialize(_buffer: InternalRow): Unit = { - buffer = function.getNewAggregationBuffer - } - - override val aggBufferAttributes: Seq[AttributeReference] = Nil + @transient + private lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length) - // Note: although this simply copies aggBufferAttributes, this common code can not be placed - // in the superclass because that will lead to initialization ordering issues. - override val inputAggBufferAttributes: Seq[AttributeReference] = Nil + @transient + private lazy val aggBufferSerDe: AggregationBufferSerDe = new AggregationBufferSerDe // We rely on Hive to check the input data types, so use `AnyDataType` here to bypass our // catalyst type checking framework. @@ -355,7 +384,7 @@ private[hive] case class HiveUDAFFunction( override def nullable: Boolean = true - override def supportsPartial: Boolean = false + override def supportsPartial: Boolean = true override lazy val dataType: DataType = inspectorToDataType(returnInspector) @@ -365,4 +394,74 @@ private[hive] case class HiveUDAFFunction( val distinct = if (isDistinct) "DISTINCT " else " " s"$name($distinct${children.map(_.sql).mkString(", ")})" } + + override def createAggregationBuffer(): AggregationBuffer = + partial1ModeEvaluator.getNewAggregationBuffer + + @transient + private lazy val inputProjection = UnsafeProjection.create(children) + + override def update(buffer: AggregationBuffer, input: InternalRow): Unit = { + partial1ModeEvaluator.iterate( + buffer, wrap(inputProjection(input), inputWrappers, cached, inputDataTypes)) + } + + override def merge(buffer: AggregationBuffer, input: AggregationBuffer): Unit = { + // The 2nd argument of the Hive `GenericUDAFEvaluator.merge()` method is an input aggregation + // buffer in the 3rd format mentioned in the ScalaDoc of this class. Originally, Hive converts + // this `AggregationBuffer`s into this format before shuffling partial aggregation results, and + // calls `GenericUDAFEvaluator.terminatePartial()` to do the conversion. + partial2ModeEvaluator.merge(buffer, partial1ModeEvaluator.terminatePartial(input)) + } + + override def eval(buffer: AggregationBuffer): Any = { + resultUnwrapper(finalModeEvaluator.terminate(buffer)) + } + + override def serialize(buffer: AggregationBuffer): Array[Byte] = { + // Serializes an `AggregationBuffer` that holds partial aggregation results so that we can + // shuffle it for global aggregation later. + aggBufferSerDe.serialize(buffer) + } + + override def deserialize(bytes: Array[Byte]): AggregationBuffer = { + // Deserializes an `AggregationBuffer` from the shuffled partial aggregation phase to prepare + // for global aggregation by merging multiple partial aggregation results within a single group. + aggBufferSerDe.deserialize(bytes) + } + + // Helper class used to de/serialize Hive UDAF `AggregationBuffer` objects + private class AggregationBufferSerDe { + private val partialResultUnwrapper = unwrapperFor(partialResultInspector) + + private val partialResultWrapper = wrapperFor(partialResultInspector, partialResultDataType) + + private val projection = UnsafeProjection.create(Array(partialResultDataType)) + + private val mutableRow = new GenericInternalRow(1) + + def serialize(buffer: AggregationBuffer): Array[Byte] = { + // `GenericUDAFEvaluator.terminatePartial()` converts an `AggregationBuffer` into an object + // that can be inspected by the `ObjectInspector` returned by `GenericUDAFEvaluator.init()`. + // Then we can unwrap it to a Spark SQL value. + mutableRow.update(0, partialResultUnwrapper(partial1ModeEvaluator.terminatePartial(buffer))) + val unsafeRow = projection(mutableRow) + val bytes = ByteBuffer.allocate(unsafeRow.getSizeInBytes) + unsafeRow.writeTo(bytes) + bytes.array() + } + + def deserialize(bytes: Array[Byte]): AggregationBuffer = { + // `GenericUDAFEvaluator` doesn't provide any method that is capable to convert an object + // returned by `GenericUDAFEvaluator.terminatePartial()` back to an `AggregationBuffer`. The + // workaround here is creating an initial `AggregationBuffer` first and then merge the + // deserialized object into the buffer. + val buffer = partial2ModeEvaluator.getNewAggregationBuffer + val unsafeRow = new UnsafeRow(1) + unsafeRow.pointTo(bytes, bytes.length) + val partialResult = unsafeRow.get(0, partialResultDataType) + partial2ModeEvaluator.merge(buffer, partialResultWrapper(partialResult)) + buffer + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala new file mode 100644 index 0000000000..c9ef72ee11 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, GenericUDAFEvaluator, GenericUDAFMax} +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.{AggregationBuffer, Mode} +import org.apache.hadoop.hive.ql.util.JavaDataModel +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory} +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils + +class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { + import testImplicits._ + + protected override def beforeAll(): Unit = { + sql(s"CREATE TEMPORARY FUNCTION mock AS '${classOf[MockUDAF].getName}'") + sql(s"CREATE TEMPORARY FUNCTION hive_max AS '${classOf[GenericUDAFMax].getName}'") + + Seq( + (0: Integer) -> "val_0", + (1: Integer) -> "val_1", + (2: Integer) -> null, + (3: Integer) -> null + ).toDF("key", "value").repartition(2).createOrReplaceTempView("t") + } + + protected override def afterAll(): Unit = { + sql(s"DROP TEMPORARY FUNCTION IF EXISTS mock") + sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_max") + } + + test("built-in Hive UDAF") { + val df = sql("SELECT key % 2, hive_max(key) FROM t GROUP BY key % 2") + + val aggs = df.queryExecution.executedPlan.collect { + case agg: ObjectHashAggregateExec => agg + } + + // There should be two aggregate operators, one for partial aggregation, and the other for + // global aggregation. + assert(aggs.length == 2) + + checkAnswer(df, Seq( + Row(0, 2), + Row(1, 3) + )) + } + + test("customized Hive UDAF") { + val df = sql("SELECT key % 2, mock(value) FROM t GROUP BY key % 2") + + val aggs = df.queryExecution.executedPlan.collect { + case agg: ObjectHashAggregateExec => agg + } + + // There should be two aggregate operators, one for partial aggregation, and the other for + // global aggregation. + assert(aggs.length == 2) + + checkAnswer(df, Seq( + Row(0, Row(1, 1)), + Row(1, Row(1, 1)) + )) + } +} + +/** + * A testing Hive UDAF that computes the counts of both non-null values and nulls of a given column. + */ +class MockUDAF extends AbstractGenericUDAFResolver { + override def getEvaluator(info: Array[TypeInfo]): GenericUDAFEvaluator = new MockUDAFEvaluator +} + +class MockUDAFBuffer(var nonNullCount: Long, var nullCount: Long) + extends GenericUDAFEvaluator.AbstractAggregationBuffer { + + override def estimate(): Int = JavaDataModel.PRIMITIVES2 * 2 +} + +class MockUDAFEvaluator extends GenericUDAFEvaluator { + private val nonNullCountOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector + + private val nullCountOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector + + private val bufferOI = { + val fieldNames = Seq("nonNullCount", "nullCount").asJava + val fieldOIs = Seq(nonNullCountOI: ObjectInspector, nullCountOI: ObjectInspector).asJava + ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs) + } + + private val nonNullCountField = bufferOI.getStructFieldRef("nonNullCount") + + private val nullCountField = bufferOI.getStructFieldRef("nullCount") + + override def getNewAggregationBuffer: AggregationBuffer = new MockUDAFBuffer(0L, 0L) + + override def reset(agg: AggregationBuffer): Unit = { + val buffer = agg.asInstanceOf[MockUDAFBuffer] + buffer.nonNullCount = 0L + buffer.nullCount = 0L + } + + override def init(mode: Mode, parameters: Array[ObjectInspector]): ObjectInspector = bufferOI + + override def iterate(agg: AggregationBuffer, parameters: Array[AnyRef]): Unit = { + val buffer = agg.asInstanceOf[MockUDAFBuffer] + if (parameters.head eq null) { + buffer.nullCount += 1L + } else { + buffer.nonNullCount += 1L + } + } + + override def merge(agg: AggregationBuffer, partial: Object): Unit = { + if (partial ne null) { + val nonNullCount = nonNullCountOI.get(bufferOI.getStructFieldData(partial, nonNullCountField)) + val nullCount = nullCountOI.get(bufferOI.getStructFieldData(partial, nullCountField)) + val buffer = agg.asInstanceOf[MockUDAFBuffer] + buffer.nonNullCount += nonNullCount + buffer.nullCount += nullCount + } + } + + override def terminatePartial(agg: AggregationBuffer): AnyRef = { + val buffer = agg.asInstanceOf[MockUDAFBuffer] + Array[Object](buffer.nonNullCount: java.lang.Long, buffer.nullCount: java.lang.Long) + } + + override def terminate(agg: AggregationBuffer): AnyRef = terminatePartial(agg) +} |