aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSital Kedia <skedia@fb.com>2016-10-02 15:47:36 -0700
committerReynold Xin <rxin@databricks.com>2016-10-02 15:47:36 -0700
commitf8d7fade4b9a78ae87b6012e3d6f71eef3032b22 (patch)
tree02d94ff639ac9a24b8ac6f1b2a5b082b068d09de
parentb88cb63da39786c07cb4bfa70afed32ec5eb3286 (diff)
downloadspark-f8d7fade4b9a78ae87b6012e3d6f71eef3032b22.tar.gz
spark-f8d7fade4b9a78ae87b6012e3d6f71eef3032b22.tar.bz2
spark-f8d7fade4b9a78ae87b6012e3d6f71eef3032b22.zip
[SPARK-17509][SQL] When wrapping catalyst datatype to Hive data type avoid…
## What changes were proposed in this pull request? When wrapping catalyst datatypes to Hive data type, wrap function was doing an expensive pattern matching which was consuming around 11% of cpu time. Avoid the pattern matching by returning the wrapper only once and reuse it. ## How was this patch tested? Tested by running the job on cluster and saw around 8% cpu improvements. Author: Sital Kedia <skedia@fb.com> Closes #15064 from sitalkedia/skedia/hive_wrapper.
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala307
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala15
2 files changed, 145 insertions, 177 deletions
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
index e4b963efea..c3c4351cf5 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
@@ -238,102 +238,161 @@ private[hive] trait HiveInspectors {
case c => throw new AnalysisException(s"Unsupported java type $c")
}
+ private def withNullSafe(f: Any => Any): Any => Any = {
+ input => if (input == null) null else f(input)
+ }
+
/**
* Wraps with Hive types based on object inspector.
- * TODO: Consolidate all hive OI/data interface code.
*/
protected def wrapperFor(oi: ObjectInspector, dataType: DataType): Any => Any = oi match {
- case _: JavaHiveVarcharObjectInspector =>
- (o: Any) =>
- if (o != null) {
- val s = o.asInstanceOf[UTF8String].toString
- new HiveVarchar(s, s.length)
- } else {
- null
- }
-
- case _: JavaHiveCharObjectInspector =>
- (o: Any) =>
- if (o != null) {
- val s = o.asInstanceOf[UTF8String].toString
- new HiveChar(s, s.length)
- } else {
- null
- }
-
- case _: JavaHiveDecimalObjectInspector =>
- (o: Any) =>
- if (o != null) {
- HiveDecimal.create(o.asInstanceOf[Decimal].toJavaBigDecimal)
- } else {
- null
- }
-
- case _: JavaDateObjectInspector =>
- (o: Any) =>
- if (o != null) {
- DateTimeUtils.toJavaDate(o.asInstanceOf[Int])
- } else {
- null
- }
-
- case _: JavaTimestampObjectInspector =>
+ case x: ConstantObjectInspector =>
(o: Any) =>
- if (o != null) {
- DateTimeUtils.toJavaTimestamp(o.asInstanceOf[Long])
- } else {
- null
+ x.getWritableConstantValue
+ case x: PrimitiveObjectInspector => x match {
+ // TODO we don't support the HiveVarcharObjectInspector yet.
+ case _: StringObjectInspector if x.preferWritable() =>
+ withNullSafe(o => getStringWritable(o))
+ case _: StringObjectInspector =>
+ withNullSafe(o => o.asInstanceOf[UTF8String].toString())
+ case _: IntObjectInspector if x.preferWritable() =>
+ withNullSafe(o => getIntWritable(o))
+ case _: IntObjectInspector =>
+ withNullSafe(o => o.asInstanceOf[java.lang.Integer])
+ case _: BooleanObjectInspector if x.preferWritable() =>
+ withNullSafe(o => getBooleanWritable(o))
+ case _: BooleanObjectInspector =>
+ withNullSafe(o => o.asInstanceOf[java.lang.Boolean])
+ case _: FloatObjectInspector if x.preferWritable() =>
+ withNullSafe(o => getFloatWritable(o))
+ case _: FloatObjectInspector =>
+ withNullSafe(o => o.asInstanceOf[java.lang.Float])
+ case _: DoubleObjectInspector if x.preferWritable() =>
+ withNullSafe(o => getDoubleWritable(o))
+ case _: DoubleObjectInspector =>
+ withNullSafe(o => o.asInstanceOf[java.lang.Double])
+ case _: LongObjectInspector if x.preferWritable() =>
+ withNullSafe(o => getLongWritable(o))
+ case _: LongObjectInspector =>
+ withNullSafe(o => o.asInstanceOf[java.lang.Long])
+ case _: ShortObjectInspector if x.preferWritable() =>
+ withNullSafe(o => getShortWritable(o))
+ case _: ShortObjectInspector =>
+ withNullSafe(o => o.asInstanceOf[java.lang.Short])
+ case _: ByteObjectInspector if x.preferWritable() =>
+ withNullSafe(o => getByteWritable(o))
+ case _: ByteObjectInspector =>
+ withNullSafe(o => o.asInstanceOf[java.lang.Byte])
+ case _: JavaHiveVarcharObjectInspector =>
+ withNullSafe { o =>
+ val s = o.asInstanceOf[UTF8String].toString
+ new HiveVarchar(s, s.length)
}
+ case _: JavaHiveCharObjectInspector =>
+ withNullSafe { o =>
+ val s = o.asInstanceOf[UTF8String].toString
+ new HiveChar(s, s.length)
+ }
+ case _: JavaHiveDecimalObjectInspector =>
+ withNullSafe(o =>
+ HiveDecimal.create(o.asInstanceOf[Decimal].toJavaBigDecimal))
+ case _: JavaDateObjectInspector =>
+ withNullSafe(o =>
+ DateTimeUtils.toJavaDate(o.asInstanceOf[Int]))
+ case _: JavaTimestampObjectInspector =>
+ withNullSafe(o =>
+ DateTimeUtils.toJavaTimestamp(o.asInstanceOf[Long]))
+ case _: HiveDecimalObjectInspector if x.preferWritable() =>
+ withNullSafe(o => getDecimalWritable(o.asInstanceOf[Decimal]))
+ case _: HiveDecimalObjectInspector =>
+ withNullSafe(o =>
+ HiveDecimal.create(o.asInstanceOf[Decimal].toJavaBigDecimal))
+ case _: BinaryObjectInspector if x.preferWritable() =>
+ withNullSafe(o => getBinaryWritable(o))
+ case _: BinaryObjectInspector =>
+ withNullSafe(o => o.asInstanceOf[Array[Byte]])
+ case _: DateObjectInspector if x.preferWritable() =>
+ withNullSafe(o => getDateWritable(o))
+ case _: DateObjectInspector =>
+ withNullSafe(o => DateTimeUtils.toJavaDate(o.asInstanceOf[Int]))
+ case _: TimestampObjectInspector if x.preferWritable() =>
+ withNullSafe(o => getTimestampWritable(o))
+ case _: TimestampObjectInspector =>
+ withNullSafe(o => DateTimeUtils.toJavaTimestamp(o.asInstanceOf[Long]))
+ }
case soi: StandardStructObjectInspector =>
val schema = dataType.asInstanceOf[StructType]
val wrappers = soi.getAllStructFieldRefs.asScala.zip(schema.fields).map {
case (ref, field) => wrapperFor(ref.getFieldObjectInspector, field.dataType)
}
- (o: Any) => {
- if (o != null) {
- val struct = soi.create()
- val row = o.asInstanceOf[InternalRow]
- soi.getAllStructFieldRefs.asScala.zip(wrappers).zipWithIndex.foreach {
- case ((field, wrapper), i) =>
- soi.setStructFieldData(struct, field, wrapper(row.get(i, schema(i).dataType)))
- }
- struct
- } else {
- null
+ withNullSafe { o =>
+ val struct = soi.create()
+ val row = o.asInstanceOf[InternalRow]
+ soi.getAllStructFieldRefs.asScala.zip(wrappers).zipWithIndex.foreach {
+ case ((field, wrapper), i) =>
+ soi.setStructFieldData(struct, field, wrapper(row.get(i, schema(i).dataType)))
+ }
+ struct
+ }
+
+ case ssoi: SettableStructObjectInspector =>
+ val structType = dataType.asInstanceOf[StructType]
+ val wrappers = ssoi.getAllStructFieldRefs.asScala.zip(structType).map {
+ case (ref, tpe) => wrapperFor(ref.getFieldObjectInspector, tpe.dataType)
+ }
+ withNullSafe { o =>
+ val row = o.asInstanceOf[InternalRow]
+ // 1. create the pojo (most likely) object
+ val result = ssoi.create()
+ ssoi.getAllStructFieldRefs.asScala.zip(wrappers).zipWithIndex.foreach {
+ case ((field, wrapper), i) =>
+ val tpe = structType(i).dataType
+ ssoi.setStructFieldData(
+ result,
+ field,
+ wrapper(row.get(i, tpe)).asInstanceOf[AnyRef])
}
+ result
+ }
+
+ case soi: StructObjectInspector =>
+ val structType = dataType.asInstanceOf[StructType]
+ val wrappers = soi.getAllStructFieldRefs.asScala.zip(structType).map {
+ case (ref, tpe) => wrapperFor(ref.getFieldObjectInspector, tpe.dataType)
+ }
+ withNullSafe { o =>
+ val row = o.asInstanceOf[InternalRow]
+ val result = new java.util.ArrayList[AnyRef](wrappers.size)
+ soi.getAllStructFieldRefs.asScala.zip(wrappers).zipWithIndex.foreach {
+ case ((field, wrapper), i) =>
+ val tpe = structType(i).dataType
+ result.add(wrapper(row.get(i, tpe)).asInstanceOf[AnyRef])
+ }
+ result
}
case loi: ListObjectInspector =>
val elementType = dataType.asInstanceOf[ArrayType].elementType
val wrapper = wrapperFor(loi.getListElementObjectInspector, elementType)
- (o: Any) => {
- if (o != null) {
- val array = o.asInstanceOf[ArrayData]
- val values = new java.util.ArrayList[Any](array.numElements())
- array.foreach(elementType, (_, e) => values.add(wrapper(e)))
- values
- } else {
- null
- }
+ withNullSafe { o =>
+ val array = o.asInstanceOf[ArrayData]
+ val values = new java.util.ArrayList[Any](array.numElements())
+ array.foreach(elementType, (_, e) => values.add(wrapper(e)))
+ values
}
case moi: MapObjectInspector =>
val mt = dataType.asInstanceOf[MapType]
val keyWrapper = wrapperFor(moi.getMapKeyObjectInspector, mt.keyType)
val valueWrapper = wrapperFor(moi.getMapValueObjectInspector, mt.valueType)
-
- (o: Any) => {
- if (o != null) {
+ withNullSafe { o =>
val map = o.asInstanceOf[MapData]
val jmap = new java.util.HashMap[Any, Any](map.numElements())
map.foreach(mt.keyType, mt.valueType, (k, v) =>
jmap.put(keyWrapper(k), valueWrapper(v)))
jmap
- } else {
- null
}
- }
case _ =>
identity[Any]
@@ -648,119 +707,19 @@ private[hive] trait HiveInspectors {
(value: Any, row: MutableRow, ordinal: Int) => row(ordinal) = unwrapper(value)
}
- /**
- * Converts native catalyst types to the types expected by Hive
- * @param a the value to be wrapped
- * @param oi This ObjectInspector associated with the value returned by this function, and
- * the ObjectInspector should also be consistent with those returned from
- * toInspector: DataType => ObjectInspector and
- * toInspector: Expression => ObjectInspector
- *
- * Strictly follows the following order in wrapping (constant OI has the higher priority):
- * Constant object inspector => return the bundled value of Constant object inspector
- * Check whether the `a` is null => return null if true
- * If object inspector prefers writable object => return a Writable for the given data `a`
- * Map the catalyst data to the boxed java primitive
- *
- * NOTICE: the complex data type requires recursive wrapping.
- */
- def wrap(a: Any, oi: ObjectInspector, dataType: DataType): AnyRef = oi match {
- case x: ConstantObjectInspector => x.getWritableConstantValue
- case _ if a == null => null
- case x: PrimitiveObjectInspector => x match {
- // TODO we don't support the HiveVarcharObjectInspector yet.
- case _: StringObjectInspector if x.preferWritable() => getStringWritable(a)
- case _: StringObjectInspector => a.asInstanceOf[UTF8String].toString()
- case _: IntObjectInspector if x.preferWritable() => getIntWritable(a)
- case _: IntObjectInspector => a.asInstanceOf[java.lang.Integer]
- case _: BooleanObjectInspector if x.preferWritable() => getBooleanWritable(a)
- case _: BooleanObjectInspector => a.asInstanceOf[java.lang.Boolean]
- case _: FloatObjectInspector if x.preferWritable() => getFloatWritable(a)
- case _: FloatObjectInspector => a.asInstanceOf[java.lang.Float]
- case _: DoubleObjectInspector if x.preferWritable() => getDoubleWritable(a)
- case _: DoubleObjectInspector => a.asInstanceOf[java.lang.Double]
- case _: LongObjectInspector if x.preferWritable() => getLongWritable(a)
- case _: LongObjectInspector => a.asInstanceOf[java.lang.Long]
- case _: ShortObjectInspector if x.preferWritable() => getShortWritable(a)
- case _: ShortObjectInspector => a.asInstanceOf[java.lang.Short]
- case _: ByteObjectInspector if x.preferWritable() => getByteWritable(a)
- case _: ByteObjectInspector => a.asInstanceOf[java.lang.Byte]
- case _: HiveDecimalObjectInspector if x.preferWritable() =>
- getDecimalWritable(a.asInstanceOf[Decimal])
- case _: HiveDecimalObjectInspector =>
- HiveDecimal.create(a.asInstanceOf[Decimal].toJavaBigDecimal)
- case _: BinaryObjectInspector if x.preferWritable() => getBinaryWritable(a)
- case _: BinaryObjectInspector => a.asInstanceOf[Array[Byte]]
- case _: DateObjectInspector if x.preferWritable() => getDateWritable(a)
- case _: DateObjectInspector => DateTimeUtils.toJavaDate(a.asInstanceOf[Int])
- case _: TimestampObjectInspector if x.preferWritable() => getTimestampWritable(a)
- case _: TimestampObjectInspector => DateTimeUtils.toJavaTimestamp(a.asInstanceOf[Long])
- }
- case x: SettableStructObjectInspector =>
- val fieldRefs = x.getAllStructFieldRefs
- val structType = dataType.asInstanceOf[StructType]
- val row = a.asInstanceOf[InternalRow]
- // 1. create the pojo (most likely) object
- val result = x.create()
- var i = 0
- val size = fieldRefs.size
- while (i < size) {
- // 2. set the property for the pojo
- val tpe = structType(i).dataType
- x.setStructFieldData(
- result,
- fieldRefs.get(i),
- wrap(row.get(i, tpe), fieldRefs.get(i).getFieldObjectInspector, tpe))
- i += 1
- }
-
- result
- case x: StructObjectInspector =>
- val fieldRefs = x.getAllStructFieldRefs
- val structType = dataType.asInstanceOf[StructType]
- val row = a.asInstanceOf[InternalRow]
- val result = new java.util.ArrayList[AnyRef](fieldRefs.size)
- var i = 0
- val size = fieldRefs.size
- while (i < size) {
- val tpe = structType(i).dataType
- result.add(wrap(row.get(i, tpe), fieldRefs.get(i).getFieldObjectInspector, tpe))
- i += 1
- }
-
- result
- case x: ListObjectInspector =>
- val list = new java.util.ArrayList[Object]
- val tpe = dataType.asInstanceOf[ArrayType].elementType
- a.asInstanceOf[ArrayData].foreach(tpe, (_, e) =>
- list.add(wrap(e, x.getListElementObjectInspector, tpe))
- )
- list
- case x: MapObjectInspector =>
- val keyType = dataType.asInstanceOf[MapType].keyType
- val valueType = dataType.asInstanceOf[MapType].valueType
- val map = a.asInstanceOf[MapData]
-
- // Some UDFs seem to assume we pass in a HashMap.
- val hashMap = new java.util.HashMap[Any, Any](map.numElements())
-
- map.foreach(keyType, valueType, (k, v) =>
- hashMap.put(wrap(k, x.getMapKeyObjectInspector, keyType),
- wrap(v, x.getMapValueObjectInspector, valueType))
- )
-
- hashMap
+ def wrap(a: Any, oi: ObjectInspector, dataType: DataType): AnyRef = {
+ wrapperFor(oi, dataType)(a).asInstanceOf[AnyRef]
}
def wrap(
row: InternalRow,
- inspectors: Seq[ObjectInspector],
+ wrappers: Array[(Any) => Any],
cache: Array[AnyRef],
dataTypes: Array[DataType]): Array[AnyRef] = {
var i = 0
- val length = inspectors.length
+ val length = wrappers.length
while (i < length) {
- cache(i) = wrap(row.get(i, dataTypes(i)), inspectors(i), dataTypes(i))
+ cache(i) = wrappers(i)(row.get(i, dataTypes(i))).asInstanceOf[AnyRef]
i += 1
}
cache
@@ -768,13 +727,13 @@ private[hive] trait HiveInspectors {
def wrap(
row: Seq[Any],
- inspectors: Seq[ObjectInspector],
+ wrappers: Array[(Any) => Any],
cache: Array[AnyRef],
dataTypes: Array[DataType]): Array[AnyRef] = {
var i = 0
- val length = inspectors.length
+ val length = wrappers.length
while (i < length) {
- cache(i) = wrap(row(i), inspectors(i), dataTypes(i))
+ cache(i) = wrappers(i)(row(i)).asInstanceOf[AnyRef]
i += 1
}
cache
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index 962dd5a52e..d54913518b 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -71,6 +71,9 @@ private[hive] case class HiveSimpleUDF(
override lazy val dataType = javaClassToDataType(method.getReturnType)
@transient
+ private lazy val wrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray
+
+ @transient
lazy val unwrapper = unwrapperFor(ObjectInspectorFactory.getReflectionObjectInspector(
method.getGenericReturnType(), ObjectInspectorOptions.JAVA))
@@ -82,7 +85,7 @@ private[hive] case class HiveSimpleUDF(
// TODO: Finish input output types.
override def eval(input: InternalRow): Any = {
- val inputs = wrap(children.map(_.eval(input)), arguments, cached, inputDataTypes)
+ val inputs = wrap(children.map(_.eval(input)), wrappers, cached, inputDataTypes)
val ret = FunctionRegistry.invoke(
method,
function,
@@ -215,6 +218,9 @@ private[hive] case class HiveGenericUDTF(
private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray
@transient
+ private lazy val wrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray
+
+ @transient
private lazy val unwrapper = unwrapperFor(outputInspector)
override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
@@ -222,7 +228,7 @@ private[hive] case class HiveGenericUDTF(
val inputProjection = new InterpretedProjection(children)
- function.process(wrap(inputProjection(input), inputInspectors, udtInput, inputDataTypes))
+ function.process(wrap(inputProjection(input), wrappers, udtInput, inputDataTypes))
collector.collectRows()
}
@@ -297,6 +303,9 @@ private[hive] case class HiveUDAFFunction(
private lazy val function = functionAndInspector._1
@transient
+ private lazy val wrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray
+
+ @transient
private lazy val returnInspector = functionAndInspector._2
@transient
@@ -322,7 +331,7 @@ private[hive] case class HiveUDAFFunction(
override def update(_buffer: MutableRow, input: InternalRow): Unit = {
val inputs = inputProjection(input)
- function.iterate(buffer, wrap(inputs, inspectors, cached, inputDataTypes))
+ function.iterate(buffer, wrap(inputs, wrappers, cached, inputDataTypes))
}
override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {