aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorCheng Hao <hao.cheng@intel.com>2014-12-18 20:21:52 -0800
committerMichael Armbrust <michael@databricks.com>2014-12-18 20:21:52 -0800
commitae9f128608f67cbee0a2fb24754783ee3b4f3098 (patch)
treec841d42ba331ae19bcf9f0c467704f95691f662c /sql
parent7687415c2578b5bdc79c9646c246e52da9a4dd4a (diff)
downloadspark-ae9f128608f67cbee0a2fb24754783ee3b4f3098.tar.gz
spark-ae9f128608f67cbee0a2fb24754783ee3b4f3098.tar.bz2
spark-ae9f128608f67cbee0a2fb24754783ee3b4f3098.zip
[SPARK-4573] [SQL] Add SettableStructObjectInspector support in "wrap" function
Hive UDAF may create an customized object constructed by SettableStructObjectInspector, this is critical when integrate Hive UDAF with the refactor-ed UDAF interface. Performance issue in `wrap/unwrap` since more match cases added, will do it in another PR. Author: Cheng Hao <hao.cheng@intel.com> Closes #3429 from chenghao-intel/settable_oi and squashes the following commits: 9f0aff3 [Cheng Hao] update code style issues as feedbacks 2b0561d [Cheng Hao] Add more scala doc f5a40e8 [Cheng Hao] add scala doc 2977e9b [Cheng Hao] remove the timezone setting for test suite 3ed284c [Cheng Hao] fix the date type comparison f1b6749 [Cheng Hao] Update the comment 932940d [Cheng Hao] Add more unit test 72e4332 [Cheng Hao] Add settable StructObjectInspector support
Diffstat (limited to 'sql')
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala346
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala220
-rw-r--r--sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala87
-rw-r--r--sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala130
4 files changed, 659 insertions, 124 deletions
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
index 0eeac8620f..06189341f8 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
@@ -18,9 +18,7 @@
package org.apache.spark.sql.hive
import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar}
-import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory
import org.apache.hadoop.hive.serde2.objectinspector._
-import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector._
import org.apache.hadoop.hive.serde2.objectinspector.primitive._
import org.apache.hadoop.hive.serde2.{io => hiveIo}
import org.apache.hadoop.{io => hadoopIo}
@@ -33,6 +31,145 @@ import org.apache.spark.sql.catalyst.types.decimal.Decimal
/* Implicit conversions */
import scala.collection.JavaConversions._
+/**
+ * 1. The Underlying data type in catalyst and in Hive
+ * In catalyst:
+ * Primitive =>
+ * java.lang.String
+ * int / scala.Int
+ * boolean / scala.Boolean
+ * float / scala.Float
+ * double / scala.Double
+ * long / scala.Long
+ * short / scala.Short
+ * byte / scala.Byte
+ * org.apache.spark.sql.catalyst.types.decimal.Decimal
+ * Array[Byte]
+ * java.sql.Date
+ * java.sql.Timestamp
+ * Complex Types =>
+ * Map: scala.collection.immutable.Map
+ * List: scala.collection.immutable.Seq
+ * Struct:
+ * org.apache.spark.sql.catalyst.expression.Row
+ * Union: NOT SUPPORTED YET
+ * The Complex types plays as a container, which can hold arbitrary data types.
+ *
+ * In Hive, the native data types are various, in UDF/UDAF/UDTF, and associated with
+ * Object Inspectors, in Hive expression evaluation framework, the underlying data are
+ * Primitive Type
+ * Java Boxed Primitives:
+ * org.apache.hadoop.hive.common.type.HiveVarchar
+ * java.lang.String
+ * java.lang.Integer
+ * java.lang.Boolean
+ * java.lang.Float
+ * java.lang.Double
+ * java.lang.Long
+ * java.lang.Short
+ * java.lang.Byte
+ * org.apache.hadoop.hive.common.`type`.HiveDecimal
+ * byte[]
+ * java.sql.Date
+ * java.sql.Timestamp
+ * Writables:
+ * org.apache.hadoop.hive.serde2.io.HiveVarcharWritable
+ * org.apache.hadoop.io.Text
+ * org.apache.hadoop.io.IntWritable
+ * org.apache.hadoop.hive.serde2.io.DoubleWritable
+ * org.apache.hadoop.io.BooleanWritable
+ * org.apache.hadoop.io.LongWritable
+ * org.apache.hadoop.io.FloatWritable
+ * org.apache.hadoop.hive.serde2.io.ShortWritable
+ * org.apache.hadoop.hive.serde2.io.ByteWritable
+ * org.apache.hadoop.io.BytesWritable
+ * org.apache.hadoop.hive.serde2.io.DateWritable
+ * org.apache.hadoop.hive.serde2.io.TimestampWritable
+ * org.apache.hadoop.hive.serde2.io.HiveDecimalWritable
+ * Complex Type
+ * List: Object[] / java.util.List
+ * Map: java.util.Map
+ * Struct: Object[] / java.util.List / java POJO
+ * Union: class StandardUnion { byte tag; Object object }
+ *
+ * NOTICE: HiveVarchar is not supported by catalyst, it will be simply considered as String type.
+ *
+ *
+ * 2. Hive ObjectInspector is a group of flexible APIs to inspect value in different data
+ * representation, and developers can extend those API as needed, so technically,
+ * object inspector supports arbitrary data type in java.
+ *
+ * Fortunately, only few built-in Hive Object Inspectors are used in generic udf/udaf/udtf
+ * evaluation.
+ * 1) Primitive Types (PrimitiveObjectInspector & its sub classes)
+ {{{
+ public interface PrimitiveObjectInspector {
+ // Java Primitives (java.lang.Integer, java.lang.String etc.)
+ Object getPrimitiveWritableObject(Object o);
+ // Writables (hadoop.io.IntWritable, hadoop.io.Text etc.)
+ Object getPrimitiveJavaObject(Object o);
+ // ObjectInspector only inspect the `writable` always return true, we need to check it
+ // before invoking the methods above.
+ boolean preferWritable();
+ ...
+ }
+ }}}
+
+ * 2) Complex Types:
+ * ListObjectInspector: inspects java array or [[java.util.List]]
+ * MapObjectInspector: inspects [[java.util.Map]]
+ * Struct.StructObjectInspector: inspects java array, [[java.util.List]] and
+ * even a normal java object (POJO)
+ * UnionObjectInspector: (tag: Int, object data) (TODO: not supported by SparkSQL yet)
+ *
+ * 3) ConstantObjectInspector:
+ * Constant object inspector can be either primitive type or Complex type, and it bundles a
+ * constant value as its property, usually the value is created when the constant object inspector
+ * constructed.
+ * {{{
+ public interface ConstantObjectInspector extends ObjectInspector {
+ Object getWritableConstantValue();
+ ...
+ }
+ }}}
+ * Hive provides 3 built-in constant object inspectors:
+ * Primitive Object Inspectors:
+ * WritableConstantStringObjectInspector
+ * WritableConstantHiveVarcharObjectInspector
+ * WritableConstantHiveDecimalObjectInspector
+ * WritableConstantTimestampObjectInspector
+ * WritableConstantIntObjectInspector
+ * WritableConstantDoubleObjectInspector
+ * WritableConstantBooleanObjectInspector
+ * WritableConstantLongObjectInspector
+ * WritableConstantFloatObjectInspector
+ * WritableConstantShortObjectInspector
+ * WritableConstantByteObjectInspector
+ * WritableConstantBinaryObjectInspector
+ * WritableConstantDateObjectInspector
+ * Map Object Inspector:
+ * StandardConstantMapObjectInspector
+ * List Object Inspector:
+ * StandardConstantListObjectInspector]]
+ * Struct Object Inspector: Hive doesn't provide the built-in constant object inspector for Struct
+ * Union Object Inspector: Hive doesn't provide the built-in constant object inspector for Union
+ *
+ *
+ * 3. This trait facilitates:
+ * Data Unwrapping: Hive Data => Catalyst Data (unwrap)
+ * Data Wrapping: Catalyst Data => Hive Data (wrap)
+ * Binding the Object Inspector for Catalyst Data (toInspector)
+ * Retrieving the Catalyst Data Type from Object Inspector (inspectorToDataType)
+ *
+ *
+ * 4. Future Improvement (TODO)
+ * This implementation is quite ugly and inefficient:
+ * a. Pattern matching in runtime
+ * b. Small objects creation in catalyst data => writable
+ * c. Unnecessary unwrap / wrap for nested UDF invoking:
+ * e.g. date_add(printf("%s-%s-%s", a,b,c), 3)
+ * We don't need to unwrap the data for printf and wrap it again and passes in data_add
+ */
private[hive] trait HiveInspectors {
def javaClassToDataType(clz: Class[_]): DataType = clz match {
@@ -87,10 +224,23 @@ private[hive] trait HiveInspectors {
* @param oi the ObjectInspector associated with the Hive Type
* @return convert the data into catalyst type
* TODO return the function of (data => Any) instead for performance consideration
+ *
+ * Strictly follows the following order in unwrapping (constant OI has the higher priority):
+ * Constant Null object inspector =>
+ * return null
+ * Constant object inspector =>
+ * extract the value from constant object inspector
+ * Check whether the `data` is null =>
+ * return null if true
+ * If object inspector prefers writable =>
+ * extract writable from `data` and then get the catalyst type from the writable
+ * Extract the java object directly from the object inspector
+ *
+ * NOTICE: the complex data type requires recursive unwrapping.
*/
def unwrap(data: Any, oi: ObjectInspector): Any = oi match {
- case _ if data == null => null
- case poi: VoidObjectInspector => null
+ case coi: ConstantObjectInspector if coi.getWritableConstantValue == null => null
+ case poi: WritableConstantStringObjectInspector => poi.getWritableConstantValue.toString
case poi: WritableConstantHiveVarcharObjectInspector =>
poi.getWritableConstantValue.getHiveVarchar.getValue
case poi: WritableConstantHiveDecimalObjectInspector =>
@@ -119,12 +269,44 @@ private[hive] trait HiveInspectors {
System.arraycopy(writable.getBytes, 0, temp, 0, temp.length)
temp
case poi: WritableConstantDateObjectInspector => poi.getWritableConstantValue.get()
- case hvoi: HiveVarcharObjectInspector => hvoi.getPrimitiveJavaObject(data).getValue
- case hdoi: HiveDecimalObjectInspector => HiveShim.toCatalystDecimal(hdoi, data)
- // org.apache.hadoop.hive.serde2.io.TimestampWritable.set will reset current time object
- // if next timestamp is null, so Timestamp object is cloned
- case ti: TimestampObjectInspector => ti.getPrimitiveJavaObject(data).clone()
- case pi: PrimitiveObjectInspector => pi.getPrimitiveJavaObject(data)
+ case mi: StandardConstantMapObjectInspector =>
+ // take the value from the map inspector object, rather than the input data
+ mi.getWritableConstantValue.map { case (k, v) =>
+ (unwrap(k, mi.getMapKeyObjectInspector),
+ unwrap(v, mi.getMapValueObjectInspector))
+ }.toMap
+ case li: StandardConstantListObjectInspector =>
+ // take the value from the list inspector object, rather than the input data
+ li.getWritableConstantValue.map(unwrap(_, li.getListElementObjectInspector)).toSeq
+ // if the value is null, we don't care about the object inspector type
+ case _ if data == null => null
+ case poi: VoidObjectInspector => null // always be null for void object inspector
+ case pi: PrimitiveObjectInspector => pi match {
+ // We think HiveVarchar is also a String
+ case hvoi: HiveVarcharObjectInspector if hvoi.preferWritable() =>
+ hvoi.getPrimitiveWritableObject(data).getHiveVarchar.getValue
+ case hvoi: HiveVarcharObjectInspector => hvoi.getPrimitiveJavaObject(data).getValue
+ case x: StringObjectInspector if x.preferWritable() =>
+ x.getPrimitiveWritableObject(data).toString
+ case x: IntObjectInspector if x.preferWritable() => x.get(data)
+ case x: BooleanObjectInspector if x.preferWritable() => x.get(data)
+ case x: FloatObjectInspector if x.preferWritable() => x.get(data)
+ case x: DoubleObjectInspector if x.preferWritable() => x.get(data)
+ case x: LongObjectInspector if x.preferWritable() => x.get(data)
+ case x: ShortObjectInspector if x.preferWritable() => x.get(data)
+ case x: ByteObjectInspector if x.preferWritable() => x.get(data)
+ case x: HiveDecimalObjectInspector => HiveShim.toCatalystDecimal(x, data)
+ case x: BinaryObjectInspector if x.preferWritable() =>
+ x.getPrimitiveWritableObject(data).copyBytes()
+ case x: DateObjectInspector if x.preferWritable() =>
+ x.getPrimitiveWritableObject(data).get()
+ // org.apache.hadoop.hive.serde2.io.TimestampWritable.set will reset current time object
+ // if next timestamp is null, so Timestamp object is cloned
+ case x: TimestampObjectInspector if x.preferWritable() =>
+ x.getPrimitiveWritableObject(data).getTimestamp.clone()
+ case ti: TimestampObjectInspector => ti.getPrimitiveJavaObject(data).clone()
+ case _ => pi.getPrimitiveJavaObject(data)
+ }
case li: ListObjectInspector =>
Option(li.getList(data))
.map(_.map(unwrap(_, li.getListElementObjectInspector)).toSeq)
@@ -132,10 +314,11 @@ private[hive] trait HiveInspectors {
case mi: MapObjectInspector =>
Option(mi.getMap(data)).map(
_.map {
- case (k,v) =>
+ case (k, v) =>
(unwrap(k, mi.getMapKeyObjectInspector),
unwrap(v, mi.getMapValueObjectInspector))
}.toMap).orNull
+ // currently, hive doesn't provide the ConstantStructObjectInspector
case si: StructObjectInspector =>
val allRefs = si.getAllStructFieldRefs
new GenericRow(
@@ -191,55 +374,89 @@ private[hive] trait HiveInspectors {
* the ObjectInspector should also be consistent with those returned from
* toInspector: DataType => ObjectInspector and
* toInspector: Expression => ObjectInspector
+ *
+ * Strictly follows the following order in wrapping (constant OI has the higher priority):
+ * Constant object inspector => return the bundled value of Constant object inspector
+ * Check whether the `a` is null => return null if true
+ * If object inspector prefers writable object => return a Writable for the given data `a`
+ * Map the catalyst data to the boxed java primitive
+ *
+ * NOTICE: the complex data type requires recursive wrapping.
*/
- def wrap(a: Any, oi: ObjectInspector): AnyRef = if (a == null) {
- null
- } else {
- oi match {
- case x: ConstantObjectInspector => x.getWritableConstantValue
- case x: PrimitiveObjectInspector => a match {
- // TODO what if x.preferWritable() == true? reuse the writable?
- case s: String => s: java.lang.String
- case i: Int => i: java.lang.Integer
- case b: Boolean => b: java.lang.Boolean
- case f: Float => f: java.lang.Float
- case d: Double => d: java.lang.Double
- case l: Long => l: java.lang.Long
- case l: Short => l: java.lang.Short
- case l: Byte => l: java.lang.Byte
- case b: BigDecimal => HiveShim.createDecimal(b.underlying())
- case d: Decimal => HiveShim.createDecimal(d.toBigDecimal.underlying())
- case b: Array[Byte] => b
- case d: java.sql.Date => d
- case t: java.sql.Timestamp => t
+ def wrap(a: Any, oi: ObjectInspector): AnyRef = oi match {
+ case x: ConstantObjectInspector => x.getWritableConstantValue
+ case _ if a == null => null
+ case x: PrimitiveObjectInspector => x match {
+ // TODO we don't support the HiveVarcharObjectInspector yet.
+ case _: StringObjectInspector if x.preferWritable() => HiveShim.getStringWritable(a)
+ case _: StringObjectInspector => a.asInstanceOf[java.lang.String]
+ case _: IntObjectInspector if x.preferWritable() => HiveShim.getIntWritable(a)
+ case _: IntObjectInspector => a.asInstanceOf[java.lang.Integer]
+ case _: BooleanObjectInspector if x.preferWritable() => HiveShim.getBooleanWritable(a)
+ case _: BooleanObjectInspector => a.asInstanceOf[java.lang.Boolean]
+ case _: FloatObjectInspector if x.preferWritable() => HiveShim.getFloatWritable(a)
+ case _: FloatObjectInspector => a.asInstanceOf[java.lang.Float]
+ case _: DoubleObjectInspector if x.preferWritable() => HiveShim.getDoubleWritable(a)
+ case _: DoubleObjectInspector => a.asInstanceOf[java.lang.Double]
+ case _: LongObjectInspector if x.preferWritable() => HiveShim.getLongWritable(a)
+ case _: LongObjectInspector => a.asInstanceOf[java.lang.Long]
+ case _: ShortObjectInspector if x.preferWritable() => HiveShim.getShortWritable(a)
+ case _: ShortObjectInspector => a.asInstanceOf[java.lang.Short]
+ case _: ByteObjectInspector if x.preferWritable() => HiveShim.getByteWritable(a)
+ case _: ByteObjectInspector => a.asInstanceOf[java.lang.Byte]
+ case _: HiveDecimalObjectInspector if x.preferWritable() =>
+ HiveShim.getDecimalWritable(a.asInstanceOf[Decimal])
+ case _: HiveDecimalObjectInspector =>
+ HiveShim.createDecimal(a.asInstanceOf[Decimal].toBigDecimal.underlying())
+ case _: BinaryObjectInspector if x.preferWritable() => HiveShim.getBinaryWritable(a)
+ case _: BinaryObjectInspector => a.asInstanceOf[Array[Byte]]
+ case _: DateObjectInspector if x.preferWritable() => HiveShim.getDateWritable(a)
+ case _: DateObjectInspector => a.asInstanceOf[java.sql.Date]
+ case _: TimestampObjectInspector if x.preferWritable() => HiveShim.getTimestampWritable(a)
+ case _: TimestampObjectInspector => a.asInstanceOf[java.sql.Timestamp]
+ }
+ case x: SettableStructObjectInspector =>
+ val fieldRefs = x.getAllStructFieldRefs
+ val row = a.asInstanceOf[Seq[_]]
+ // 1. create the pojo (most likely) object
+ val result = x.create()
+ var i = 0
+ while (i < fieldRefs.length) {
+ // 2. set the property for the pojo
+ x.setStructFieldData(
+ result,
+ fieldRefs.get(i),
+ wrap(row(i), fieldRefs.get(i).getFieldObjectInspector))
+ i += 1
}
- case x: StructObjectInspector =>
- val fieldRefs = x.getAllStructFieldRefs
- val row = a.asInstanceOf[Seq[_]]
- val result = new java.util.ArrayList[AnyRef](fieldRefs.length)
- var i = 0
- while (i < fieldRefs.length) {
- result.add(wrap(row(i), fieldRefs.get(i).getFieldObjectInspector))
- i += 1
- }
- result
- case x: ListObjectInspector =>
- val list = new java.util.ArrayList[Object]
- a.asInstanceOf[Seq[_]].foreach {
- v => list.add(wrap(v, x.getListElementObjectInspector))
- }
- list
- case x: MapObjectInspector =>
- // Some UDFs seem to assume we pass in a HashMap.
- val hashMap = new java.util.HashMap[AnyRef, AnyRef]()
- hashMap.putAll(a.asInstanceOf[Map[_, _]].map {
- case (k, v) =>
- wrap(k, x.getMapKeyObjectInspector) -> wrap(v, x.getMapValueObjectInspector)
- })
+ result
+ case x: StructObjectInspector =>
+ val fieldRefs = x.getAllStructFieldRefs
+ val row = a.asInstanceOf[Seq[_]]
+ val result = new java.util.ArrayList[AnyRef](fieldRefs.length)
+ var i = 0
+ while (i < fieldRefs.length) {
+ result.add(wrap(row(i), fieldRefs.get(i).getFieldObjectInspector))
+ i += 1
+ }
- hashMap
- }
+ result
+ case x: ListObjectInspector =>
+ val list = new java.util.ArrayList[Object]
+ a.asInstanceOf[Seq[_]].foreach {
+ v => list.add(wrap(v, x.getListElementObjectInspector))
+ }
+ list
+ case x: MapObjectInspector =>
+ // Some UDFs seem to assume we pass in a HashMap.
+ val hashMap = new java.util.HashMap[AnyRef, AnyRef]()
+ hashMap.putAll(a.asInstanceOf[Map[_, _]].map {
+ case (k, v) =>
+ wrap(k, x.getMapKeyObjectInspector) -> wrap(v, x.getMapValueObjectInspector)
+ })
+
+ hashMap
}
def wrap(
@@ -254,6 +471,11 @@ private[hive] trait HiveInspectors {
cache
}
+ /**
+ * @param dataType Catalyst data type
+ * @return Hive java object inspector (recursively), not the Writable ObjectInspector
+ * We can easily map to the Hive built-in object inspector according to the data type.
+ */
def toInspector(dataType: DataType): ObjectInspector = dataType match {
case ArrayType(tpe, _) =>
ObjectInspectorFactory.getStandardListObjectInspector(toInspector(tpe))
@@ -272,12 +494,20 @@ private[hive] trait HiveInspectors {
case BinaryType => PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector
case DateType => PrimitiveObjectInspectorFactory.javaDateObjectInspector
case TimestampType => PrimitiveObjectInspectorFactory.javaTimestampObjectInspector
+ // TODO decimal precision?
case DecimalType() => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector
case StructType(fields) =>
ObjectInspectorFactory.getStandardStructObjectInspector(
fields.map(f => f.name), fields.map(f => toInspector(f.dataType)))
}
+ /**
+ * Map the catalyst expression to ObjectInspector, however,
+ * if the expression is [[Literal]] or foldable, a constant writable object inspector returns;
+ * Otherwise, we always get the object inspector according to its data type(in catalyst)
+ * @param expr Catalyst expression to be mapped
+ * @return Hive java objectinspector (recursively).
+ */
def toInspector(expr: Expression): ObjectInspector = expr match {
case Literal(value, StringType) =>
HiveShim.getStringWritableConstantObjectInspector(value)
@@ -326,8 +556,12 @@ private[hive] trait HiveInspectors {
})
ObjectInspectorFactory.getStandardConstantMapObjectInspector(keyOI, valueOI, map)
}
+ // We will enumerate all of the possible constant expressions, throw exception if we missed
case Literal(_, dt) => sys.error(s"Hive doesn't support the constant type [$dt].")
+ // ideally, we don't test the foldable here(but in optimizer), however, some of the
+ // Hive UDF / UDAF requires its argument to be constant objectinspector, we do it eagerly.
case _ if expr.foldable => toInspector(Literal(expr.eval(), expr.dataType))
+ // For those non constant expression, map to object inspector according to its data type
case _ => toInspector(expr.dataType)
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala
new file mode 100644
index 0000000000..bfe608a51a
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala
@@ -0,0 +1,220 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+import java.sql.Date
+import java.util
+
+import org.apache.hadoop.hive.serde2.io.DoubleWritable
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
+import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
+import org.scalatest.FunSuite
+
+import org.apache.hadoop.hive.ql.udf.UDAFPercentile
+import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, StructObjectInspector, ObjectInspectorFactory}
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions
+import org.apache.hadoop.io.LongWritable
+
+import org.apache.spark.sql.catalyst.expressions.{Literal, Row}
+
+class HiveInspectorSuite extends FunSuite with HiveInspectors {
+ test("Test wrap SettableStructObjectInspector") {
+ val udaf = new UDAFPercentile.PercentileLongEvaluator()
+ udaf.init()
+
+ udaf.iterate(new LongWritable(1), 0.1)
+ udaf.iterate(new LongWritable(1), 0.1)
+
+ val state = udaf.terminatePartial()
+
+ val soi = ObjectInspectorFactory.getReflectionObjectInspector(
+ classOf[UDAFPercentile.State],
+ ObjectInspectorOptions.JAVA).asInstanceOf[StructObjectInspector]
+
+ val a = unwrap(state, soi).asInstanceOf[Row]
+ val b = wrap(a, soi).asInstanceOf[UDAFPercentile.State]
+
+ val sfCounts = soi.getStructFieldRef("counts")
+ val sfPercentiles = soi.getStructFieldRef("percentiles")
+
+ assert(2 === soi.getStructFieldData(b, sfCounts)
+ .asInstanceOf[util.Map[LongWritable, LongWritable]]
+ .get(new LongWritable(1L))
+ .get())
+ assert(0.1 === soi.getStructFieldData(b, sfPercentiles)
+ .asInstanceOf[util.ArrayList[DoubleWritable]]
+ .get(0)
+ .get())
+ }
+
+ val data =
+ Literal(true) ::
+ Literal(0.asInstanceOf[Byte]) ::
+ Literal(0.asInstanceOf[Short]) ::
+ Literal(0) ::
+ Literal(0.asInstanceOf[Long]) ::
+ Literal(0.asInstanceOf[Float]) ::
+ Literal(0.asInstanceOf[Double]) ::
+ Literal("0") ::
+ Literal(new Date(2014, 9, 23)) ::
+ Literal(Decimal(BigDecimal(123.123))) ::
+ Literal(new java.sql.Timestamp(123123)) ::
+ Literal(Array[Byte](1,2,3)) ::
+ Literal(Seq[Int](1,2,3), ArrayType(IntegerType)) ::
+ Literal(Map[Int, Int](1->2, 2->1), MapType(IntegerType, IntegerType)) ::
+ Literal(Row(1,2.0d,3.0f),
+ StructType(StructField("c1", IntegerType) ::
+ StructField("c2", DoubleType) ::
+ StructField("c3", FloatType) :: Nil)) ::
+ Nil
+
+ val row = data.map(_.eval(null))
+ val dataTypes = data.map(_.dataType)
+
+ import scala.collection.JavaConversions._
+ def toWritableInspector(dataType: DataType): ObjectInspector = dataType match {
+ case ArrayType(tpe, _) =>
+ ObjectInspectorFactory.getStandardListObjectInspector(toWritableInspector(tpe))
+ case MapType(keyType, valueType, _) =>
+ ObjectInspectorFactory.getStandardMapObjectInspector(
+ toWritableInspector(keyType), toWritableInspector(valueType))
+ case StringType => PrimitiveObjectInspectorFactory.writableStringObjectInspector
+ case IntegerType => PrimitiveObjectInspectorFactory.writableIntObjectInspector
+ case DoubleType => PrimitiveObjectInspectorFactory.writableDoubleObjectInspector
+ case BooleanType => PrimitiveObjectInspectorFactory.writableBooleanObjectInspector
+ case LongType => PrimitiveObjectInspectorFactory.writableLongObjectInspector
+ case FloatType => PrimitiveObjectInspectorFactory.writableFloatObjectInspector
+ case ShortType => PrimitiveObjectInspectorFactory.writableShortObjectInspector
+ case ByteType => PrimitiveObjectInspectorFactory.writableByteObjectInspector
+ case NullType => PrimitiveObjectInspectorFactory.writableVoidObjectInspector
+ case BinaryType => PrimitiveObjectInspectorFactory.writableBinaryObjectInspector
+ case DateType => PrimitiveObjectInspectorFactory.writableDateObjectInspector
+ case TimestampType => PrimitiveObjectInspectorFactory.writableTimestampObjectInspector
+ case DecimalType() => PrimitiveObjectInspectorFactory.writableHiveDecimalObjectInspector
+ case StructType(fields) =>
+ ObjectInspectorFactory.getStandardStructObjectInspector(
+ fields.map(f => f.name), fields.map(f => toWritableInspector(f.dataType)))
+ }
+
+ def checkDataType(dt1: Seq[DataType], dt2: Seq[DataType]): Unit = {
+ dt1.zip(dt2).map {
+ case (dd1, dd2) =>
+ assert(dd1.getClass === dd2.getClass) // DecimalType doesn't has the default precision info
+ }
+ }
+
+ def checkValues(row1: Seq[Any], row2: Seq[Any]): Unit = {
+ row1.zip(row2).map {
+ case (r1, r2) => checkValues(r1, r2)
+ }
+ }
+
+ def checkValues(v1: Any, v2: Any): Unit = {
+ (v1, v2) match {
+ case (r1: Decimal, r2: Decimal) =>
+ // Ignore the Decimal precision
+ assert(r1.compare(r2) === 0)
+ case (r1: Array[Byte], r2: Array[Byte])
+ if r1 != null && r2 != null && r1.length == r2.length =>
+ r1.zip(r2).map { case (b1, b2) => assert(b1 === b2) }
+ case (r1: Date, r2: Date) => assert(r1.compareTo(r2) === 0)
+ case (r1, r2) => assert(r1 === r2)
+ }
+ }
+
+ test("oi => datatype => oi") {
+ val ois = dataTypes.map(toInspector)
+
+ checkDataType(ois.map(inspectorToDataType), dataTypes)
+ checkDataType(dataTypes.map(toWritableInspector).map(inspectorToDataType), dataTypes)
+ }
+
+ test("wrap / unwrap null, constant null and writables") {
+ val writableOIs = dataTypes.map(toWritableInspector)
+ val nullRow = data.map(d => null)
+
+ checkValues(nullRow, nullRow.zip(writableOIs).map {
+ case (d, oi) => unwrap(wrap(d, oi), oi)
+ })
+
+ // struct couldn't be constant, sweep it out
+ val constantExprs = data.filter(!_.dataType.isInstanceOf[StructType])
+ val constantData = constantExprs.map(_.eval())
+ val constantNullData = constantData.map(_ => null)
+ val constantWritableOIs = constantExprs.map(e => toWritableInspector(e.dataType))
+ val constantNullWritableOIs = constantExprs.map(e => toInspector(Literal(null, e.dataType)))
+
+ checkValues(constantData, constantData.zip(constantWritableOIs).map {
+ case (d, oi) => unwrap(wrap(d, oi), oi)
+ })
+
+ checkValues(constantNullData, constantData.zip(constantNullWritableOIs).map {
+ case (d, oi) => unwrap(wrap(d, oi), oi)
+ })
+
+ checkValues(constantNullData, constantNullData.zip(constantWritableOIs).map {
+ case (d, oi) => unwrap(wrap(d, oi), oi)
+ })
+ }
+
+ test("wrap / unwrap primitive writable object inspector") {
+ val writableOIs = dataTypes.map(toWritableInspector)
+
+ checkValues(row, row.zip(writableOIs).map {
+ case (data, oi) => unwrap(wrap(data, oi), oi)
+ })
+ }
+
+ test("wrap / unwrap primitive java object inspector") {
+ val ois = dataTypes.map(toInspector)
+
+ checkValues(row, row.zip(ois).map {
+ case (data, oi) => unwrap(wrap(data, oi), oi)
+ })
+ }
+
+ test("wrap / unwrap Struct Type") {
+ val dt = StructType(dataTypes.zipWithIndex.map {
+ case (t, idx) => StructField(s"c_$idx", t)
+ })
+
+ checkValues(row, unwrap(wrap(row, toInspector(dt)), toInspector(dt)).asInstanceOf[Row])
+ checkValues(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt)))
+ }
+
+ test("wrap / unwrap Array Type") {
+ val dt = ArrayType(dataTypes(0))
+
+ val d = row(0) :: row(0) :: Nil
+ checkValues(d, unwrap(wrap(d, toInspector(dt)), toInspector(dt)))
+ checkValues(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt)))
+ checkValues(d, unwrap(wrap(d, toInspector(Literal(d, dt))), toInspector(Literal(d, dt))))
+ checkValues(d, unwrap(wrap(null, toInspector(Literal(d, dt))), toInspector(Literal(d, dt))))
+ }
+
+ test("wrap / unwrap Map Type") {
+ val dt = MapType(dataTypes(0), dataTypes(1))
+
+ val d = Map(row(0) -> row(1))
+ checkValues(d, unwrap(wrap(d, toInspector(dt)), toInspector(dt)))
+ checkValues(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt)))
+ checkValues(d, unwrap(wrap(d, toInspector(Literal(d, dt))), toInspector(Literal(d, dt))))
+ checkValues(d, unwrap(wrap(null, toInspector(Literal(d, dt))), toInspector(Literal(d, dt))))
+ }
+}
diff --git a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala
index 67cc886575..2d01a85067 100644
--- a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala
+++ b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala
@@ -35,6 +35,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.{PrimitiveObjectInspector,
import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, TypeInfoFactory}
import org.apache.hadoop.hive.serde2.{Deserializer, ColumnProjectionUtils}
import org.apache.hadoop.hive.serde2.{io => hiveIo}
+import org.apache.hadoop.io.NullWritable
import org.apache.hadoop.{io => hadoopIo}
import org.apache.hadoop.mapred.InputFormat
import org.apache.spark.sql.catalyst.types.decimal.Decimal
@@ -71,76 +72,114 @@ private[hive] object HiveShim {
def getStringWritableConstantObjectInspector(value: Any): ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
PrimitiveCategory.STRING,
- if (value == null) null else new hadoopIo.Text(value.asInstanceOf[String]))
+ getStringWritable(value))
def getIntWritableConstantObjectInspector(value: Any): ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
PrimitiveCategory.INT,
- if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int]))
+ getIntWritable(value))
def getDoubleWritableConstantObjectInspector(value: Any): ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
PrimitiveCategory.DOUBLE,
- if (value == null) null else new hiveIo.DoubleWritable(value.asInstanceOf[Double]))
+ getDoubleWritable(value))
def getBooleanWritableConstantObjectInspector(value: Any): ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
PrimitiveCategory.BOOLEAN,
- if (value == null) null else new hadoopIo.BooleanWritable(value.asInstanceOf[Boolean]))
+ getBooleanWritable(value))
def getLongWritableConstantObjectInspector(value: Any): ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
PrimitiveCategory.LONG,
- if (value == null) null else new hadoopIo.LongWritable(value.asInstanceOf[Long]))
+ getLongWritable(value))
def getFloatWritableConstantObjectInspector(value: Any): ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
PrimitiveCategory.FLOAT,
- if (value == null) null else new hadoopIo.FloatWritable(value.asInstanceOf[Float]))
+ getFloatWritable(value))
def getShortWritableConstantObjectInspector(value: Any): ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
PrimitiveCategory.SHORT,
- if (value == null) null else new hiveIo.ShortWritable(value.asInstanceOf[Short]))
+ getShortWritable(value))
def getByteWritableConstantObjectInspector(value: Any): ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
PrimitiveCategory.BYTE,
- if (value == null) null else new hiveIo.ByteWritable(value.asInstanceOf[Byte]))
+ getByteWritable(value))
def getBinaryWritableConstantObjectInspector(value: Any): ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
PrimitiveCategory.BINARY,
- if (value == null) null else new hadoopIo.BytesWritable(value.asInstanceOf[Array[Byte]]))
+ getBinaryWritable(value))
def getDateWritableConstantObjectInspector(value: Any): ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
PrimitiveCategory.DATE,
- if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[java.sql.Date]))
+ getDateWritable(value))
def getTimestampWritableConstantObjectInspector(value: Any): ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
PrimitiveCategory.TIMESTAMP,
- if (value == null) {
- null
- } else {
- new hiveIo.TimestampWritable(value.asInstanceOf[java.sql.Timestamp])
- })
+ getTimestampWritable(value))
def getDecimalWritableConstantObjectInspector(value: Any): ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
PrimitiveCategory.DECIMAL,
- if (value == null) {
- null
- } else {
- new hiveIo.HiveDecimalWritable(
- HiveShim.createDecimal(value.asInstanceOf[Decimal].toBigDecimal.underlying()))
- })
+ getDecimalWritable(value))
def getPrimitiveNullWritableConstantObjectInspector: ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
PrimitiveCategory.VOID, null)
+ def getStringWritable(value: Any): hadoopIo.Text =
+ if (value == null) null else new hadoopIo.Text(value.asInstanceOf[String])
+
+ def getIntWritable(value: Any): hadoopIo.IntWritable =
+ if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int])
+
+ def getDoubleWritable(value: Any): hiveIo.DoubleWritable =
+ if (value == null) null else new hiveIo.DoubleWritable(value.asInstanceOf[Double])
+
+ def getBooleanWritable(value: Any): hadoopIo.BooleanWritable =
+ if (value == null) null else new hadoopIo.BooleanWritable(value.asInstanceOf[Boolean])
+
+ def getLongWritable(value: Any): hadoopIo.LongWritable =
+ if (value == null) null else new hadoopIo.LongWritable(value.asInstanceOf[Long])
+
+ def getFloatWritable(value: Any): hadoopIo.FloatWritable =
+ if (value == null) null else new hadoopIo.FloatWritable(value.asInstanceOf[Float])
+
+ def getShortWritable(value: Any): hiveIo.ShortWritable =
+ if (value == null) null else new hiveIo.ShortWritable(value.asInstanceOf[Short])
+
+ def getByteWritable(value: Any): hiveIo.ByteWritable =
+ if (value == null) null else new hiveIo.ByteWritable(value.asInstanceOf[Byte])
+
+ def getBinaryWritable(value: Any): hadoopIo.BytesWritable =
+ if (value == null) null else new hadoopIo.BytesWritable(value.asInstanceOf[Array[Byte]])
+
+ def getDateWritable(value: Any): hiveIo.DateWritable =
+ if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[java.sql.Date])
+
+ def getTimestampWritable(value: Any): hiveIo.TimestampWritable =
+ if (value == null) {
+ null
+ } else {
+ new hiveIo.TimestampWritable(value.asInstanceOf[java.sql.Timestamp])
+ }
+
+ def getDecimalWritable(value: Any): hiveIo.HiveDecimalWritable =
+ if (value == null) {
+ null
+ } else {
+ new hiveIo.HiveDecimalWritable(
+ HiveShim.createDecimal(value.asInstanceOf[Decimal].toBigDecimal.underlying()))
+ }
+
+ def getPrimitiveNullWritable: NullWritable = NullWritable.get()
+
def createDriverResultsArray = new JArrayList[String]
def processResults(results: JArrayList[String]) = results
@@ -197,7 +236,11 @@ private[hive] object HiveShim {
}
def toCatalystDecimal(hdoi: HiveDecimalObjectInspector, data: Any): Decimal = {
- Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue())
+ if (hdoi.preferWritable()) {
+ Decimal(hdoi.getPrimitiveWritableObject(data).getHiveDecimal().bigDecimalValue)
+ } else {
+ Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue())
+ }
}
}
diff --git a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala
index 7c8cbf10c1..b78c75798e 100644
--- a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala
+++ b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala
@@ -22,6 +22,7 @@ import java.util.Properties
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
+import org.apache.hadoop.io.NullWritable
import org.apache.hadoop.mapred.InputFormat
import org.apache.hadoop.hive.common.StatsSetupConst
import org.apache.hadoop.hive.common.`type`.{HiveDecimal}
@@ -163,91 +164,123 @@ private[hive] object HiveShim {
new TableDesc(inputFormatClass, outputFormatClass, properties)
}
+
def getStringWritableConstantObjectInspector(value: Any): ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
- TypeInfoFactory.stringTypeInfo,
- if (value == null) null else new hadoopIo.Text(value.asInstanceOf[String]))
+ TypeInfoFactory.stringTypeInfo, getStringWritable(value))
def getIntWritableConstantObjectInspector(value: Any): ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
- TypeInfoFactory.intTypeInfo,
- if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int]))
+ TypeInfoFactory.intTypeInfo, getIntWritable(value))
def getDoubleWritableConstantObjectInspector(value: Any): ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
- TypeInfoFactory.doubleTypeInfo, if (value == null) {
- null
- } else {
- new hiveIo.DoubleWritable(value.asInstanceOf[Double])
- })
+ TypeInfoFactory.doubleTypeInfo, getDoubleWritable(value))
def getBooleanWritableConstantObjectInspector(value: Any): ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
- TypeInfoFactory.booleanTypeInfo, if (value == null) {
- null
- } else {
- new hadoopIo.BooleanWritable(value.asInstanceOf[Boolean])
- })
+ TypeInfoFactory.booleanTypeInfo, getBooleanWritable(value))
def getLongWritableConstantObjectInspector(value: Any): ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
- TypeInfoFactory.longTypeInfo,
- if (value == null) null else new hadoopIo.LongWritable(value.asInstanceOf[Long]))
+ TypeInfoFactory.longTypeInfo, getLongWritable(value))
def getFloatWritableConstantObjectInspector(value: Any): ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
- TypeInfoFactory.floatTypeInfo, if (value == null) {
- null
- } else {
- new hadoopIo.FloatWritable(value.asInstanceOf[Float])
- })
+ TypeInfoFactory.floatTypeInfo, getFloatWritable(value))
def getShortWritableConstantObjectInspector(value: Any): ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
- TypeInfoFactory.shortTypeInfo,
- if (value == null) null else new hiveIo.ShortWritable(value.asInstanceOf[Short]))
+ TypeInfoFactory.shortTypeInfo, getShortWritable(value))
def getByteWritableConstantObjectInspector(value: Any): ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
- TypeInfoFactory.byteTypeInfo,
- if (value == null) null else new hiveIo.ByteWritable(value.asInstanceOf[Byte]))
+ TypeInfoFactory.byteTypeInfo, getByteWritable(value))
def getBinaryWritableConstantObjectInspector(value: Any): ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
- TypeInfoFactory.binaryTypeInfo, if (value == null) {
- null
- } else {
- new hadoopIo.BytesWritable(value.asInstanceOf[Array[Byte]])
- })
+ TypeInfoFactory.binaryTypeInfo, getBinaryWritable(value))
def getDateWritableConstantObjectInspector(value: Any): ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
- TypeInfoFactory.dateTypeInfo,
- if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[java.sql.Date]))
+ TypeInfoFactory.dateTypeInfo, getDateWritable(value))
def getTimestampWritableConstantObjectInspector(value: Any): ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
- TypeInfoFactory.timestampTypeInfo, if (value == null) {
- null
- } else {
- new hiveIo.TimestampWritable(value.asInstanceOf[java.sql.Timestamp])
- })
+ TypeInfoFactory.timestampTypeInfo, getTimestampWritable(value))
def getDecimalWritableConstantObjectInspector(value: Any): ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
- TypeInfoFactory.decimalTypeInfo,
- if (value == null) {
- null
- } else {
- // TODO precise, scale?
- new hiveIo.HiveDecimalWritable(
- HiveShim.createDecimal(value.asInstanceOf[Decimal].toBigDecimal.underlying()))
- })
+ TypeInfoFactory.decimalTypeInfo, getDecimalWritable(value))
def getPrimitiveNullWritableConstantObjectInspector: ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
TypeInfoFactory.voidTypeInfo, null)
+ def getStringWritable(value: Any): hadoopIo.Text =
+ if (value == null) null else new hadoopIo.Text(value.asInstanceOf[String])
+
+ def getIntWritable(value: Any): hadoopIo.IntWritable =
+ if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int])
+
+ def getDoubleWritable(value: Any): hiveIo.DoubleWritable =
+ if (value == null) {
+ null
+ } else {
+ new hiveIo.DoubleWritable(value.asInstanceOf[Double])
+ }
+
+ def getBooleanWritable(value: Any): hadoopIo.BooleanWritable =
+ if (value == null) {
+ null
+ } else {
+ new hadoopIo.BooleanWritable(value.asInstanceOf[Boolean])
+ }
+
+ def getLongWritable(value: Any): hadoopIo.LongWritable =
+ if (value == null) null else new hadoopIo.LongWritable(value.asInstanceOf[Long])
+
+ def getFloatWritable(value: Any): hadoopIo.FloatWritable =
+ if (value == null) {
+ null
+ } else {
+ new hadoopIo.FloatWritable(value.asInstanceOf[Float])
+ }
+
+ def getShortWritable(value: Any): hiveIo.ShortWritable =
+ if (value == null) null else new hiveIo.ShortWritable(value.asInstanceOf[Short])
+
+ def getByteWritable(value: Any): hiveIo.ByteWritable =
+ if (value == null) null else new hiveIo.ByteWritable(value.asInstanceOf[Byte])
+
+ def getBinaryWritable(value: Any): hadoopIo.BytesWritable =
+ if (value == null) {
+ null
+ } else {
+ new hadoopIo.BytesWritable(value.asInstanceOf[Array[Byte]])
+ }
+
+ def getDateWritable(value: Any): hiveIo.DateWritable =
+ if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[java.sql.Date])
+
+ def getTimestampWritable(value: Any): hiveIo.TimestampWritable =
+ if (value == null) {
+ null
+ } else {
+ new hiveIo.TimestampWritable(value.asInstanceOf[java.sql.Timestamp])
+ }
+
+ def getDecimalWritable(value: Any): hiveIo.HiveDecimalWritable =
+ if (value == null) {
+ null
+ } else {
+ // TODO precise, scale?
+ new hiveIo.HiveDecimalWritable(
+ HiveShim.createDecimal(value.asInstanceOf[Decimal].toBigDecimal.underlying()))
+ }
+
+ def getPrimitiveNullWritable: NullWritable = NullWritable.get()
+
def createDriverResultsArray = new JArrayList[Object]
def processResults(results: JArrayList[Object]) = {
@@ -355,7 +388,12 @@ private[hive] object HiveShim {
}
def toCatalystDecimal(hdoi: HiveDecimalObjectInspector, data: Any): Decimal = {
- Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue(), hdoi.precision(), hdoi.scale())
+ if (hdoi.preferWritable()) {
+ Decimal(hdoi.getPrimitiveWritableObject(data).getHiveDecimal().bigDecimalValue,
+ hdoi.precision(), hdoi.scale())
+ } else {
+ Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue(), hdoi.precision(), hdoi.scale())
+ }
}
}