From ae9f128608f67cbee0a2fb24754783ee3b4f3098 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Thu, 18 Dec 2014 20:21:52 -0800 Subject: [SPARK-4573] [SQL] Add SettableStructObjectInspector support in "wrap" function Hive UDAF may create an customized object constructed by SettableStructObjectInspector, this is critical when integrate Hive UDAF with the refactor-ed UDAF interface. Performance issue in `wrap/unwrap` since more match cases added, will do it in another PR. Author: Cheng Hao Closes #3429 from chenghao-intel/settable_oi and squashes the following commits: 9f0aff3 [Cheng Hao] update code style issues as feedbacks 2b0561d [Cheng Hao] Add more scala doc f5a40e8 [Cheng Hao] add scala doc 2977e9b [Cheng Hao] remove the timezone setting for test suite 3ed284c [Cheng Hao] fix the date type comparison f1b6749 [Cheng Hao] Update the comment 932940d [Cheng Hao] Add more unit test 72e4332 [Cheng Hao] Add settable StructObjectInspector support --- .../scala/org/apache/spark/sql/hive/Shim12.scala | 87 ++++++++++++++++------ 1 file changed, 65 insertions(+), 22 deletions(-) (limited to 'sql/hive/v0.12.0/src/main') diff --git a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala index 67cc886575..2d01a85067 100644 --- a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala +++ b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala @@ -35,6 +35,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.{PrimitiveObjectInspector, import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, TypeInfoFactory} import org.apache.hadoop.hive.serde2.{Deserializer, ColumnProjectionUtils} import org.apache.hadoop.hive.serde2.{io => hiveIo} +import org.apache.hadoop.io.NullWritable import org.apache.hadoop.{io => hadoopIo} import org.apache.hadoop.mapred.InputFormat import org.apache.spark.sql.catalyst.types.decimal.Decimal @@ -71,76 +72,114 @@ private[hive] object HiveShim { def getStringWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveCategory.STRING, - if (value == null) null else new hadoopIo.Text(value.asInstanceOf[String])) + getStringWritable(value)) def getIntWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveCategory.INT, - if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int])) + getIntWritable(value)) def getDoubleWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveCategory.DOUBLE, - if (value == null) null else new hiveIo.DoubleWritable(value.asInstanceOf[Double])) + getDoubleWritable(value)) def getBooleanWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveCategory.BOOLEAN, - if (value == null) null else new hadoopIo.BooleanWritable(value.asInstanceOf[Boolean])) + getBooleanWritable(value)) def getLongWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveCategory.LONG, - if (value == null) null else new hadoopIo.LongWritable(value.asInstanceOf[Long])) + getLongWritable(value)) def getFloatWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveCategory.FLOAT, - if (value == null) null else new hadoopIo.FloatWritable(value.asInstanceOf[Float])) + getFloatWritable(value)) def getShortWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveCategory.SHORT, - if (value == null) null else new hiveIo.ShortWritable(value.asInstanceOf[Short])) + getShortWritable(value)) def getByteWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveCategory.BYTE, - if (value == null) null else new hiveIo.ByteWritable(value.asInstanceOf[Byte])) + getByteWritable(value)) def getBinaryWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveCategory.BINARY, - if (value == null) null else new hadoopIo.BytesWritable(value.asInstanceOf[Array[Byte]])) + getBinaryWritable(value)) def getDateWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveCategory.DATE, - if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[java.sql.Date])) + getDateWritable(value)) def getTimestampWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveCategory.TIMESTAMP, - if (value == null) { - null - } else { - new hiveIo.TimestampWritable(value.asInstanceOf[java.sql.Timestamp]) - }) + getTimestampWritable(value)) def getDecimalWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveCategory.DECIMAL, - if (value == null) { - null - } else { - new hiveIo.HiveDecimalWritable( - HiveShim.createDecimal(value.asInstanceOf[Decimal].toBigDecimal.underlying())) - }) + getDecimalWritable(value)) def getPrimitiveNullWritableConstantObjectInspector: ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveCategory.VOID, null) + def getStringWritable(value: Any): hadoopIo.Text = + if (value == null) null else new hadoopIo.Text(value.asInstanceOf[String]) + + def getIntWritable(value: Any): hadoopIo.IntWritable = + if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int]) + + def getDoubleWritable(value: Any): hiveIo.DoubleWritable = + if (value == null) null else new hiveIo.DoubleWritable(value.asInstanceOf[Double]) + + def getBooleanWritable(value: Any): hadoopIo.BooleanWritable = + if (value == null) null else new hadoopIo.BooleanWritable(value.asInstanceOf[Boolean]) + + def getLongWritable(value: Any): hadoopIo.LongWritable = + if (value == null) null else new hadoopIo.LongWritable(value.asInstanceOf[Long]) + + def getFloatWritable(value: Any): hadoopIo.FloatWritable = + if (value == null) null else new hadoopIo.FloatWritable(value.asInstanceOf[Float]) + + def getShortWritable(value: Any): hiveIo.ShortWritable = + if (value == null) null else new hiveIo.ShortWritable(value.asInstanceOf[Short]) + + def getByteWritable(value: Any): hiveIo.ByteWritable = + if (value == null) null else new hiveIo.ByteWritable(value.asInstanceOf[Byte]) + + def getBinaryWritable(value: Any): hadoopIo.BytesWritable = + if (value == null) null else new hadoopIo.BytesWritable(value.asInstanceOf[Array[Byte]]) + + def getDateWritable(value: Any): hiveIo.DateWritable = + if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[java.sql.Date]) + + def getTimestampWritable(value: Any): hiveIo.TimestampWritable = + if (value == null) { + null + } else { + new hiveIo.TimestampWritable(value.asInstanceOf[java.sql.Timestamp]) + } + + def getDecimalWritable(value: Any): hiveIo.HiveDecimalWritable = + if (value == null) { + null + } else { + new hiveIo.HiveDecimalWritable( + HiveShim.createDecimal(value.asInstanceOf[Decimal].toBigDecimal.underlying())) + } + + def getPrimitiveNullWritable: NullWritable = NullWritable.get() + def createDriverResultsArray = new JArrayList[String] def processResults(results: JArrayList[String]) = results @@ -197,7 +236,11 @@ private[hive] object HiveShim { } def toCatalystDecimal(hdoi: HiveDecimalObjectInspector, data: Any): Decimal = { - Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue()) + if (hdoi.preferWritable()) { + Decimal(hdoi.getPrimitiveWritableObject(data).getHiveDecimal().bigDecimalValue) + } else { + Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue()) + } } } -- cgit v1.2.3