From ae9f128608f67cbee0a2fb24754783ee3b4f3098 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Thu, 18 Dec 2014 20:21:52 -0800 Subject: [SPARK-4573] [SQL] Add SettableStructObjectInspector support in "wrap" function Hive UDAF may create an customized object constructed by SettableStructObjectInspector, this is critical when integrate Hive UDAF with the refactor-ed UDAF interface. Performance issue in `wrap/unwrap` since more match cases added, will do it in another PR. Author: Cheng Hao Closes #3429 from chenghao-intel/settable_oi and squashes the following commits: 9f0aff3 [Cheng Hao] update code style issues as feedbacks 2b0561d [Cheng Hao] Add more scala doc f5a40e8 [Cheng Hao] add scala doc 2977e9b [Cheng Hao] remove the timezone setting for test suite 3ed284c [Cheng Hao] fix the date type comparison f1b6749 [Cheng Hao] Update the comment 932940d [Cheng Hao] Add more unit test 72e4332 [Cheng Hao] Add settable StructObjectInspector support --- .../scala/org/apache/spark/sql/hive/Shim13.scala | 130 +++++++++++++-------- 1 file changed, 84 insertions(+), 46 deletions(-) (limited to 'sql/hive/v0.13.1') diff --git a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala index 7c8cbf10c1..b78c75798e 100644 --- a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala +++ b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala @@ -22,6 +22,7 @@ import java.util.Properties import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.NullWritable import org.apache.hadoop.mapred.InputFormat import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.hadoop.hive.common.`type`.{HiveDecimal} @@ -163,91 +164,123 @@ private[hive] object HiveShim { new TableDesc(inputFormatClass, outputFormatClass, properties) } + def getStringWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.stringTypeInfo, - if (value == null) null else new hadoopIo.Text(value.asInstanceOf[String])) + TypeInfoFactory.stringTypeInfo, getStringWritable(value)) def getIntWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.intTypeInfo, - if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int])) + TypeInfoFactory.intTypeInfo, getIntWritable(value)) def getDoubleWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.doubleTypeInfo, if (value == null) { - null - } else { - new hiveIo.DoubleWritable(value.asInstanceOf[Double]) - }) + TypeInfoFactory.doubleTypeInfo, getDoubleWritable(value)) def getBooleanWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.booleanTypeInfo, if (value == null) { - null - } else { - new hadoopIo.BooleanWritable(value.asInstanceOf[Boolean]) - }) + TypeInfoFactory.booleanTypeInfo, getBooleanWritable(value)) def getLongWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.longTypeInfo, - if (value == null) null else new hadoopIo.LongWritable(value.asInstanceOf[Long])) + TypeInfoFactory.longTypeInfo, getLongWritable(value)) def getFloatWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.floatTypeInfo, if (value == null) { - null - } else { - new hadoopIo.FloatWritable(value.asInstanceOf[Float]) - }) + TypeInfoFactory.floatTypeInfo, getFloatWritable(value)) def getShortWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.shortTypeInfo, - if (value == null) null else new hiveIo.ShortWritable(value.asInstanceOf[Short])) + TypeInfoFactory.shortTypeInfo, getShortWritable(value)) def getByteWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.byteTypeInfo, - if (value == null) null else new hiveIo.ByteWritable(value.asInstanceOf[Byte])) + TypeInfoFactory.byteTypeInfo, getByteWritable(value)) def getBinaryWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.binaryTypeInfo, if (value == null) { - null - } else { - new hadoopIo.BytesWritable(value.asInstanceOf[Array[Byte]]) - }) + TypeInfoFactory.binaryTypeInfo, getBinaryWritable(value)) def getDateWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.dateTypeInfo, - if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[java.sql.Date])) + TypeInfoFactory.dateTypeInfo, getDateWritable(value)) def getTimestampWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.timestampTypeInfo, if (value == null) { - null - } else { - new hiveIo.TimestampWritable(value.asInstanceOf[java.sql.Timestamp]) - }) + TypeInfoFactory.timestampTypeInfo, getTimestampWritable(value)) def getDecimalWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.decimalTypeInfo, - if (value == null) { - null - } else { - // TODO precise, scale? - new hiveIo.HiveDecimalWritable( - HiveShim.createDecimal(value.asInstanceOf[Decimal].toBigDecimal.underlying())) - }) + TypeInfoFactory.decimalTypeInfo, getDecimalWritable(value)) def getPrimitiveNullWritableConstantObjectInspector: ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( TypeInfoFactory.voidTypeInfo, null) + def getStringWritable(value: Any): hadoopIo.Text = + if (value == null) null else new hadoopIo.Text(value.asInstanceOf[String]) + + def getIntWritable(value: Any): hadoopIo.IntWritable = + if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int]) + + def getDoubleWritable(value: Any): hiveIo.DoubleWritable = + if (value == null) { + null + } else { + new hiveIo.DoubleWritable(value.asInstanceOf[Double]) + } + + def getBooleanWritable(value: Any): hadoopIo.BooleanWritable = + if (value == null) { + null + } else { + new hadoopIo.BooleanWritable(value.asInstanceOf[Boolean]) + } + + def getLongWritable(value: Any): hadoopIo.LongWritable = + if (value == null) null else new hadoopIo.LongWritable(value.asInstanceOf[Long]) + + def getFloatWritable(value: Any): hadoopIo.FloatWritable = + if (value == null) { + null + } else { + new hadoopIo.FloatWritable(value.asInstanceOf[Float]) + } + + def getShortWritable(value: Any): hiveIo.ShortWritable = + if (value == null) null else new hiveIo.ShortWritable(value.asInstanceOf[Short]) + + def getByteWritable(value: Any): hiveIo.ByteWritable = + if (value == null) null else new hiveIo.ByteWritable(value.asInstanceOf[Byte]) + + def getBinaryWritable(value: Any): hadoopIo.BytesWritable = + if (value == null) { + null + } else { + new hadoopIo.BytesWritable(value.asInstanceOf[Array[Byte]]) + } + + def getDateWritable(value: Any): hiveIo.DateWritable = + if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[java.sql.Date]) + + def getTimestampWritable(value: Any): hiveIo.TimestampWritable = + if (value == null) { + null + } else { + new hiveIo.TimestampWritable(value.asInstanceOf[java.sql.Timestamp]) + } + + def getDecimalWritable(value: Any): hiveIo.HiveDecimalWritable = + if (value == null) { + null + } else { + // TODO precise, scale? + new hiveIo.HiveDecimalWritable( + HiveShim.createDecimal(value.asInstanceOf[Decimal].toBigDecimal.underlying())) + } + + def getPrimitiveNullWritable: NullWritable = NullWritable.get() + def createDriverResultsArray = new JArrayList[Object] def processResults(results: JArrayList[Object]) = { @@ -355,7 +388,12 @@ private[hive] object HiveShim { } def toCatalystDecimal(hdoi: HiveDecimalObjectInspector, data: Any): Decimal = { - Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue(), hdoi.precision(), hdoi.scale()) + if (hdoi.preferWritable()) { + Decimal(hdoi.getPrimitiveWritableObject(data).getHiveDecimal().bigDecimalValue, + hdoi.precision(), hdoi.scale()) + } else { + Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue(), hdoi.precision(), hdoi.scale()) + } } } -- cgit v1.2.3