aboutsummaryrefslogtreecommitdiff
path: root/sql/hive
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@163.com>2015-09-24 09:54:07 -0700
committerYin Huai <yhuai@databricks.com>2015-09-24 09:54:07 -0700
commit341b13f8f5eb118f1fb4d4f84418715ac4750a4d (patch)
treee6b8d90482e7d8e6cc339a6f357d320e92d7e8ff /sql/hive
parent02144d6745ec0a6d8877d969feb82139bd22437f (diff)
downloadspark-341b13f8f5eb118f1fb4d4f84418715ac4750a4d.tar.gz
spark-341b13f8f5eb118f1fb4d4f84418715ac4750a4d.tar.bz2
spark-341b13f8f5eb118f1fb4d4f84418715ac4750a4d.zip
[SPARK-10765] [SQL] use new aggregate interface for hive UDAF
Author: Wenchen Fan <cloud0fan@163.com> Closes #8874 from cloud-fan/hive-agg.
Diffstat (limited to 'sql/hive')
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala139
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala2
2 files changed, 58 insertions, 83 deletions
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index cad02373e5..fa9012b96e 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -37,6 +37,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
@@ -65,9 +66,10 @@ private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry)
HiveGenericUDF(new HiveFunctionWrapper(functionClassName), children)
} else if (
classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) {
- HiveGenericUDAF(new HiveFunctionWrapper(functionClassName), children)
+ HiveUDAFFunction(new HiveFunctionWrapper(functionClassName), children)
} else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) {
- HiveUDAF(new HiveFunctionWrapper(functionClassName), children)
+ HiveUDAFFunction(
+ new HiveFunctionWrapper(functionClassName), children, isUDAFBridgeRequired = true)
} else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) {
HiveGenericUDTF(new HiveFunctionWrapper(functionClassName), children)
} else {
@@ -441,70 +443,6 @@ private[hive] case class HiveWindowFunction(
new HiveWindowFunction(funcWrapper, pivotResult, isUDAFBridgeRequired, children)
}
-private[hive] case class HiveGenericUDAF(
- funcWrapper: HiveFunctionWrapper,
- children: Seq[Expression]) extends AggregateExpression1
- with HiveInspectors {
-
- type UDFType = AbstractGenericUDAFResolver
-
- @transient
- protected lazy val resolver: AbstractGenericUDAFResolver = funcWrapper.createFunction()
-
- @transient
- protected lazy val objectInspector = {
- val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors.toArray, false, false)
- resolver.getEvaluator(parameterInfo)
- .init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray)
- }
-
- @transient
- protected lazy val inspectors = children.map(toInspector)
-
- def dataType: DataType = inspectorToDataType(objectInspector)
-
- def nullable: Boolean = true
-
- override def toString: String = {
- s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
- }
-
- def newInstance(): HiveUDAFFunction = new HiveUDAFFunction(funcWrapper, children, this)
-}
-
-/** It is used as a wrapper for the hive functions which uses UDAF interface */
-private[hive] case class HiveUDAF(
- funcWrapper: HiveFunctionWrapper,
- children: Seq[Expression]) extends AggregateExpression1
- with HiveInspectors {
-
- type UDFType = UDAF
-
- @transient
- protected lazy val resolver: AbstractGenericUDAFResolver =
- new GenericUDAFBridge(funcWrapper.createFunction())
-
- @transient
- protected lazy val objectInspector = {
- val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors.toArray, false, false)
- resolver.getEvaluator(parameterInfo)
- .init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray)
- }
-
- @transient
- protected lazy val inspectors = children.map(toInspector)
-
- def dataType: DataType = inspectorToDataType(objectInspector)
-
- def nullable: Boolean = true
-
- override def toString: String = {
- s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
- }
-
- def newInstance(): HiveUDAFFunction = new HiveUDAFFunction(funcWrapper, children, this, true)
-}
-
/**
* Converts a Hive Generic User Defined Table Generating Function (UDTF) to a
* [[Generator]]. Note that the semantics of Generators do not allow
@@ -584,49 +522,86 @@ private[hive] case class HiveGenericUDTF(
}
}
+/**
+ * Currently we don't support partial aggregation for queries using Hive UDAF, which may hurt
+ * performance a lot.
+ */
private[hive] case class HiveUDAFFunction(
funcWrapper: HiveFunctionWrapper,
- exprs: Seq[Expression],
- base: AggregateExpression1,
+ children: Seq[Expression],
isUDAFBridgeRequired: Boolean = false)
- extends AggregateFunction1
- with HiveInspectors {
+ extends AggregateFunction2 with HiveInspectors {
- def this() = this(null, null, null)
+ def this() = this(null, null)
- private val resolver =
+ @transient
+ private lazy val resolver =
if (isUDAFBridgeRequired) {
new GenericUDAFBridge(funcWrapper.createFunction[UDAF]())
} else {
funcWrapper.createFunction[AbstractGenericUDAFResolver]()
}
- private val inspectors = exprs.map(toInspector).toArray
+ @transient
+ private lazy val inspectors = children.map(toInspector).toArray
- private val function = {
+ @transient
+ private lazy val functionAndInspector = {
val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors, false, false)
- resolver.getEvaluator(parameterInfo)
+ val f = resolver.getEvaluator(parameterInfo)
+ f -> f.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors)
}
- private val returnInspector = function.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors)
+ @transient
+ private lazy val function = functionAndInspector._1
+
+ @transient
+ private lazy val returnInspector = functionAndInspector._2
- private val buffer =
- function.getNewAggregationBuffer
+ @transient
+ private[this] var buffer: GenericUDAFEvaluator.AggregationBuffer = _
override def eval(input: InternalRow): Any = unwrap(function.evaluate(buffer), returnInspector)
@transient
- val inputProjection = new InterpretedProjection(exprs)
+ private lazy val inputProjection = new InterpretedProjection(children)
@transient
- protected lazy val cached = new Array[AnyRef](exprs.length)
+ private lazy val cached = new Array[AnyRef](children.length)
@transient
- private lazy val inputDataTypes: Array[DataType] = exprs.map(_.dataType).toArray
+ private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray
+
+ // Hive UDAF has its own buffer, so we don't need to occupy a slot in the aggregation
+ // buffer for it.
+ override def bufferSchema: StructType = StructType(Nil)
- def update(input: InternalRow): Unit = {
+ override def update(_buffer: MutableRow, input: InternalRow): Unit = {
val inputs = inputProjection(input)
function.iterate(buffer, wrap(inputs, inspectors, cached, inputDataTypes))
}
+
+ override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
+ throw new UnsupportedOperationException(
+ "Hive UDAF doesn't support partial aggregate")
+ }
+
+ override def cloneBufferAttributes: Seq[Attribute] = Nil
+
+ override def initialize(_buffer: MutableRow): Unit = {
+ buffer = function.getNewAggregationBuffer
+ }
+
+ override def bufferAttributes: Seq[AttributeReference] = Nil
+
+ // We rely on Hive to check the input data types, so use `AnyDataType` here to bypass our
+ // catalyst type checking framework.
+ override def inputTypes: Seq[AbstractDataType] = children.map(_ => AnyDataType)
+
+ override def nullable: Boolean = true
+
+ override def supportsPartial: Boolean = false
+
+ override lazy val dataType: DataType = inspectorToDataType(returnInspector)
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
index d9ba895e1e..3c8a0091c8 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
@@ -131,7 +131,7 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton {
hiveContext.setConf(SQLConf.CODEGEN_ENABLED, codegenDefault)
}
- test("SPARK-6409 UDAFAverage test") {
+ test("SPARK-6409 UDAF Average test") {
sql(s"CREATE TEMPORARY FUNCTION test_avg AS '${classOf[GenericUDAFAverage].getName}'")
checkAnswer(
sql("SELECT test_avg(1), test_avg(substr(value,5)) FROM src"),