aboutsummaryrefslogtreecommitdiff
path: root/sql/hive
diff options
context:
space:
mode:
authorBrian Cho <bcho@fb.com>2016-06-22 16:56:55 -0700
committerHerman van Hovell <hvanhovell@databricks.com>2016-06-22 16:56:55 -0700
commit4f869f88ee96fa57be79f972f218111b6feac67f (patch)
treedbefb2af67392cadf027280bac4781db02b3406f /sql/hive
parent044971eca0ff3c2ce62afa665dbd3072d52cbbec (diff)
downloadspark-4f869f88ee96fa57be79f972f218111b6feac67f.tar.gz
spark-4f869f88ee96fa57be79f972f218111b6feac67f.tar.bz2
spark-4f869f88ee96fa57be79f972f218111b6feac67f.zip
[SPARK-15956][SQL] When unwrapping ORC avoid pattern matching at runtime
## What changes were proposed in this pull request? Extend the returning of unwrapper functions from primitive types to all types. This PR is based on https://github.com/apache/spark/pull/13676. It only fixes a bug with scala-2.10 compilation. All credit should go to dafrista. ## How was this patch tested? The patch should pass all unit tests. Reading ORC files with non-primitive types with this change reduced the read time by ~15%. Author: Brian Cho <bcho@fb.com> Author: Herman van Hovell <hvanhovell@databricks.com> Closes #13854 from hvanhovell/SPARK-15956-scala210.
Diffstat (limited to 'sql/hive')
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala428
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala3
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala6
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala21
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala6
5 files changed, 314 insertions, 150 deletions
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
index 585befe378..bf5cc17a68 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
@@ -239,145 +239,6 @@ private[hive] trait HiveInspectors {
}
/**
- * Converts hive types to native catalyst types.
- * @param data the data in Hive type
- * @param oi the ObjectInspector associated with the Hive Type
- * @return convert the data into catalyst type
- * TODO return the function of (data => Any) instead for performance consideration
- *
- * Strictly follows the following order in unwrapping (constant OI has the higher priority):
- * Constant Null object inspector =>
- * return null
- * Constant object inspector =>
- * extract the value from constant object inspector
- * Check whether the `data` is null =>
- * return null if true
- * If object inspector prefers writable =>
- * extract writable from `data` and then get the catalyst type from the writable
- * Extract the java object directly from the object inspector
- *
- * NOTICE: the complex data type requires recursive unwrapping.
- */
- def unwrap(data: Any, oi: ObjectInspector): Any = oi match {
- case coi: ConstantObjectInspector if coi.getWritableConstantValue == null => null
- case poi: WritableConstantStringObjectInspector =>
- UTF8String.fromString(poi.getWritableConstantValue.toString)
- case poi: WritableConstantHiveVarcharObjectInspector =>
- UTF8String.fromString(poi.getWritableConstantValue.getHiveVarchar.getValue)
- case poi: WritableConstantHiveCharObjectInspector =>
- UTF8String.fromString(poi.getWritableConstantValue.getHiveChar.getValue)
- case poi: WritableConstantHiveDecimalObjectInspector =>
- HiveShim.toCatalystDecimal(
- PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector,
- poi.getWritableConstantValue.getHiveDecimal)
- case poi: WritableConstantTimestampObjectInspector =>
- val t = poi.getWritableConstantValue
- t.getSeconds * 1000000L + t.getNanos / 1000L
- case poi: WritableConstantIntObjectInspector =>
- poi.getWritableConstantValue.get()
- case poi: WritableConstantDoubleObjectInspector =>
- poi.getWritableConstantValue.get()
- case poi: WritableConstantBooleanObjectInspector =>
- poi.getWritableConstantValue.get()
- case poi: WritableConstantLongObjectInspector =>
- poi.getWritableConstantValue.get()
- case poi: WritableConstantFloatObjectInspector =>
- poi.getWritableConstantValue.get()
- case poi: WritableConstantShortObjectInspector =>
- poi.getWritableConstantValue.get()
- case poi: WritableConstantByteObjectInspector =>
- poi.getWritableConstantValue.get()
- case poi: WritableConstantBinaryObjectInspector =>
- val writable = poi.getWritableConstantValue
- val temp = new Array[Byte](writable.getLength)
- System.arraycopy(writable.getBytes, 0, temp, 0, temp.length)
- temp
- case poi: WritableConstantDateObjectInspector =>
- DateTimeUtils.fromJavaDate(poi.getWritableConstantValue.get())
- case mi: StandardConstantMapObjectInspector =>
- // take the value from the map inspector object, rather than the input data
- val keyValues = mi.getWritableConstantValue.asScala.toSeq
- val keys = keyValues.map(kv => unwrap(kv._1, mi.getMapKeyObjectInspector)).toArray
- val values = keyValues.map(kv => unwrap(kv._2, mi.getMapValueObjectInspector)).toArray
- ArrayBasedMapData(keys, values)
- case li: StandardConstantListObjectInspector =>
- // take the value from the list inspector object, rather than the input data
- val values = li.getWritableConstantValue.asScala
- .map(unwrap(_, li.getListElementObjectInspector))
- .toArray
- new GenericArrayData(values)
- // if the value is null, we don't care about the object inspector type
- case _ if data == null => null
- case poi: VoidObjectInspector => null // always be null for void object inspector
- case pi: PrimitiveObjectInspector => pi match {
- // We think HiveVarchar/HiveChar is also a String
- case hvoi: HiveVarcharObjectInspector if hvoi.preferWritable() =>
- UTF8String.fromString(hvoi.getPrimitiveWritableObject(data).getHiveVarchar.getValue)
- case hvoi: HiveVarcharObjectInspector =>
- UTF8String.fromString(hvoi.getPrimitiveJavaObject(data).getValue)
- case hvoi: HiveCharObjectInspector if hvoi.preferWritable() =>
- UTF8String.fromString(hvoi.getPrimitiveWritableObject(data).getHiveChar.getValue)
- case hvoi: HiveCharObjectInspector =>
- UTF8String.fromString(hvoi.getPrimitiveJavaObject(data).getValue)
- case x: StringObjectInspector if x.preferWritable() =>
- // Text is in UTF-8 already. No need to convert again via fromString. Copy bytes
- val wObj = x.getPrimitiveWritableObject(data)
- val result = wObj.copyBytes()
- UTF8String.fromBytes(result, 0, result.length)
- case x: StringObjectInspector =>
- UTF8String.fromString(x.getPrimitiveJavaObject(data))
- case x: IntObjectInspector if x.preferWritable() => x.get(data)
- case x: BooleanObjectInspector if x.preferWritable() => x.get(data)
- case x: FloatObjectInspector if x.preferWritable() => x.get(data)
- case x: DoubleObjectInspector if x.preferWritable() => x.get(data)
- case x: LongObjectInspector if x.preferWritable() => x.get(data)
- case x: ShortObjectInspector if x.preferWritable() => x.get(data)
- case x: ByteObjectInspector if x.preferWritable() => x.get(data)
- case x: HiveDecimalObjectInspector => HiveShim.toCatalystDecimal(x, data)
- case x: BinaryObjectInspector if x.preferWritable() =>
- // BytesWritable.copyBytes() only available since Hadoop2
- // In order to keep backward-compatible, we have to copy the
- // bytes with old apis
- val bw = x.getPrimitiveWritableObject(data)
- val result = new Array[Byte](bw.getLength())
- System.arraycopy(bw.getBytes(), 0, result, 0, bw.getLength())
- result
- case x: DateObjectInspector if x.preferWritable() =>
- DateTimeUtils.fromJavaDate(x.getPrimitiveWritableObject(data).get())
- case x: DateObjectInspector => DateTimeUtils.fromJavaDate(x.getPrimitiveJavaObject(data))
- case x: TimestampObjectInspector if x.preferWritable() =>
- val t = x.getPrimitiveWritableObject(data)
- t.getSeconds * 1000000L + t.getNanos / 1000L
- case ti: TimestampObjectInspector =>
- DateTimeUtils.fromJavaTimestamp(ti.getPrimitiveJavaObject(data))
- case _ => pi.getPrimitiveJavaObject(data)
- }
- case li: ListObjectInspector =>
- Option(li.getList(data))
- .map { l =>
- val values = l.asScala.map(unwrap(_, li.getListElementObjectInspector)).toArray
- new GenericArrayData(values)
- }
- .orNull
- case mi: MapObjectInspector =>
- val map = mi.getMap(data)
- if (map == null) {
- null
- } else {
- val keyValues = map.asScala.toSeq
- val keys = keyValues.map(kv => unwrap(kv._1, mi.getMapKeyObjectInspector)).toArray
- val values = keyValues.map(kv => unwrap(kv._2, mi.getMapValueObjectInspector)).toArray
- ArrayBasedMapData(keys, values)
- }
- // currently, hive doesn't provide the ConstantStructObjectInspector
- case si: StructObjectInspector =>
- val allRefs = si.getAllStructFieldRefs
- InternalRow.fromSeq(allRefs.asScala.map(
- r => unwrap(si.getStructFieldData(data, r), r.getFieldObjectInspector)))
- }
-
-
- /**
* Wraps with Hive types based on object inspector.
* TODO: Consolidate all hive OI/data interface code.
*/
@@ -479,8 +340,292 @@ private[hive] trait HiveInspectors {
}
/**
- * Builds specific unwrappers ahead of time according to object inspector
+ * Builds unwrappers ahead of time according to object inspector
+ * types to avoid pattern matching and branching costs per row.
+ *
+ * Strictly follows the following order in unwrapping (constant OI has the higher priority):
+ * Constant Null object inspector =>
+ * return null
+ * Constant object inspector =>
+ * extract the value from constant object inspector
+ * If object inspector prefers writable =>
+ * extract writable from `data` and then get the catalyst type from the writable
+ * Extract the java object directly from the object inspector
+ *
+ * NOTICE: the complex data type requires recursive unwrapping.
+ *
+ * @param objectInspector the ObjectInspector used to create an unwrapper.
+ * @return A function that unwraps data objects.
+ * Use the overloaded HiveStructField version for in-place updating of a MutableRow.
+ */
+ def unwrapperFor(objectInspector: ObjectInspector): Any => Any =
+ objectInspector match {
+ case coi: ConstantObjectInspector if coi.getWritableConstantValue == null =>
+ _ => null
+ case poi: WritableConstantStringObjectInspector =>
+ val constant = UTF8String.fromString(poi.getWritableConstantValue.toString)
+ _ => constant
+ case poi: WritableConstantHiveVarcharObjectInspector =>
+ val constant = UTF8String.fromString(poi.getWritableConstantValue.getHiveVarchar.getValue)
+ _ => constant
+ case poi: WritableConstantHiveCharObjectInspector =>
+ val constant = UTF8String.fromString(poi.getWritableConstantValue.getHiveChar.getValue)
+ _ => constant
+ case poi: WritableConstantHiveDecimalObjectInspector =>
+ val constant = HiveShim.toCatalystDecimal(
+ PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector,
+ poi.getWritableConstantValue.getHiveDecimal)
+ _ => constant
+ case poi: WritableConstantTimestampObjectInspector =>
+ val t = poi.getWritableConstantValue
+ val constant = t.getSeconds * 1000000L + t.getNanos / 1000L
+ _ => constant
+ case poi: WritableConstantIntObjectInspector =>
+ val constant = poi.getWritableConstantValue.get()
+ _ => constant
+ case poi: WritableConstantDoubleObjectInspector =>
+ val constant = poi.getWritableConstantValue.get()
+ _ => constant
+ case poi: WritableConstantBooleanObjectInspector =>
+ val constant = poi.getWritableConstantValue.get()
+ _ => constant
+ case poi: WritableConstantLongObjectInspector =>
+ val constant = poi.getWritableConstantValue.get()
+ _ => constant
+ case poi: WritableConstantFloatObjectInspector =>
+ val constant = poi.getWritableConstantValue.get()
+ _ => constant
+ case poi: WritableConstantShortObjectInspector =>
+ val constant = poi.getWritableConstantValue.get()
+ _ => constant
+ case poi: WritableConstantByteObjectInspector =>
+ val constant = poi.getWritableConstantValue.get()
+ _ => constant
+ case poi: WritableConstantBinaryObjectInspector =>
+ val writable = poi.getWritableConstantValue
+ val constant = new Array[Byte](writable.getLength)
+ System.arraycopy(writable.getBytes, 0, constant, 0, constant.length)
+ _ => constant
+ case poi: WritableConstantDateObjectInspector =>
+ val constant = DateTimeUtils.fromJavaDate(poi.getWritableConstantValue.get())
+ _ => constant
+ case mi: StandardConstantMapObjectInspector =>
+ val keyUnwrapper = unwrapperFor(mi.getMapKeyObjectInspector)
+ val valueUnwrapper = unwrapperFor(mi.getMapValueObjectInspector)
+ val keyValues = mi.getWritableConstantValue.asScala.toSeq
+ val keys = keyValues.map(kv => keyUnwrapper(kv._1)).toArray
+ val values = keyValues.map(kv => valueUnwrapper(kv._2)).toArray
+ val constant = ArrayBasedMapData(keys, values)
+ _ => constant
+ case li: StandardConstantListObjectInspector =>
+ val unwrapper = unwrapperFor(li.getListElementObjectInspector)
+ val values = li.getWritableConstantValue.asScala
+ .map(unwrapper)
+ .toArray
+ val constant = new GenericArrayData(values)
+ _ => constant
+ case poi: VoidObjectInspector =>
+ _ => null // always be null for void object inspector
+ case pi: PrimitiveObjectInspector => pi match {
+ // We think HiveVarchar/HiveChar is also a String
+ case hvoi: HiveVarcharObjectInspector if hvoi.preferWritable() =>
+ data: Any => {
+ if (data != null) {
+ UTF8String.fromString(hvoi.getPrimitiveWritableObject(data).getHiveVarchar.getValue)
+ } else {
+ null
+ }
+ }
+ case hvoi: HiveVarcharObjectInspector =>
+ data: Any => {
+ if (data != null) {
+ UTF8String.fromString(hvoi.getPrimitiveJavaObject(data).getValue)
+ } else {
+ null
+ }
+ }
+ case hvoi: HiveCharObjectInspector if hvoi.preferWritable() =>
+ data: Any => {
+ if (data != null) {
+ UTF8String.fromString(hvoi.getPrimitiveWritableObject(data).getHiveChar.getValue)
+ } else {
+ null
+ }
+ }
+ case hvoi: HiveCharObjectInspector =>
+ data: Any => {
+ if (data != null) {
+ UTF8String.fromString(hvoi.getPrimitiveJavaObject(data).getValue)
+ } else {
+ null
+ }
+ }
+ case x: StringObjectInspector if x.preferWritable() =>
+ data: Any => {
+ if (data != null) {
+ // Text is in UTF-8 already. No need to convert again via fromString. Copy bytes
+ val wObj = x.getPrimitiveWritableObject(data)
+ val result = wObj.copyBytes()
+ UTF8String.fromBytes(result, 0, result.length)
+ } else {
+ null
+ }
+ }
+ case x: StringObjectInspector =>
+ data: Any => {
+ if (data != null) {
+ UTF8String.fromString(x.getPrimitiveJavaObject(data))
+ } else {
+ null
+ }
+ }
+ case x: IntObjectInspector if x.preferWritable() =>
+ data: Any => {
+ if (data != null) x.get(data) else null
+ }
+ case x: BooleanObjectInspector if x.preferWritable() =>
+ data: Any => {
+ if (data != null) x.get(data) else null
+ }
+ case x: FloatObjectInspector if x.preferWritable() =>
+ data: Any => {
+ if (data != null) x.get(data) else null
+ }
+ case x: DoubleObjectInspector if x.preferWritable() =>
+ data: Any => {
+ if (data != null) x.get(data) else null
+ }
+ case x: LongObjectInspector if x.preferWritable() =>
+ data: Any => {
+ if (data != null) x.get(data) else null
+ }
+ case x: ShortObjectInspector if x.preferWritable() =>
+ data: Any => {
+ if (data != null) x.get(data) else null
+ }
+ case x: ByteObjectInspector if x.preferWritable() =>
+ data: Any => {
+ if (data != null) x.get(data) else null
+ }
+ case x: HiveDecimalObjectInspector =>
+ data: Any => {
+ if (data != null) {
+ HiveShim.toCatalystDecimal(x, data)
+ } else {
+ null
+ }
+ }
+ case x: BinaryObjectInspector if x.preferWritable() =>
+ data: Any => {
+ if (data != null) {
+ // BytesWritable.copyBytes() only available since Hadoop2
+ // In order to keep backward-compatible, we have to copy the
+ // bytes with old apis
+ val bw = x.getPrimitiveWritableObject(data)
+ val result = new Array[Byte](bw.getLength())
+ System.arraycopy(bw.getBytes(), 0, result, 0, bw.getLength())
+ result
+ } else {
+ null
+ }
+ }
+ case x: DateObjectInspector if x.preferWritable() =>
+ data: Any => {
+ if (data != null) {
+ DateTimeUtils.fromJavaDate(x.getPrimitiveWritableObject(data).get())
+ } else {
+ null
+ }
+ }
+ case x: DateObjectInspector =>
+ data: Any => {
+ if (data != null) {
+ DateTimeUtils.fromJavaDate(x.getPrimitiveJavaObject(data))
+ } else {
+ null
+ }
+ }
+ case x: TimestampObjectInspector if x.preferWritable() =>
+ data: Any => {
+ if (data != null) {
+ val t = x.getPrimitiveWritableObject(data)
+ t.getSeconds * 1000000L + t.getNanos / 1000L
+ } else {
+ null
+ }
+ }
+ case ti: TimestampObjectInspector =>
+ data: Any => {
+ if (data != null) {
+ DateTimeUtils.fromJavaTimestamp(ti.getPrimitiveJavaObject(data))
+ } else {
+ null
+ }
+ }
+ case _ =>
+ data: Any => {
+ if (data != null) {
+ pi.getPrimitiveJavaObject(data)
+ } else {
+ null
+ }
+ }
+ }
+ case li: ListObjectInspector =>
+ val unwrapper = unwrapperFor(li.getListElementObjectInspector)
+ data: Any => {
+ if (data != null) {
+ Option(li.getList(data))
+ .map { l =>
+ val values = l.asScala.map(unwrapper).toArray
+ new GenericArrayData(values)
+ }
+ .orNull
+ } else {
+ null
+ }
+ }
+ case mi: MapObjectInspector =>
+ val keyUnwrapper = unwrapperFor(mi.getMapKeyObjectInspector)
+ val valueUnwrapper = unwrapperFor(mi.getMapValueObjectInspector)
+ data: Any => {
+ if (data != null) {
+ val map = mi.getMap(data)
+ if (map == null) {
+ null
+ } else {
+ val keyValues = map.asScala.toSeq
+ val keys = keyValues.map(kv => keyUnwrapper(kv._1)).toArray
+ val values = keyValues.map(kv => valueUnwrapper(kv._2)).toArray
+ ArrayBasedMapData(keys, values)
+ }
+ } else {
+ null
+ }
+ }
+ // currently, hive doesn't provide the ConstantStructObjectInspector
+ case si: StructObjectInspector =>
+ val fields = si.getAllStructFieldRefs.asScala
+ val unwrappers = fields.map { field =>
+ val unwrapper = unwrapperFor(field.getFieldObjectInspector)
+ data: Any => unwrapper(si.getStructFieldData(data, field))
+ }
+ data: Any => {
+ if (data != null) {
+ InternalRow.fromSeq(unwrappers.map(_(data)))
+ } else {
+ null
+ }
+ }
+ }
+
+ /**
+ * Builds unwrappers ahead of time according to object inspector
* types to avoid pattern matching and branching costs per row.
+ *
+ * @param field The HiveStructField to create an unwrapper for.
+ * @return A function that performs in-place updating of a MutableRow.
+ * Use the overloaded ObjectInspector version for assignments.
*/
def unwrapperFor(field: HiveStructField): (Any, MutableRow, Int) => Unit =
field.getFieldObjectInspector match {
@@ -499,7 +644,8 @@ private[hive] trait HiveInspectors {
case oi: DoubleObjectInspector =>
(value: Any, row: MutableRow, ordinal: Int) => row.setDouble(ordinal, oi.get(value))
case oi =>
- (value: Any, row: MutableRow, ordinal: Int) => row(ordinal) = unwrap(value, oi)
+ val unwrapper = unwrapperFor(oi)
+ (value: Any, row: MutableRow, ordinal: Int) => row(ordinal) = unwrapper(value)
}
/**
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
index d044811052..e49a235643 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
@@ -401,7 +401,8 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging {
(value: Any, row: MutableRow, ordinal: Int) =>
row.update(ordinal, oi.getPrimitiveJavaObject(value))
case oi =>
- (value: Any, row: MutableRow, ordinal: Int) => row(ordinal) = unwrap(value, oi)
+ val unwrapper = unwrapperFor(oi)
+ (value: Any, row: MutableRow, ordinal: Int) => row(ordinal) = unwrapper(value)
}
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
index 9e25e1d40c..84990d3697 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
@@ -127,6 +127,9 @@ case class ScriptTransformation(
}
val mutableRow = new SpecificMutableRow(output.map(_.dataType))
+ @transient
+ lazy val unwrappers = outputSoi.getAllStructFieldRefs.asScala.map(unwrapperFor)
+
private def checkFailureAndPropagate(cause: Throwable = null): Unit = {
if (writerThread.exception.isDefined) {
throw writerThread.exception.get
@@ -215,13 +218,12 @@ case class ScriptTransformation(
val raw = outputSerde.deserialize(scriptOutputWritable)
scriptOutputWritable = null
val dataList = outputSoi.getStructFieldsDataAsList(raw)
- val fieldList = outputSoi.getAllStructFieldRefs()
var i = 0
while (i < dataList.size()) {
if (dataList.get(i) == null) {
mutableRow.setNullAt(i)
} else {
- mutableRow(i) = unwrap(dataList.get(i), fieldList.get(i).getFieldObjectInspector)
+ unwrappers(i)(dataList.get(i), mutableRow, i)
}
i += 1
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index c53675694f..9347aeb8e0 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -71,8 +71,8 @@ private[hive] case class HiveSimpleUDF(
override lazy val dataType = javaClassToDataType(method.getReturnType)
@transient
- lazy val returnInspector = ObjectInspectorFactory.getReflectionObjectInspector(
- method.getGenericReturnType(), ObjectInspectorOptions.JAVA)
+ lazy val unwrapper = unwrapperFor(ObjectInspectorFactory.getReflectionObjectInspector(
+ method.getGenericReturnType(), ObjectInspectorOptions.JAVA))
@transient
private lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length)
@@ -87,7 +87,7 @@ private[hive] case class HiveSimpleUDF(
method,
function,
conversionHelper.convertIfNecessary(inputs : _*): _*)
- unwrap(ret, returnInspector)
+ unwrapper(ret)
}
override def toString: String = {
@@ -134,6 +134,9 @@ private[hive] case class HiveGenericUDF(
}
@transient
+ private lazy val unwrapper = unwrapperFor(returnInspector)
+
+ @transient
private lazy val isUDFDeterministic = {
val udfType = function.getClass.getAnnotation(classOf[HiveUDFType])
udfType != null && udfType.deterministic()
@@ -156,7 +159,7 @@ private[hive] case class HiveGenericUDF(
.set(() => children(idx).eval(input))
i += 1
}
- unwrap(function.evaluate(deferredObjects), returnInspector)
+ unwrapper(function.evaluate(deferredObjects))
}
override def prettyName: String = name
@@ -210,6 +213,9 @@ private[hive] case class HiveGenericUDTF(
@transient
private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray
+ @transient
+ private lazy val unwrapper = unwrapperFor(outputInspector)
+
override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
outputInspector // Make sure initialized.
@@ -226,7 +232,7 @@ private[hive] case class HiveGenericUDTF(
// We need to clone the input here because implementations of
// GenericUDTF reuse the same object. Luckily they are always an array, so
// it is easy to clone.
- collected += unwrap(input, outputInspector).asInstanceOf[InternalRow]
+ collected += unwrapper(input).asInstanceOf[InternalRow]
}
def collectRows(): Seq[InternalRow] = {
@@ -293,9 +299,12 @@ private[hive] case class HiveUDAFFunction(
private lazy val returnInspector = functionAndInspector._2
@transient
+ private lazy val unwrapper = unwrapperFor(returnInspector)
+
+ @transient
private[this] var buffer: GenericUDAFEvaluator.AggregationBuffer = _
- override def eval(input: InternalRow): Any = unwrap(function.evaluate(buffer), returnInspector)
+ override def eval(input: InternalRow): Any = unwrapper(function.evaluate(buffer))
@transient
private lazy val inputProjection = new InterpretedProjection(children)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala
index 3b867bbfa1..bc51bcb07e 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala
@@ -35,6 +35,12 @@ import org.apache.spark.sql.types._
import org.apache.spark.sql.Row
class HiveInspectorSuite extends SparkFunSuite with HiveInspectors {
+
+ def unwrap(data: Any, oi: ObjectInspector): Any = {
+ val unwrapper = unwrapperFor(oi)
+ unwrapper(data)
+ }
+
test("Test wrap SettableStructObjectInspector") {
val udaf = new UDAFPercentile.PercentileLongEvaluator()
udaf.init()