diff options
Diffstat (limited to 'sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala')
-rw-r--r-- | sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala | 15 |
1 files changed, 12 insertions, 3 deletions
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 962dd5a52e..d54913518b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -71,6 +71,9 @@ private[hive] case class HiveSimpleUDF( override lazy val dataType = javaClassToDataType(method.getReturnType) @transient + private lazy val wrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray + + @transient lazy val unwrapper = unwrapperFor(ObjectInspectorFactory.getReflectionObjectInspector( method.getGenericReturnType(), ObjectInspectorOptions.JAVA)) @@ -82,7 +85,7 @@ private[hive] case class HiveSimpleUDF( // TODO: Finish input output types. override def eval(input: InternalRow): Any = { - val inputs = wrap(children.map(_.eval(input)), arguments, cached, inputDataTypes) + val inputs = wrap(children.map(_.eval(input)), wrappers, cached, inputDataTypes) val ret = FunctionRegistry.invoke( method, function, @@ -215,6 +218,9 @@ private[hive] case class HiveGenericUDTF( private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray @transient + private lazy val wrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray + + @transient private lazy val unwrapper = unwrapperFor(outputInspector) override def eval(input: InternalRow): TraversableOnce[InternalRow] = { @@ -222,7 +228,7 @@ private[hive] case class HiveGenericUDTF( val inputProjection = new InterpretedProjection(children) - function.process(wrap(inputProjection(input), inputInspectors, udtInput, inputDataTypes)) + function.process(wrap(inputProjection(input), wrappers, udtInput, inputDataTypes)) collector.collectRows() } @@ -297,6 +303,9 @@ private[hive] case class HiveUDAFFunction( private lazy val function = functionAndInspector._1 @transient + private lazy val wrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray + + @transient private lazy val returnInspector = functionAndInspector._2 @transient @@ -322,7 +331,7 @@ private[hive] case class HiveUDAFFunction( override def update(_buffer: MutableRow, input: InternalRow): Unit = { val inputs = inputProjection(input) - function.iterate(buffer, wrap(inputs, inspectors, cached, inputDataTypes)) + function.iterate(buffer, wrap(inputs, wrappers, cached, inputDataTypes)) } override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { |