diff options
author | Wenchen Fan <cloud0fan@163.com> | 2015-09-24 09:54:07 -0700 |
---|---|---|
committer | Yin Huai <yhuai@databricks.com> | 2015-09-24 09:54:07 -0700 |
commit | 341b13f8f5eb118f1fb4d4f84418715ac4750a4d (patch) | |
tree | e6b8d90482e7d8e6cc339a6f357d320e92d7e8ff /sql/hive | |
parent | 02144d6745ec0a6d8877d969feb82139bd22437f (diff) | |
download | spark-341b13f8f5eb118f1fb4d4f84418715ac4750a4d.tar.gz spark-341b13f8f5eb118f1fb4d4f84418715ac4750a4d.tar.bz2 spark-341b13f8f5eb118f1fb4d4f84418715ac4750a4d.zip |
[SPARK-10765] [SQL] use new aggregate interface for hive UDAF
Author: Wenchen Fan <cloud0fan@163.com>
Closes #8874 from cloud-fan/hive-agg.
Diffstat (limited to 'sql/hive')
-rw-r--r-- | sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala | 139 | ||||
-rw-r--r-- | sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala | 2 |
2 files changed, 58 insertions, 83 deletions
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index cad02373e5..fa9012b96e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule @@ -65,9 +66,10 @@ private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry) HiveGenericUDF(new HiveFunctionWrapper(functionClassName), children) } else if ( classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUDAF(new HiveFunctionWrapper(functionClassName), children) + HiveUDAFFunction(new HiveFunctionWrapper(functionClassName), children) } else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveUDAF(new HiveFunctionWrapper(functionClassName), children) + HiveUDAFFunction( + new HiveFunctionWrapper(functionClassName), children, isUDAFBridgeRequired = true) } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) { HiveGenericUDTF(new HiveFunctionWrapper(functionClassName), children) } else { @@ -441,70 +443,6 @@ private[hive] case class HiveWindowFunction( new HiveWindowFunction(funcWrapper, pivotResult, isUDAFBridgeRequired, children) } -private[hive] case class HiveGenericUDAF( - funcWrapper: HiveFunctionWrapper, - children: Seq[Expression]) extends AggregateExpression1 - with HiveInspectors { - - type UDFType = AbstractGenericUDAFResolver - - @transient - protected lazy val resolver: AbstractGenericUDAFResolver = funcWrapper.createFunction() - - @transient - protected lazy val objectInspector = { - val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors.toArray, false, false) - resolver.getEvaluator(parameterInfo) - .init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray) - } - - @transient - protected lazy val inspectors = children.map(toInspector) - - def dataType: DataType = inspectorToDataType(objectInspector) - - def nullable: Boolean = true - - override def toString: String = { - s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" - } - - def newInstance(): HiveUDAFFunction = new HiveUDAFFunction(funcWrapper, children, this) -} - -/** It is used as a wrapper for the hive functions which uses UDAF interface */ -private[hive] case class HiveUDAF( - funcWrapper: HiveFunctionWrapper, - children: Seq[Expression]) extends AggregateExpression1 - with HiveInspectors { - - type UDFType = UDAF - - @transient - protected lazy val resolver: AbstractGenericUDAFResolver = - new GenericUDAFBridge(funcWrapper.createFunction()) - - @transient - protected lazy val objectInspector = { - val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors.toArray, false, false) - resolver.getEvaluator(parameterInfo) - .init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray) - } - - @transient - protected lazy val inspectors = children.map(toInspector) - - def dataType: DataType = inspectorToDataType(objectInspector) - - def nullable: Boolean = true - - override def toString: String = { - s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" - } - - def newInstance(): HiveUDAFFunction = new HiveUDAFFunction(funcWrapper, children, this, true) -} - /** * Converts a Hive Generic User Defined Table Generating Function (UDTF) to a * [[Generator]]. Note that the semantics of Generators do not allow @@ -584,49 +522,86 @@ private[hive] case class HiveGenericUDTF( } } +/** + * Currently we don't support partial aggregation for queries using Hive UDAF, which may hurt + * performance a lot. + */ private[hive] case class HiveUDAFFunction( funcWrapper: HiveFunctionWrapper, - exprs: Seq[Expression], - base: AggregateExpression1, + children: Seq[Expression], isUDAFBridgeRequired: Boolean = false) - extends AggregateFunction1 - with HiveInspectors { + extends AggregateFunction2 with HiveInspectors { - def this() = this(null, null, null) + def this() = this(null, null) - private val resolver = + @transient + private lazy val resolver = if (isUDAFBridgeRequired) { new GenericUDAFBridge(funcWrapper.createFunction[UDAF]()) } else { funcWrapper.createFunction[AbstractGenericUDAFResolver]() } - private val inspectors = exprs.map(toInspector).toArray + @transient + private lazy val inspectors = children.map(toInspector).toArray - private val function = { + @transient + private lazy val functionAndInspector = { val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors, false, false) - resolver.getEvaluator(parameterInfo) + val f = resolver.getEvaluator(parameterInfo) + f -> f.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors) } - private val returnInspector = function.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors) + @transient + private lazy val function = functionAndInspector._1 + + @transient + private lazy val returnInspector = functionAndInspector._2 - private val buffer = - function.getNewAggregationBuffer + @transient + private[this] var buffer: GenericUDAFEvaluator.AggregationBuffer = _ override def eval(input: InternalRow): Any = unwrap(function.evaluate(buffer), returnInspector) @transient - val inputProjection = new InterpretedProjection(exprs) + private lazy val inputProjection = new InterpretedProjection(children) @transient - protected lazy val cached = new Array[AnyRef](exprs.length) + private lazy val cached = new Array[AnyRef](children.length) @transient - private lazy val inputDataTypes: Array[DataType] = exprs.map(_.dataType).toArray + private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray + + // Hive UDAF has its own buffer, so we don't need to occupy a slot in the aggregation + // buffer for it. + override def bufferSchema: StructType = StructType(Nil) - def update(input: InternalRow): Unit = { + override def update(_buffer: MutableRow, input: InternalRow): Unit = { val inputs = inputProjection(input) function.iterate(buffer, wrap(inputs, inspectors, cached, inputDataTypes)) } + + override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { + throw new UnsupportedOperationException( + "Hive UDAF doesn't support partial aggregate") + } + + override def cloneBufferAttributes: Seq[Attribute] = Nil + + override def initialize(_buffer: MutableRow): Unit = { + buffer = function.getNewAggregationBuffer + } + + override def bufferAttributes: Seq[AttributeReference] = Nil + + // We rely on Hive to check the input data types, so use `AnyDataType` here to bypass our + // catalyst type checking framework. + override def inputTypes: Seq[AbstractDataType] = children.map(_ => AnyDataType) + + override def nullable: Boolean = true + + override def supportsPartial: Boolean = false + + override lazy val dataType: DataType = inspectorToDataType(returnInspector) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index d9ba895e1e..3c8a0091c8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -131,7 +131,7 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton { hiveContext.setConf(SQLConf.CODEGEN_ENABLED, codegenDefault) } - test("SPARK-6409 UDAFAverage test") { + test("SPARK-6409 UDAF Average test") { sql(s"CREATE TEMPORARY FUNCTION test_avg AS '${classOf[GenericUDAFAverage].getName}'") checkAnswer( sql("SELECT test_avg(1), test_avg(substr(value,5)) FROM src"), |