aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorHerman van Hovell <hvanhovell@databricks.com>2016-06-22 11:36:32 -0700
committerHerman van Hovell <hvanhovell@databricks.com>2016-06-22 11:36:32 -0700
commit472d611a70da02d95e36da754435a3ac562f8b24 (patch)
treef7add948b2f8c43fa958a5f169405a2f746015ad /sql
parentc2cebdb7ddff3d041d548fe1cd8de4efb31b294f (diff)
downloadspark-472d611a70da02d95e36da754435a3ac562f8b24.tar.gz
spark-472d611a70da02d95e36da754435a3ac562f8b24.tar.bz2
spark-472d611a70da02d95e36da754435a3ac562f8b24.zip
[SPARK-15956][SQL] Revert "[] When unwrapping ORC avoid pattern matching…
This reverts commit 0a9c02759515c41de37db6381750bc3a316c860c. It breaks the 2.10 build, I'll fix this in a different PR. Author: Herman van Hovell <hvanhovell@databricks.com> Closes #13853 from hvanhovell/SPARK-15956-revert.
Diffstat (limited to 'sql')
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala428
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala3
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala6
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala21
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala6
5 files changed, 150 insertions, 314 deletions
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
index 1aadc8b31b..585befe378 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
@@ -239,6 +239,145 @@ private[hive] trait HiveInspectors {
}
/**
+ * Converts hive types to native catalyst types.
+ * @param data the data in Hive type
+ * @param oi the ObjectInspector associated with the Hive Type
+ * @return convert the data into catalyst type
+ * TODO return the function of (data => Any) instead for performance consideration
+ *
+ * Strictly follows the following order in unwrapping (constant OI has the higher priority):
+ * Constant Null object inspector =>
+ * return null
+ * Constant object inspector =>
+ * extract the value from constant object inspector
+ * Check whether the `data` is null =>
+ * return null if true
+ * If object inspector prefers writable =>
+ * extract writable from `data` and then get the catalyst type from the writable
+ * Extract the java object directly from the object inspector
+ *
+ * NOTICE: the complex data type requires recursive unwrapping.
+ */
+ def unwrap(data: Any, oi: ObjectInspector): Any = oi match {
+ case coi: ConstantObjectInspector if coi.getWritableConstantValue == null => null
+ case poi: WritableConstantStringObjectInspector =>
+ UTF8String.fromString(poi.getWritableConstantValue.toString)
+ case poi: WritableConstantHiveVarcharObjectInspector =>
+ UTF8String.fromString(poi.getWritableConstantValue.getHiveVarchar.getValue)
+ case poi: WritableConstantHiveCharObjectInspector =>
+ UTF8String.fromString(poi.getWritableConstantValue.getHiveChar.getValue)
+ case poi: WritableConstantHiveDecimalObjectInspector =>
+ HiveShim.toCatalystDecimal(
+ PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector,
+ poi.getWritableConstantValue.getHiveDecimal)
+ case poi: WritableConstantTimestampObjectInspector =>
+ val t = poi.getWritableConstantValue
+ t.getSeconds * 1000000L + t.getNanos / 1000L
+ case poi: WritableConstantIntObjectInspector =>
+ poi.getWritableConstantValue.get()
+ case poi: WritableConstantDoubleObjectInspector =>
+ poi.getWritableConstantValue.get()
+ case poi: WritableConstantBooleanObjectInspector =>
+ poi.getWritableConstantValue.get()
+ case poi: WritableConstantLongObjectInspector =>
+ poi.getWritableConstantValue.get()
+ case poi: WritableConstantFloatObjectInspector =>
+ poi.getWritableConstantValue.get()
+ case poi: WritableConstantShortObjectInspector =>
+ poi.getWritableConstantValue.get()
+ case poi: WritableConstantByteObjectInspector =>
+ poi.getWritableConstantValue.get()
+ case poi: WritableConstantBinaryObjectInspector =>
+ val writable = poi.getWritableConstantValue
+ val temp = new Array[Byte](writable.getLength)
+ System.arraycopy(writable.getBytes, 0, temp, 0, temp.length)
+ temp
+ case poi: WritableConstantDateObjectInspector =>
+ DateTimeUtils.fromJavaDate(poi.getWritableConstantValue.get())
+ case mi: StandardConstantMapObjectInspector =>
+ // take the value from the map inspector object, rather than the input data
+ val keyValues = mi.getWritableConstantValue.asScala.toSeq
+ val keys = keyValues.map(kv => unwrap(kv._1, mi.getMapKeyObjectInspector)).toArray
+ val values = keyValues.map(kv => unwrap(kv._2, mi.getMapValueObjectInspector)).toArray
+ ArrayBasedMapData(keys, values)
+ case li: StandardConstantListObjectInspector =>
+ // take the value from the list inspector object, rather than the input data
+ val values = li.getWritableConstantValue.asScala
+ .map(unwrap(_, li.getListElementObjectInspector))
+ .toArray
+ new GenericArrayData(values)
+ // if the value is null, we don't care about the object inspector type
+ case _ if data == null => null
+ case poi: VoidObjectInspector => null // always be null for void object inspector
+ case pi: PrimitiveObjectInspector => pi match {
+ // We think HiveVarchar/HiveChar is also a String
+ case hvoi: HiveVarcharObjectInspector if hvoi.preferWritable() =>
+ UTF8String.fromString(hvoi.getPrimitiveWritableObject(data).getHiveVarchar.getValue)
+ case hvoi: HiveVarcharObjectInspector =>
+ UTF8String.fromString(hvoi.getPrimitiveJavaObject(data).getValue)
+ case hvoi: HiveCharObjectInspector if hvoi.preferWritable() =>
+ UTF8String.fromString(hvoi.getPrimitiveWritableObject(data).getHiveChar.getValue)
+ case hvoi: HiveCharObjectInspector =>
+ UTF8String.fromString(hvoi.getPrimitiveJavaObject(data).getValue)
+ case x: StringObjectInspector if x.preferWritable() =>
+ // Text is in UTF-8 already. No need to convert again via fromString. Copy bytes
+ val wObj = x.getPrimitiveWritableObject(data)
+ val result = wObj.copyBytes()
+ UTF8String.fromBytes(result, 0, result.length)
+ case x: StringObjectInspector =>
+ UTF8String.fromString(x.getPrimitiveJavaObject(data))
+ case x: IntObjectInspector if x.preferWritable() => x.get(data)
+ case x: BooleanObjectInspector if x.preferWritable() => x.get(data)
+ case x: FloatObjectInspector if x.preferWritable() => x.get(data)
+ case x: DoubleObjectInspector if x.preferWritable() => x.get(data)
+ case x: LongObjectInspector if x.preferWritable() => x.get(data)
+ case x: ShortObjectInspector if x.preferWritable() => x.get(data)
+ case x: ByteObjectInspector if x.preferWritable() => x.get(data)
+ case x: HiveDecimalObjectInspector => HiveShim.toCatalystDecimal(x, data)
+ case x: BinaryObjectInspector if x.preferWritable() =>
+ // BytesWritable.copyBytes() only available since Hadoop2
+ // In order to keep backward-compatible, we have to copy the
+ // bytes with old apis
+ val bw = x.getPrimitiveWritableObject(data)
+ val result = new Array[Byte](bw.getLength())
+ System.arraycopy(bw.getBytes(), 0, result, 0, bw.getLength())
+ result
+ case x: DateObjectInspector if x.preferWritable() =>
+ DateTimeUtils.fromJavaDate(x.getPrimitiveWritableObject(data).get())
+ case x: DateObjectInspector => DateTimeUtils.fromJavaDate(x.getPrimitiveJavaObject(data))
+ case x: TimestampObjectInspector if x.preferWritable() =>
+ val t = x.getPrimitiveWritableObject(data)
+ t.getSeconds * 1000000L + t.getNanos / 1000L
+ case ti: TimestampObjectInspector =>
+ DateTimeUtils.fromJavaTimestamp(ti.getPrimitiveJavaObject(data))
+ case _ => pi.getPrimitiveJavaObject(data)
+ }
+ case li: ListObjectInspector =>
+ Option(li.getList(data))
+ .map { l =>
+ val values = l.asScala.map(unwrap(_, li.getListElementObjectInspector)).toArray
+ new GenericArrayData(values)
+ }
+ .orNull
+ case mi: MapObjectInspector =>
+ val map = mi.getMap(data)
+ if (map == null) {
+ null
+ } else {
+ val keyValues = map.asScala.toSeq
+ val keys = keyValues.map(kv => unwrap(kv._1, mi.getMapKeyObjectInspector)).toArray
+ val values = keyValues.map(kv => unwrap(kv._2, mi.getMapValueObjectInspector)).toArray
+ ArrayBasedMapData(keys, values)
+ }
+ // currently, hive doesn't provide the ConstantStructObjectInspector
+ case si: StructObjectInspector =>
+ val allRefs = si.getAllStructFieldRefs
+ InternalRow.fromSeq(allRefs.asScala.map(
+ r => unwrap(si.getStructFieldData(data, r), r.getFieldObjectInspector)))
+ }
+
+
+ /**
* Wraps with Hive types based on object inspector.
* TODO: Consolidate all hive OI/data interface code.
*/
@@ -340,292 +479,8 @@ private[hive] trait HiveInspectors {
}
/**
- * Builds unwrappers ahead of time according to object inspector
+ * Builds specific unwrappers ahead of time according to object inspector
* types to avoid pattern matching and branching costs per row.
- *
- * Strictly follows the following order in unwrapping (constant OI has the higher priority):
- * Constant Null object inspector =>
- * return null
- * Constant object inspector =>
- * extract the value from constant object inspector
- * If object inspector prefers writable =>
- * extract writable from `data` and then get the catalyst type from the writable
- * Extract the java object directly from the object inspector
- *
- * NOTICE: the complex data type requires recursive unwrapping.
- *
- * @param objectInspector the ObjectInspector used to create an unwrapper.
- * @return A function that unwraps data objects.
- * Use the overloaded HiveStructField version for in-place updating of a MutableRow.
- */
- def unwrapperFor(objectInspector: ObjectInspector): Any => Any =
- objectInspector match {
- case coi: ConstantObjectInspector if coi.getWritableConstantValue == null =>
- _ => null
- case poi: WritableConstantStringObjectInspector =>
- val constant = UTF8String.fromString(poi.getWritableConstantValue.toString)
- _ => constant
- case poi: WritableConstantHiveVarcharObjectInspector =>
- val constant = UTF8String.fromString(poi.getWritableConstantValue.getHiveVarchar.getValue)
- _ => constant
- case poi: WritableConstantHiveCharObjectInspector =>
- val constant = UTF8String.fromString(poi.getWritableConstantValue.getHiveChar.getValue)
- _ => constant
- case poi: WritableConstantHiveDecimalObjectInspector =>
- val constant = HiveShim.toCatalystDecimal(
- PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector,
- poi.getWritableConstantValue.getHiveDecimal)
- _ => constant
- case poi: WritableConstantTimestampObjectInspector =>
- val t = poi.getWritableConstantValue
- val constant = t.getSeconds * 1000000L + t.getNanos / 1000L
- _ => constant
- case poi: WritableConstantIntObjectInspector =>
- val constant = poi.getWritableConstantValue.get()
- _ => constant
- case poi: WritableConstantDoubleObjectInspector =>
- val constant = poi.getWritableConstantValue.get()
- _ => constant
- case poi: WritableConstantBooleanObjectInspector =>
- val constant = poi.getWritableConstantValue.get()
- _ => constant
- case poi: WritableConstantLongObjectInspector =>
- val constant = poi.getWritableConstantValue.get()
- _ => constant
- case poi: WritableConstantFloatObjectInspector =>
- val constant = poi.getWritableConstantValue.get()
- _ => constant
- case poi: WritableConstantShortObjectInspector =>
- val constant = poi.getWritableConstantValue.get()
- _ => constant
- case poi: WritableConstantByteObjectInspector =>
- val constant = poi.getWritableConstantValue.get()
- _ => constant
- case poi: WritableConstantBinaryObjectInspector =>
- val writable = poi.getWritableConstantValue
- val constant = new Array[Byte](writable.getLength)
- System.arraycopy(writable.getBytes, 0, constant, 0, constant.length)
- _ => constant
- case poi: WritableConstantDateObjectInspector =>
- val constant = DateTimeUtils.fromJavaDate(poi.getWritableConstantValue.get())
- _ => constant
- case mi: StandardConstantMapObjectInspector =>
- val keyUnwrapper = unwrapperFor(mi.getMapKeyObjectInspector)
- val valueUnwrapper = unwrapperFor(mi.getMapValueObjectInspector)
- val keyValues = mi.getWritableConstantValue.asScala.toSeq
- val keys = keyValues.map(kv => keyUnwrapper(kv._1)).toArray
- val values = keyValues.map(kv => valueUnwrapper(kv._2)).toArray
- val constant = ArrayBasedMapData(keys, values)
- _ => constant
- case li: StandardConstantListObjectInspector =>
- val unwrapper = unwrapperFor(li.getListElementObjectInspector)
- val values = li.getWritableConstantValue.asScala
- .map(unwrapper)
- .toArray
- val constant = new GenericArrayData(values)
- _ => constant
- case poi: VoidObjectInspector =>
- _ => null // always be null for void object inspector
- case pi: PrimitiveObjectInspector => pi match {
- // We think HiveVarchar/HiveChar is also a String
- case hvoi: HiveVarcharObjectInspector if hvoi.preferWritable() =>
- data: Any => {
- if (data != null) {
- UTF8String.fromString(hvoi.getPrimitiveWritableObject(data).getHiveVarchar.getValue)
- } else {
- null
- }
- }
- case hvoi: HiveVarcharObjectInspector =>
- data: Any => {
- if (data != null) {
- UTF8String.fromString(hvoi.getPrimitiveJavaObject(data).getValue)
- } else {
- null
- }
- }
- case hvoi: HiveCharObjectInspector if hvoi.preferWritable() =>
- data: Any => {
- if (data != null) {
- UTF8String.fromString(hvoi.getPrimitiveWritableObject(data).getHiveChar.getValue)
- } else {
- null
- }
- }
- case hvoi: HiveCharObjectInspector =>
- data: Any => {
- if (data != null) {
- UTF8String.fromString(hvoi.getPrimitiveJavaObject(data).getValue)
- } else {
- null
- }
- }
- case x: StringObjectInspector if x.preferWritable() =>
- data: Any => {
- if (data != null) {
- // Text is in UTF-8 already. No need to convert again via fromString. Copy bytes
- val wObj = x.getPrimitiveWritableObject(data)
- val result = wObj.copyBytes()
- UTF8String.fromBytes(result, 0, result.length)
- } else {
- null
- }
- }
- case x: StringObjectInspector =>
- data: Any => {
- if (data != null) {
- UTF8String.fromString(x.getPrimitiveJavaObject(data))
- } else {
- null
- }
- }
- case x: IntObjectInspector if x.preferWritable() =>
- data: Any => {
- if (data != null) x.get(data) else null
- }
- case x: BooleanObjectInspector if x.preferWritable() =>
- data: Any => {
- if (data != null) x.get(data) else null
- }
- case x: FloatObjectInspector if x.preferWritable() =>
- data: Any => {
- if (data != null) x.get(data) else null
- }
- case x: DoubleObjectInspector if x.preferWritable() =>
- data: Any => {
- if (data != null) x.get(data) else null
- }
- case x: LongObjectInspector if x.preferWritable() =>
- data: Any => {
- if (data != null) x.get(data) else null
- }
- case x: ShortObjectInspector if x.preferWritable() =>
- data: Any => {
- if (data != null) x.get(data) else null
- }
- case x: ByteObjectInspector if x.preferWritable() =>
- data: Any => {
- if (data != null) x.get(data) else null
- }
- case x: HiveDecimalObjectInspector =>
- data: Any => {
- if (data != null) {
- HiveShim.toCatalystDecimal(x, data)
- } else {
- null
- }
- }
- case x: BinaryObjectInspector if x.preferWritable() =>
- data: Any => {
- if (data != null) {
- // BytesWritable.copyBytes() only available since Hadoop2
- // In order to keep backward-compatible, we have to copy the
- // bytes with old apis
- val bw = x.getPrimitiveWritableObject(data)
- val result = new Array[Byte](bw.getLength())
- System.arraycopy(bw.getBytes(), 0, result, 0, bw.getLength())
- result
- } else {
- null
- }
- }
- case x: DateObjectInspector if x.preferWritable() =>
- data: Any => {
- if (data != null) {
- DateTimeUtils.fromJavaDate(x.getPrimitiveWritableObject(data).get())
- } else {
- null
- }
- }
- case x: DateObjectInspector =>
- data: Any => {
- if (data != null) {
- DateTimeUtils.fromJavaDate(x.getPrimitiveJavaObject(data))
- } else {
- null
- }
- }
- case x: TimestampObjectInspector if x.preferWritable() =>
- data: Any => {
- if (data != null) {
- val t = x.getPrimitiveWritableObject(data)
- t.getSeconds * 1000000L + t.getNanos / 1000L
- } else {
- null
- }
- }
- case ti: TimestampObjectInspector =>
- data: Any => {
- if (data != null) {
- DateTimeUtils.fromJavaTimestamp(ti.getPrimitiveJavaObject(data))
- } else {
- null
- }
- }
- case _ =>
- data: Any => {
- if (data != null) {
- pi.getPrimitiveJavaObject(data)
- } else {
- null
- }
- }
- }
- case li: ListObjectInspector =>
- val unwrapper = unwrapperFor(li.getListElementObjectInspector)
- data: Any => {
- if (data != null) {
- Option(li.getList(data))
- .map { l =>
- val values = l.asScala.map(unwrapper).toArray
- new GenericArrayData(values)
- }
- .orNull
- } else {
- null
- }
- }
- case mi: MapObjectInspector =>
- val keyUnwrapper = unwrapperFor(mi.getMapKeyObjectInspector)
- val valueUnwrapper = unwrapperFor(mi.getMapValueObjectInspector)
- data: Any => {
- if (data != null) {
- val map = mi.getMap(data)
- if (map == null) {
- null
- } else {
- val keyValues = map.asScala.toSeq
- val keys = keyValues.map(kv => keyUnwrapper(kv._1)).toArray
- val values = keyValues.map(kv => valueUnwrapper(kv._2)).toArray
- ArrayBasedMapData(keys, values)
- }
- } else {
- null
- }
- }
- // currently, hive doesn't provide the ConstantStructObjectInspector
- case si: StructObjectInspector =>
- val fields = si.getAllStructFieldRefs.asScala
- val fieldsToUnwrap = fields.zip(
- fields.map(_.getFieldObjectInspector).map(unwrapperFor))
- data: Any => {
- if (data != null) {
- InternalRow.fromSeq(fieldsToUnwrap.map { case (field, unwrapper) =>
- unwrapper(si.getStructFieldData(data, field))
- })
- } else {
- null
- }
- }
- }
-
- /**
- * Builds unwrappers ahead of time according to object inspector
- * types to avoid pattern matching and branching costs per row.
- *
- * @param field The HiveStructField to create an unwrapper for.
- * @return A function that performs in-place updating of a MutableRow.
- * Use the overloaded ObjectInspector version for assignments.
*/
def unwrapperFor(field: HiveStructField): (Any, MutableRow, Int) => Unit =
field.getFieldObjectInspector match {
@@ -644,8 +499,7 @@ private[hive] trait HiveInspectors {
case oi: DoubleObjectInspector =>
(value: Any, row: MutableRow, ordinal: Int) => row.setDouble(ordinal, oi.get(value))
case oi =>
- val unwrapper = unwrapperFor(oi)
- (value: Any, row: MutableRow, ordinal: Int) => row(ordinal) = unwrapper(value)
+ (value: Any, row: MutableRow, ordinal: Int) => row(ordinal) = unwrap(value, oi)
}
/**
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
index e49a235643..d044811052 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
@@ -401,8 +401,7 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging {
(value: Any, row: MutableRow, ordinal: Int) =>
row.update(ordinal, oi.getPrimitiveJavaObject(value))
case oi =>
- val unwrapper = unwrapperFor(oi)
- (value: Any, row: MutableRow, ordinal: Int) => row(ordinal) = unwrapper(value)
+ (value: Any, row: MutableRow, ordinal: Int) => row(ordinal) = unwrap(value, oi)
}
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
index 84990d3697..9e25e1d40c 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
@@ -127,9 +127,6 @@ case class ScriptTransformation(
}
val mutableRow = new SpecificMutableRow(output.map(_.dataType))
- @transient
- lazy val unwrappers = outputSoi.getAllStructFieldRefs.asScala.map(unwrapperFor)
-
private def checkFailureAndPropagate(cause: Throwable = null): Unit = {
if (writerThread.exception.isDefined) {
throw writerThread.exception.get
@@ -218,12 +215,13 @@ case class ScriptTransformation(
val raw = outputSerde.deserialize(scriptOutputWritable)
scriptOutputWritable = null
val dataList = outputSoi.getStructFieldsDataAsList(raw)
+ val fieldList = outputSoi.getAllStructFieldRefs()
var i = 0
while (i < dataList.size()) {
if (dataList.get(i) == null) {
mutableRow.setNullAt(i)
} else {
- unwrappers(i)(dataList.get(i), mutableRow, i)
+ mutableRow(i) = unwrap(dataList.get(i), fieldList.get(i).getFieldObjectInspector)
}
i += 1
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index 9347aeb8e0..c53675694f 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -71,8 +71,8 @@ private[hive] case class HiveSimpleUDF(
override lazy val dataType = javaClassToDataType(method.getReturnType)
@transient
- lazy val unwrapper = unwrapperFor(ObjectInspectorFactory.getReflectionObjectInspector(
- method.getGenericReturnType(), ObjectInspectorOptions.JAVA))
+ lazy val returnInspector = ObjectInspectorFactory.getReflectionObjectInspector(
+ method.getGenericReturnType(), ObjectInspectorOptions.JAVA)
@transient
private lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length)
@@ -87,7 +87,7 @@ private[hive] case class HiveSimpleUDF(
method,
function,
conversionHelper.convertIfNecessary(inputs : _*): _*)
- unwrapper(ret)
+ unwrap(ret, returnInspector)
}
override def toString: String = {
@@ -134,9 +134,6 @@ private[hive] case class HiveGenericUDF(
}
@transient
- private lazy val unwrapper = unwrapperFor(returnInspector)
-
- @transient
private lazy val isUDFDeterministic = {
val udfType = function.getClass.getAnnotation(classOf[HiveUDFType])
udfType != null && udfType.deterministic()
@@ -159,7 +156,7 @@ private[hive] case class HiveGenericUDF(
.set(() => children(idx).eval(input))
i += 1
}
- unwrapper(function.evaluate(deferredObjects))
+ unwrap(function.evaluate(deferredObjects), returnInspector)
}
override def prettyName: String = name
@@ -213,9 +210,6 @@ private[hive] case class HiveGenericUDTF(
@transient
private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray
- @transient
- private lazy val unwrapper = unwrapperFor(outputInspector)
-
override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
outputInspector // Make sure initialized.
@@ -232,7 +226,7 @@ private[hive] case class HiveGenericUDTF(
// We need to clone the input here because implementations of
// GenericUDTF reuse the same object. Luckily they are always an array, so
// it is easy to clone.
- collected += unwrapper(input).asInstanceOf[InternalRow]
+ collected += unwrap(input, outputInspector).asInstanceOf[InternalRow]
}
def collectRows(): Seq[InternalRow] = {
@@ -299,12 +293,9 @@ private[hive] case class HiveUDAFFunction(
private lazy val returnInspector = functionAndInspector._2
@transient
- private lazy val unwrapper = unwrapperFor(returnInspector)
-
- @transient
private[this] var buffer: GenericUDAFEvaluator.AggregationBuffer = _
- override def eval(input: InternalRow): Any = unwrapper(function.evaluate(buffer))
+ override def eval(input: InternalRow): Any = unwrap(function.evaluate(buffer), returnInspector)
@transient
private lazy val inputProjection = new InterpretedProjection(children)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala
index bc51bcb07e..3b867bbfa1 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala
@@ -35,12 +35,6 @@ import org.apache.spark.sql.types._
import org.apache.spark.sql.Row
class HiveInspectorSuite extends SparkFunSuite with HiveInspectors {
-
- def unwrap(data: Any, oi: ObjectInspector): Any = {
- val unwrapper = unwrapperFor(oi)
- unwrapper(data)
- }
-
test("Test wrap SettableStructObjectInspector") {
val udaf = new UDAFPercentile.PercentileLongEvaluator()
udaf.init()