aboutsummaryrefslogtreecommitdiff
path: root/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
diff options
context:
space:
mode:
Diffstat (limited to 'sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala')
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala15
1 files changed, 12 insertions, 3 deletions
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index 962dd5a52e..d54913518b 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -71,6 +71,9 @@ private[hive] case class HiveSimpleUDF(
override lazy val dataType = javaClassToDataType(method.getReturnType)
@transient
+ private lazy val wrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray
+
+ @transient
lazy val unwrapper = unwrapperFor(ObjectInspectorFactory.getReflectionObjectInspector(
method.getGenericReturnType(), ObjectInspectorOptions.JAVA))
@@ -82,7 +85,7 @@ private[hive] case class HiveSimpleUDF(
// TODO: Finish input output types.
override def eval(input: InternalRow): Any = {
- val inputs = wrap(children.map(_.eval(input)), arguments, cached, inputDataTypes)
+ val inputs = wrap(children.map(_.eval(input)), wrappers, cached, inputDataTypes)
val ret = FunctionRegistry.invoke(
method,
function,
@@ -215,6 +218,9 @@ private[hive] case class HiveGenericUDTF(
private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray
@transient
+ private lazy val wrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray
+
+ @transient
private lazy val unwrapper = unwrapperFor(outputInspector)
override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
@@ -222,7 +228,7 @@ private[hive] case class HiveGenericUDTF(
val inputProjection = new InterpretedProjection(children)
- function.process(wrap(inputProjection(input), inputInspectors, udtInput, inputDataTypes))
+ function.process(wrap(inputProjection(input), wrappers, udtInput, inputDataTypes))
collector.collectRows()
}
@@ -297,6 +303,9 @@ private[hive] case class HiveUDAFFunction(
private lazy val function = functionAndInspector._1
@transient
+ private lazy val wrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray
+
+ @transient
private lazy val returnInspector = functionAndInspector._2
@transient
@@ -322,7 +331,7 @@ private[hive] case class HiveUDAFFunction(
override def update(_buffer: MutableRow, input: InternalRow): Unit = {
val inputs = inputProjection(input)
- function.iterate(buffer, wrap(inputs, inspectors, cached, inputDataTypes))
+ function.iterate(buffer, wrap(inputs, wrappers, cached, inputDataTypes))
}
override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {