From ae9f128608f67cbee0a2fb24754783ee3b4f3098 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Thu, 18 Dec 2014 20:21:52 -0800 Subject: [SPARK-4573] [SQL] Add SettableStructObjectInspector support in "wrap" function Hive UDAF may create an customized object constructed by SettableStructObjectInspector, this is critical when integrate Hive UDAF with the refactor-ed UDAF interface. Performance issue in `wrap/unwrap` since more match cases added, will do it in another PR. Author: Cheng Hao Closes #3429 from chenghao-intel/settable_oi and squashes the following commits: 9f0aff3 [Cheng Hao] update code style issues as feedbacks 2b0561d [Cheng Hao] Add more scala doc f5a40e8 [Cheng Hao] add scala doc 2977e9b [Cheng Hao] remove the timezone setting for test suite 3ed284c [Cheng Hao] fix the date type comparison f1b6749 [Cheng Hao] Update the comment 932940d [Cheng Hao] Add more unit test 72e4332 [Cheng Hao] Add settable StructObjectInspector support --- .../org/apache/spark/sql/hive/HiveInspectors.scala | 346 +++++++++++++++++---- .../apache/spark/sql/hive/HiveInspectorSuite.scala | 220 +++++++++++++ .../scala/org/apache/spark/sql/hive/Shim12.scala | 87 ++++-- .../scala/org/apache/spark/sql/hive/Shim13.scala | 130 +++++--- 4 files changed, 659 insertions(+), 124 deletions(-) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala (limited to 'sql/hive') diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 0eeac8620f..06189341f8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -18,9 +18,7 @@ package org.apache.spark.sql.hive import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar} -import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory import org.apache.hadoop.hive.serde2.objectinspector._ -import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector._ import org.apache.hadoop.hive.serde2.objectinspector.primitive._ import org.apache.hadoop.hive.serde2.{io => hiveIo} import org.apache.hadoop.{io => hadoopIo} @@ -33,6 +31,145 @@ import org.apache.spark.sql.catalyst.types.decimal.Decimal /* Implicit conversions */ import scala.collection.JavaConversions._ +/** + * 1. The Underlying data type in catalyst and in Hive + * In catalyst: + * Primitive => + * java.lang.String + * int / scala.Int + * boolean / scala.Boolean + * float / scala.Float + * double / scala.Double + * long / scala.Long + * short / scala.Short + * byte / scala.Byte + * org.apache.spark.sql.catalyst.types.decimal.Decimal + * Array[Byte] + * java.sql.Date + * java.sql.Timestamp + * Complex Types => + * Map: scala.collection.immutable.Map + * List: scala.collection.immutable.Seq + * Struct: + * org.apache.spark.sql.catalyst.expression.Row + * Union: NOT SUPPORTED YET + * The Complex types plays as a container, which can hold arbitrary data types. + * + * In Hive, the native data types are various, in UDF/UDAF/UDTF, and associated with + * Object Inspectors, in Hive expression evaluation framework, the underlying data are + * Primitive Type + * Java Boxed Primitives: + * org.apache.hadoop.hive.common.type.HiveVarchar + * java.lang.String + * java.lang.Integer + * java.lang.Boolean + * java.lang.Float + * java.lang.Double + * java.lang.Long + * java.lang.Short + * java.lang.Byte + * org.apache.hadoop.hive.common.`type`.HiveDecimal + * byte[] + * java.sql.Date + * java.sql.Timestamp + * Writables: + * org.apache.hadoop.hive.serde2.io.HiveVarcharWritable + * org.apache.hadoop.io.Text + * org.apache.hadoop.io.IntWritable + * org.apache.hadoop.hive.serde2.io.DoubleWritable + * org.apache.hadoop.io.BooleanWritable + * org.apache.hadoop.io.LongWritable + * org.apache.hadoop.io.FloatWritable + * org.apache.hadoop.hive.serde2.io.ShortWritable + * org.apache.hadoop.hive.serde2.io.ByteWritable + * org.apache.hadoop.io.BytesWritable + * org.apache.hadoop.hive.serde2.io.DateWritable + * org.apache.hadoop.hive.serde2.io.TimestampWritable + * org.apache.hadoop.hive.serde2.io.HiveDecimalWritable + * Complex Type + * List: Object[] / java.util.List + * Map: java.util.Map + * Struct: Object[] / java.util.List / java POJO + * Union: class StandardUnion { byte tag; Object object } + * + * NOTICE: HiveVarchar is not supported by catalyst, it will be simply considered as String type. + * + * + * 2. Hive ObjectInspector is a group of flexible APIs to inspect value in different data + * representation, and developers can extend those API as needed, so technically, + * object inspector supports arbitrary data type in java. + * + * Fortunately, only few built-in Hive Object Inspectors are used in generic udf/udaf/udtf + * evaluation. + * 1) Primitive Types (PrimitiveObjectInspector & its sub classes) + {{{ + public interface PrimitiveObjectInspector { + // Java Primitives (java.lang.Integer, java.lang.String etc.) + Object getPrimitiveWritableObject(Object o); + // Writables (hadoop.io.IntWritable, hadoop.io.Text etc.) + Object getPrimitiveJavaObject(Object o); + // ObjectInspector only inspect the `writable` always return true, we need to check it + // before invoking the methods above. + boolean preferWritable(); + ... + } + }}} + + * 2) Complex Types: + * ListObjectInspector: inspects java array or [[java.util.List]] + * MapObjectInspector: inspects [[java.util.Map]] + * Struct.StructObjectInspector: inspects java array, [[java.util.List]] and + * even a normal java object (POJO) + * UnionObjectInspector: (tag: Int, object data) (TODO: not supported by SparkSQL yet) + * + * 3) ConstantObjectInspector: + * Constant object inspector can be either primitive type or Complex type, and it bundles a + * constant value as its property, usually the value is created when the constant object inspector + * constructed. + * {{{ + public interface ConstantObjectInspector extends ObjectInspector { + Object getWritableConstantValue(); + ... + } + }}} + * Hive provides 3 built-in constant object inspectors: + * Primitive Object Inspectors: + * WritableConstantStringObjectInspector + * WritableConstantHiveVarcharObjectInspector + * WritableConstantHiveDecimalObjectInspector + * WritableConstantTimestampObjectInspector + * WritableConstantIntObjectInspector + * WritableConstantDoubleObjectInspector + * WritableConstantBooleanObjectInspector + * WritableConstantLongObjectInspector + * WritableConstantFloatObjectInspector + * WritableConstantShortObjectInspector + * WritableConstantByteObjectInspector + * WritableConstantBinaryObjectInspector + * WritableConstantDateObjectInspector + * Map Object Inspector: + * StandardConstantMapObjectInspector + * List Object Inspector: + * StandardConstantListObjectInspector]] + * Struct Object Inspector: Hive doesn't provide the built-in constant object inspector for Struct + * Union Object Inspector: Hive doesn't provide the built-in constant object inspector for Union + * + * + * 3. This trait facilitates: + * Data Unwrapping: Hive Data => Catalyst Data (unwrap) + * Data Wrapping: Catalyst Data => Hive Data (wrap) + * Binding the Object Inspector for Catalyst Data (toInspector) + * Retrieving the Catalyst Data Type from Object Inspector (inspectorToDataType) + * + * + * 4. Future Improvement (TODO) + * This implementation is quite ugly and inefficient: + * a. Pattern matching in runtime + * b. Small objects creation in catalyst data => writable + * c. Unnecessary unwrap / wrap for nested UDF invoking: + * e.g. date_add(printf("%s-%s-%s", a,b,c), 3) + * We don't need to unwrap the data for printf and wrap it again and passes in data_add + */ private[hive] trait HiveInspectors { def javaClassToDataType(clz: Class[_]): DataType = clz match { @@ -87,10 +224,23 @@ private[hive] trait HiveInspectors { * @param oi the ObjectInspector associated with the Hive Type * @return convert the data into catalyst type * TODO return the function of (data => Any) instead for performance consideration + * + * Strictly follows the following order in unwrapping (constant OI has the higher priority): + * Constant Null object inspector => + * return null + * Constant object inspector => + * extract the value from constant object inspector + * Check whether the `data` is null => + * return null if true + * If object inspector prefers writable => + * extract writable from `data` and then get the catalyst type from the writable + * Extract the java object directly from the object inspector + * + * NOTICE: the complex data type requires recursive unwrapping. */ def unwrap(data: Any, oi: ObjectInspector): Any = oi match { - case _ if data == null => null - case poi: VoidObjectInspector => null + case coi: ConstantObjectInspector if coi.getWritableConstantValue == null => null + case poi: WritableConstantStringObjectInspector => poi.getWritableConstantValue.toString case poi: WritableConstantHiveVarcharObjectInspector => poi.getWritableConstantValue.getHiveVarchar.getValue case poi: WritableConstantHiveDecimalObjectInspector => @@ -119,12 +269,44 @@ private[hive] trait HiveInspectors { System.arraycopy(writable.getBytes, 0, temp, 0, temp.length) temp case poi: WritableConstantDateObjectInspector => poi.getWritableConstantValue.get() - case hvoi: HiveVarcharObjectInspector => hvoi.getPrimitiveJavaObject(data).getValue - case hdoi: HiveDecimalObjectInspector => HiveShim.toCatalystDecimal(hdoi, data) - // org.apache.hadoop.hive.serde2.io.TimestampWritable.set will reset current time object - // if next timestamp is null, so Timestamp object is cloned - case ti: TimestampObjectInspector => ti.getPrimitiveJavaObject(data).clone() - case pi: PrimitiveObjectInspector => pi.getPrimitiveJavaObject(data) + case mi: StandardConstantMapObjectInspector => + // take the value from the map inspector object, rather than the input data + mi.getWritableConstantValue.map { case (k, v) => + (unwrap(k, mi.getMapKeyObjectInspector), + unwrap(v, mi.getMapValueObjectInspector)) + }.toMap + case li: StandardConstantListObjectInspector => + // take the value from the list inspector object, rather than the input data + li.getWritableConstantValue.map(unwrap(_, li.getListElementObjectInspector)).toSeq + // if the value is null, we don't care about the object inspector type + case _ if data == null => null + case poi: VoidObjectInspector => null // always be null for void object inspector + case pi: PrimitiveObjectInspector => pi match { + // We think HiveVarchar is also a String + case hvoi: HiveVarcharObjectInspector if hvoi.preferWritable() => + hvoi.getPrimitiveWritableObject(data).getHiveVarchar.getValue + case hvoi: HiveVarcharObjectInspector => hvoi.getPrimitiveJavaObject(data).getValue + case x: StringObjectInspector if x.preferWritable() => + x.getPrimitiveWritableObject(data).toString + case x: IntObjectInspector if x.preferWritable() => x.get(data) + case x: BooleanObjectInspector if x.preferWritable() => x.get(data) + case x: FloatObjectInspector if x.preferWritable() => x.get(data) + case x: DoubleObjectInspector if x.preferWritable() => x.get(data) + case x: LongObjectInspector if x.preferWritable() => x.get(data) + case x: ShortObjectInspector if x.preferWritable() => x.get(data) + case x: ByteObjectInspector if x.preferWritable() => x.get(data) + case x: HiveDecimalObjectInspector => HiveShim.toCatalystDecimal(x, data) + case x: BinaryObjectInspector if x.preferWritable() => + x.getPrimitiveWritableObject(data).copyBytes() + case x: DateObjectInspector if x.preferWritable() => + x.getPrimitiveWritableObject(data).get() + // org.apache.hadoop.hive.serde2.io.TimestampWritable.set will reset current time object + // if next timestamp is null, so Timestamp object is cloned + case x: TimestampObjectInspector if x.preferWritable() => + x.getPrimitiveWritableObject(data).getTimestamp.clone() + case ti: TimestampObjectInspector => ti.getPrimitiveJavaObject(data).clone() + case _ => pi.getPrimitiveJavaObject(data) + } case li: ListObjectInspector => Option(li.getList(data)) .map(_.map(unwrap(_, li.getListElementObjectInspector)).toSeq) @@ -132,10 +314,11 @@ private[hive] trait HiveInspectors { case mi: MapObjectInspector => Option(mi.getMap(data)).map( _.map { - case (k,v) => + case (k, v) => (unwrap(k, mi.getMapKeyObjectInspector), unwrap(v, mi.getMapValueObjectInspector)) }.toMap).orNull + // currently, hive doesn't provide the ConstantStructObjectInspector case si: StructObjectInspector => val allRefs = si.getAllStructFieldRefs new GenericRow( @@ -191,55 +374,89 @@ private[hive] trait HiveInspectors { * the ObjectInspector should also be consistent with those returned from * toInspector: DataType => ObjectInspector and * toInspector: Expression => ObjectInspector + * + * Strictly follows the following order in wrapping (constant OI has the higher priority): + * Constant object inspector => return the bundled value of Constant object inspector + * Check whether the `a` is null => return null if true + * If object inspector prefers writable object => return a Writable for the given data `a` + * Map the catalyst data to the boxed java primitive + * + * NOTICE: the complex data type requires recursive wrapping. */ - def wrap(a: Any, oi: ObjectInspector): AnyRef = if (a == null) { - null - } else { - oi match { - case x: ConstantObjectInspector => x.getWritableConstantValue - case x: PrimitiveObjectInspector => a match { - // TODO what if x.preferWritable() == true? reuse the writable? - case s: String => s: java.lang.String - case i: Int => i: java.lang.Integer - case b: Boolean => b: java.lang.Boolean - case f: Float => f: java.lang.Float - case d: Double => d: java.lang.Double - case l: Long => l: java.lang.Long - case l: Short => l: java.lang.Short - case l: Byte => l: java.lang.Byte - case b: BigDecimal => HiveShim.createDecimal(b.underlying()) - case d: Decimal => HiveShim.createDecimal(d.toBigDecimal.underlying()) - case b: Array[Byte] => b - case d: java.sql.Date => d - case t: java.sql.Timestamp => t + def wrap(a: Any, oi: ObjectInspector): AnyRef = oi match { + case x: ConstantObjectInspector => x.getWritableConstantValue + case _ if a == null => null + case x: PrimitiveObjectInspector => x match { + // TODO we don't support the HiveVarcharObjectInspector yet. + case _: StringObjectInspector if x.preferWritable() => HiveShim.getStringWritable(a) + case _: StringObjectInspector => a.asInstanceOf[java.lang.String] + case _: IntObjectInspector if x.preferWritable() => HiveShim.getIntWritable(a) + case _: IntObjectInspector => a.asInstanceOf[java.lang.Integer] + case _: BooleanObjectInspector if x.preferWritable() => HiveShim.getBooleanWritable(a) + case _: BooleanObjectInspector => a.asInstanceOf[java.lang.Boolean] + case _: FloatObjectInspector if x.preferWritable() => HiveShim.getFloatWritable(a) + case _: FloatObjectInspector => a.asInstanceOf[java.lang.Float] + case _: DoubleObjectInspector if x.preferWritable() => HiveShim.getDoubleWritable(a) + case _: DoubleObjectInspector => a.asInstanceOf[java.lang.Double] + case _: LongObjectInspector if x.preferWritable() => HiveShim.getLongWritable(a) + case _: LongObjectInspector => a.asInstanceOf[java.lang.Long] + case _: ShortObjectInspector if x.preferWritable() => HiveShim.getShortWritable(a) + case _: ShortObjectInspector => a.asInstanceOf[java.lang.Short] + case _: ByteObjectInspector if x.preferWritable() => HiveShim.getByteWritable(a) + case _: ByteObjectInspector => a.asInstanceOf[java.lang.Byte] + case _: HiveDecimalObjectInspector if x.preferWritable() => + HiveShim.getDecimalWritable(a.asInstanceOf[Decimal]) + case _: HiveDecimalObjectInspector => + HiveShim.createDecimal(a.asInstanceOf[Decimal].toBigDecimal.underlying()) + case _: BinaryObjectInspector if x.preferWritable() => HiveShim.getBinaryWritable(a) + case _: BinaryObjectInspector => a.asInstanceOf[Array[Byte]] + case _: DateObjectInspector if x.preferWritable() => HiveShim.getDateWritable(a) + case _: DateObjectInspector => a.asInstanceOf[java.sql.Date] + case _: TimestampObjectInspector if x.preferWritable() => HiveShim.getTimestampWritable(a) + case _: TimestampObjectInspector => a.asInstanceOf[java.sql.Timestamp] + } + case x: SettableStructObjectInspector => + val fieldRefs = x.getAllStructFieldRefs + val row = a.asInstanceOf[Seq[_]] + // 1. create the pojo (most likely) object + val result = x.create() + var i = 0 + while (i < fieldRefs.length) { + // 2. set the property for the pojo + x.setStructFieldData( + result, + fieldRefs.get(i), + wrap(row(i), fieldRefs.get(i).getFieldObjectInspector)) + i += 1 } - case x: StructObjectInspector => - val fieldRefs = x.getAllStructFieldRefs - val row = a.asInstanceOf[Seq[_]] - val result = new java.util.ArrayList[AnyRef](fieldRefs.length) - var i = 0 - while (i < fieldRefs.length) { - result.add(wrap(row(i), fieldRefs.get(i).getFieldObjectInspector)) - i += 1 - } - result - case x: ListObjectInspector => - val list = new java.util.ArrayList[Object] - a.asInstanceOf[Seq[_]].foreach { - v => list.add(wrap(v, x.getListElementObjectInspector)) - } - list - case x: MapObjectInspector => - // Some UDFs seem to assume we pass in a HashMap. - val hashMap = new java.util.HashMap[AnyRef, AnyRef]() - hashMap.putAll(a.asInstanceOf[Map[_, _]].map { - case (k, v) => - wrap(k, x.getMapKeyObjectInspector) -> wrap(v, x.getMapValueObjectInspector) - }) + result + case x: StructObjectInspector => + val fieldRefs = x.getAllStructFieldRefs + val row = a.asInstanceOf[Seq[_]] + val result = new java.util.ArrayList[AnyRef](fieldRefs.length) + var i = 0 + while (i < fieldRefs.length) { + result.add(wrap(row(i), fieldRefs.get(i).getFieldObjectInspector)) + i += 1 + } - hashMap - } + result + case x: ListObjectInspector => + val list = new java.util.ArrayList[Object] + a.asInstanceOf[Seq[_]].foreach { + v => list.add(wrap(v, x.getListElementObjectInspector)) + } + list + case x: MapObjectInspector => + // Some UDFs seem to assume we pass in a HashMap. + val hashMap = new java.util.HashMap[AnyRef, AnyRef]() + hashMap.putAll(a.asInstanceOf[Map[_, _]].map { + case (k, v) => + wrap(k, x.getMapKeyObjectInspector) -> wrap(v, x.getMapValueObjectInspector) + }) + + hashMap } def wrap( @@ -254,6 +471,11 @@ private[hive] trait HiveInspectors { cache } + /** + * @param dataType Catalyst data type + * @return Hive java object inspector (recursively), not the Writable ObjectInspector + * We can easily map to the Hive built-in object inspector according to the data type. + */ def toInspector(dataType: DataType): ObjectInspector = dataType match { case ArrayType(tpe, _) => ObjectInspectorFactory.getStandardListObjectInspector(toInspector(tpe)) @@ -272,12 +494,20 @@ private[hive] trait HiveInspectors { case BinaryType => PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector case DateType => PrimitiveObjectInspectorFactory.javaDateObjectInspector case TimestampType => PrimitiveObjectInspectorFactory.javaTimestampObjectInspector + // TODO decimal precision? case DecimalType() => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector case StructType(fields) => ObjectInspectorFactory.getStandardStructObjectInspector( fields.map(f => f.name), fields.map(f => toInspector(f.dataType))) } + /** + * Map the catalyst expression to ObjectInspector, however, + * if the expression is [[Literal]] or foldable, a constant writable object inspector returns; + * Otherwise, we always get the object inspector according to its data type(in catalyst) + * @param expr Catalyst expression to be mapped + * @return Hive java objectinspector (recursively). + */ def toInspector(expr: Expression): ObjectInspector = expr match { case Literal(value, StringType) => HiveShim.getStringWritableConstantObjectInspector(value) @@ -326,8 +556,12 @@ private[hive] trait HiveInspectors { }) ObjectInspectorFactory.getStandardConstantMapObjectInspector(keyOI, valueOI, map) } + // We will enumerate all of the possible constant expressions, throw exception if we missed case Literal(_, dt) => sys.error(s"Hive doesn't support the constant type [$dt].") + // ideally, we don't test the foldable here(but in optimizer), however, some of the + // Hive UDF / UDAF requires its argument to be constant objectinspector, we do it eagerly. case _ if expr.foldable => toInspector(Literal(expr.eval(), expr.dataType)) + // For those non constant expression, map to object inspector according to its data type case _ => toInspector(expr.dataType) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala new file mode 100644 index 0000000000..bfe608a51a --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.sql.Date +import java.util + +import org.apache.hadoop.hive.serde2.io.DoubleWritable +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.types.decimal.Decimal +import org.scalatest.FunSuite + +import org.apache.hadoop.hive.ql.udf.UDAFPercentile +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, StructObjectInspector, ObjectInspectorFactory} +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions +import org.apache.hadoop.io.LongWritable + +import org.apache.spark.sql.catalyst.expressions.{Literal, Row} + +class HiveInspectorSuite extends FunSuite with HiveInspectors { + test("Test wrap SettableStructObjectInspector") { + val udaf = new UDAFPercentile.PercentileLongEvaluator() + udaf.init() + + udaf.iterate(new LongWritable(1), 0.1) + udaf.iterate(new LongWritable(1), 0.1) + + val state = udaf.terminatePartial() + + val soi = ObjectInspectorFactory.getReflectionObjectInspector( + classOf[UDAFPercentile.State], + ObjectInspectorOptions.JAVA).asInstanceOf[StructObjectInspector] + + val a = unwrap(state, soi).asInstanceOf[Row] + val b = wrap(a, soi).asInstanceOf[UDAFPercentile.State] + + val sfCounts = soi.getStructFieldRef("counts") + val sfPercentiles = soi.getStructFieldRef("percentiles") + + assert(2 === soi.getStructFieldData(b, sfCounts) + .asInstanceOf[util.Map[LongWritable, LongWritable]] + .get(new LongWritable(1L)) + .get()) + assert(0.1 === soi.getStructFieldData(b, sfPercentiles) + .asInstanceOf[util.ArrayList[DoubleWritable]] + .get(0) + .get()) + } + + val data = + Literal(true) :: + Literal(0.asInstanceOf[Byte]) :: + Literal(0.asInstanceOf[Short]) :: + Literal(0) :: + Literal(0.asInstanceOf[Long]) :: + Literal(0.asInstanceOf[Float]) :: + Literal(0.asInstanceOf[Double]) :: + Literal("0") :: + Literal(new Date(2014, 9, 23)) :: + Literal(Decimal(BigDecimal(123.123))) :: + Literal(new java.sql.Timestamp(123123)) :: + Literal(Array[Byte](1,2,3)) :: + Literal(Seq[Int](1,2,3), ArrayType(IntegerType)) :: + Literal(Map[Int, Int](1->2, 2->1), MapType(IntegerType, IntegerType)) :: + Literal(Row(1,2.0d,3.0f), + StructType(StructField("c1", IntegerType) :: + StructField("c2", DoubleType) :: + StructField("c3", FloatType) :: Nil)) :: + Nil + + val row = data.map(_.eval(null)) + val dataTypes = data.map(_.dataType) + + import scala.collection.JavaConversions._ + def toWritableInspector(dataType: DataType): ObjectInspector = dataType match { + case ArrayType(tpe, _) => + ObjectInspectorFactory.getStandardListObjectInspector(toWritableInspector(tpe)) + case MapType(keyType, valueType, _) => + ObjectInspectorFactory.getStandardMapObjectInspector( + toWritableInspector(keyType), toWritableInspector(valueType)) + case StringType => PrimitiveObjectInspectorFactory.writableStringObjectInspector + case IntegerType => PrimitiveObjectInspectorFactory.writableIntObjectInspector + case DoubleType => PrimitiveObjectInspectorFactory.writableDoubleObjectInspector + case BooleanType => PrimitiveObjectInspectorFactory.writableBooleanObjectInspector + case LongType => PrimitiveObjectInspectorFactory.writableLongObjectInspector + case FloatType => PrimitiveObjectInspectorFactory.writableFloatObjectInspector + case ShortType => PrimitiveObjectInspectorFactory.writableShortObjectInspector + case ByteType => PrimitiveObjectInspectorFactory.writableByteObjectInspector + case NullType => PrimitiveObjectInspectorFactory.writableVoidObjectInspector + case BinaryType => PrimitiveObjectInspectorFactory.writableBinaryObjectInspector + case DateType => PrimitiveObjectInspectorFactory.writableDateObjectInspector + case TimestampType => PrimitiveObjectInspectorFactory.writableTimestampObjectInspector + case DecimalType() => PrimitiveObjectInspectorFactory.writableHiveDecimalObjectInspector + case StructType(fields) => + ObjectInspectorFactory.getStandardStructObjectInspector( + fields.map(f => f.name), fields.map(f => toWritableInspector(f.dataType))) + } + + def checkDataType(dt1: Seq[DataType], dt2: Seq[DataType]): Unit = { + dt1.zip(dt2).map { + case (dd1, dd2) => + assert(dd1.getClass === dd2.getClass) // DecimalType doesn't has the default precision info + } + } + + def checkValues(row1: Seq[Any], row2: Seq[Any]): Unit = { + row1.zip(row2).map { + case (r1, r2) => checkValues(r1, r2) + } + } + + def checkValues(v1: Any, v2: Any): Unit = { + (v1, v2) match { + case (r1: Decimal, r2: Decimal) => + // Ignore the Decimal precision + assert(r1.compare(r2) === 0) + case (r1: Array[Byte], r2: Array[Byte]) + if r1 != null && r2 != null && r1.length == r2.length => + r1.zip(r2).map { case (b1, b2) => assert(b1 === b2) } + case (r1: Date, r2: Date) => assert(r1.compareTo(r2) === 0) + case (r1, r2) => assert(r1 === r2) + } + } + + test("oi => datatype => oi") { + val ois = dataTypes.map(toInspector) + + checkDataType(ois.map(inspectorToDataType), dataTypes) + checkDataType(dataTypes.map(toWritableInspector).map(inspectorToDataType), dataTypes) + } + + test("wrap / unwrap null, constant null and writables") { + val writableOIs = dataTypes.map(toWritableInspector) + val nullRow = data.map(d => null) + + checkValues(nullRow, nullRow.zip(writableOIs).map { + case (d, oi) => unwrap(wrap(d, oi), oi) + }) + + // struct couldn't be constant, sweep it out + val constantExprs = data.filter(!_.dataType.isInstanceOf[StructType]) + val constantData = constantExprs.map(_.eval()) + val constantNullData = constantData.map(_ => null) + val constantWritableOIs = constantExprs.map(e => toWritableInspector(e.dataType)) + val constantNullWritableOIs = constantExprs.map(e => toInspector(Literal(null, e.dataType))) + + checkValues(constantData, constantData.zip(constantWritableOIs).map { + case (d, oi) => unwrap(wrap(d, oi), oi) + }) + + checkValues(constantNullData, constantData.zip(constantNullWritableOIs).map { + case (d, oi) => unwrap(wrap(d, oi), oi) + }) + + checkValues(constantNullData, constantNullData.zip(constantWritableOIs).map { + case (d, oi) => unwrap(wrap(d, oi), oi) + }) + } + + test("wrap / unwrap primitive writable object inspector") { + val writableOIs = dataTypes.map(toWritableInspector) + + checkValues(row, row.zip(writableOIs).map { + case (data, oi) => unwrap(wrap(data, oi), oi) + }) + } + + test("wrap / unwrap primitive java object inspector") { + val ois = dataTypes.map(toInspector) + + checkValues(row, row.zip(ois).map { + case (data, oi) => unwrap(wrap(data, oi), oi) + }) + } + + test("wrap / unwrap Struct Type") { + val dt = StructType(dataTypes.zipWithIndex.map { + case (t, idx) => StructField(s"c_$idx", t) + }) + + checkValues(row, unwrap(wrap(row, toInspector(dt)), toInspector(dt)).asInstanceOf[Row]) + checkValues(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) + } + + test("wrap / unwrap Array Type") { + val dt = ArrayType(dataTypes(0)) + + val d = row(0) :: row(0) :: Nil + checkValues(d, unwrap(wrap(d, toInspector(dt)), toInspector(dt))) + checkValues(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) + checkValues(d, unwrap(wrap(d, toInspector(Literal(d, dt))), toInspector(Literal(d, dt)))) + checkValues(d, unwrap(wrap(null, toInspector(Literal(d, dt))), toInspector(Literal(d, dt)))) + } + + test("wrap / unwrap Map Type") { + val dt = MapType(dataTypes(0), dataTypes(1)) + + val d = Map(row(0) -> row(1)) + checkValues(d, unwrap(wrap(d, toInspector(dt)), toInspector(dt))) + checkValues(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) + checkValues(d, unwrap(wrap(d, toInspector(Literal(d, dt))), toInspector(Literal(d, dt)))) + checkValues(d, unwrap(wrap(null, toInspector(Literal(d, dt))), toInspector(Literal(d, dt)))) + } +} diff --git a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala index 67cc886575..2d01a85067 100644 --- a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala +++ b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala @@ -35,6 +35,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.{PrimitiveObjectInspector, import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, TypeInfoFactory} import org.apache.hadoop.hive.serde2.{Deserializer, ColumnProjectionUtils} import org.apache.hadoop.hive.serde2.{io => hiveIo} +import org.apache.hadoop.io.NullWritable import org.apache.hadoop.{io => hadoopIo} import org.apache.hadoop.mapred.InputFormat import org.apache.spark.sql.catalyst.types.decimal.Decimal @@ -71,76 +72,114 @@ private[hive] object HiveShim { def getStringWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveCategory.STRING, - if (value == null) null else new hadoopIo.Text(value.asInstanceOf[String])) + getStringWritable(value)) def getIntWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveCategory.INT, - if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int])) + getIntWritable(value)) def getDoubleWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveCategory.DOUBLE, - if (value == null) null else new hiveIo.DoubleWritable(value.asInstanceOf[Double])) + getDoubleWritable(value)) def getBooleanWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveCategory.BOOLEAN, - if (value == null) null else new hadoopIo.BooleanWritable(value.asInstanceOf[Boolean])) + getBooleanWritable(value)) def getLongWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveCategory.LONG, - if (value == null) null else new hadoopIo.LongWritable(value.asInstanceOf[Long])) + getLongWritable(value)) def getFloatWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveCategory.FLOAT, - if (value == null) null else new hadoopIo.FloatWritable(value.asInstanceOf[Float])) + getFloatWritable(value)) def getShortWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveCategory.SHORT, - if (value == null) null else new hiveIo.ShortWritable(value.asInstanceOf[Short])) + getShortWritable(value)) def getByteWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveCategory.BYTE, - if (value == null) null else new hiveIo.ByteWritable(value.asInstanceOf[Byte])) + getByteWritable(value)) def getBinaryWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveCategory.BINARY, - if (value == null) null else new hadoopIo.BytesWritable(value.asInstanceOf[Array[Byte]])) + getBinaryWritable(value)) def getDateWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveCategory.DATE, - if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[java.sql.Date])) + getDateWritable(value)) def getTimestampWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveCategory.TIMESTAMP, - if (value == null) { - null - } else { - new hiveIo.TimestampWritable(value.asInstanceOf[java.sql.Timestamp]) - }) + getTimestampWritable(value)) def getDecimalWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveCategory.DECIMAL, - if (value == null) { - null - } else { - new hiveIo.HiveDecimalWritable( - HiveShim.createDecimal(value.asInstanceOf[Decimal].toBigDecimal.underlying())) - }) + getDecimalWritable(value)) def getPrimitiveNullWritableConstantObjectInspector: ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveCategory.VOID, null) + def getStringWritable(value: Any): hadoopIo.Text = + if (value == null) null else new hadoopIo.Text(value.asInstanceOf[String]) + + def getIntWritable(value: Any): hadoopIo.IntWritable = + if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int]) + + def getDoubleWritable(value: Any): hiveIo.DoubleWritable = + if (value == null) null else new hiveIo.DoubleWritable(value.asInstanceOf[Double]) + + def getBooleanWritable(value: Any): hadoopIo.BooleanWritable = + if (value == null) null else new hadoopIo.BooleanWritable(value.asInstanceOf[Boolean]) + + def getLongWritable(value: Any): hadoopIo.LongWritable = + if (value == null) null else new hadoopIo.LongWritable(value.asInstanceOf[Long]) + + def getFloatWritable(value: Any): hadoopIo.FloatWritable = + if (value == null) null else new hadoopIo.FloatWritable(value.asInstanceOf[Float]) + + def getShortWritable(value: Any): hiveIo.ShortWritable = + if (value == null) null else new hiveIo.ShortWritable(value.asInstanceOf[Short]) + + def getByteWritable(value: Any): hiveIo.ByteWritable = + if (value == null) null else new hiveIo.ByteWritable(value.asInstanceOf[Byte]) + + def getBinaryWritable(value: Any): hadoopIo.BytesWritable = + if (value == null) null else new hadoopIo.BytesWritable(value.asInstanceOf[Array[Byte]]) + + def getDateWritable(value: Any): hiveIo.DateWritable = + if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[java.sql.Date]) + + def getTimestampWritable(value: Any): hiveIo.TimestampWritable = + if (value == null) { + null + } else { + new hiveIo.TimestampWritable(value.asInstanceOf[java.sql.Timestamp]) + } + + def getDecimalWritable(value: Any): hiveIo.HiveDecimalWritable = + if (value == null) { + null + } else { + new hiveIo.HiveDecimalWritable( + HiveShim.createDecimal(value.asInstanceOf[Decimal].toBigDecimal.underlying())) + } + + def getPrimitiveNullWritable: NullWritable = NullWritable.get() + def createDriverResultsArray = new JArrayList[String] def processResults(results: JArrayList[String]) = results @@ -197,7 +236,11 @@ private[hive] object HiveShim { } def toCatalystDecimal(hdoi: HiveDecimalObjectInspector, data: Any): Decimal = { - Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue()) + if (hdoi.preferWritable()) { + Decimal(hdoi.getPrimitiveWritableObject(data).getHiveDecimal().bigDecimalValue) + } else { + Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue()) + } } } diff --git a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala index 7c8cbf10c1..b78c75798e 100644 --- a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala +++ b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala @@ -22,6 +22,7 @@ import java.util.Properties import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.NullWritable import org.apache.hadoop.mapred.InputFormat import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.hadoop.hive.common.`type`.{HiveDecimal} @@ -163,91 +164,123 @@ private[hive] object HiveShim { new TableDesc(inputFormatClass, outputFormatClass, properties) } + def getStringWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.stringTypeInfo, - if (value == null) null else new hadoopIo.Text(value.asInstanceOf[String])) + TypeInfoFactory.stringTypeInfo, getStringWritable(value)) def getIntWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.intTypeInfo, - if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int])) + TypeInfoFactory.intTypeInfo, getIntWritable(value)) def getDoubleWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.doubleTypeInfo, if (value == null) { - null - } else { - new hiveIo.DoubleWritable(value.asInstanceOf[Double]) - }) + TypeInfoFactory.doubleTypeInfo, getDoubleWritable(value)) def getBooleanWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.booleanTypeInfo, if (value == null) { - null - } else { - new hadoopIo.BooleanWritable(value.asInstanceOf[Boolean]) - }) + TypeInfoFactory.booleanTypeInfo, getBooleanWritable(value)) def getLongWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.longTypeInfo, - if (value == null) null else new hadoopIo.LongWritable(value.asInstanceOf[Long])) + TypeInfoFactory.longTypeInfo, getLongWritable(value)) def getFloatWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.floatTypeInfo, if (value == null) { - null - } else { - new hadoopIo.FloatWritable(value.asInstanceOf[Float]) - }) + TypeInfoFactory.floatTypeInfo, getFloatWritable(value)) def getShortWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.shortTypeInfo, - if (value == null) null else new hiveIo.ShortWritable(value.asInstanceOf[Short])) + TypeInfoFactory.shortTypeInfo, getShortWritable(value)) def getByteWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.byteTypeInfo, - if (value == null) null else new hiveIo.ByteWritable(value.asInstanceOf[Byte])) + TypeInfoFactory.byteTypeInfo, getByteWritable(value)) def getBinaryWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.binaryTypeInfo, if (value == null) { - null - } else { - new hadoopIo.BytesWritable(value.asInstanceOf[Array[Byte]]) - }) + TypeInfoFactory.binaryTypeInfo, getBinaryWritable(value)) def getDateWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.dateTypeInfo, - if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[java.sql.Date])) + TypeInfoFactory.dateTypeInfo, getDateWritable(value)) def getTimestampWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.timestampTypeInfo, if (value == null) { - null - } else { - new hiveIo.TimestampWritable(value.asInstanceOf[java.sql.Timestamp]) - }) + TypeInfoFactory.timestampTypeInfo, getTimestampWritable(value)) def getDecimalWritableConstantObjectInspector(value: Any): ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.decimalTypeInfo, - if (value == null) { - null - } else { - // TODO precise, scale? - new hiveIo.HiveDecimalWritable( - HiveShim.createDecimal(value.asInstanceOf[Decimal].toBigDecimal.underlying())) - }) + TypeInfoFactory.decimalTypeInfo, getDecimalWritable(value)) def getPrimitiveNullWritableConstantObjectInspector: ObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( TypeInfoFactory.voidTypeInfo, null) + def getStringWritable(value: Any): hadoopIo.Text = + if (value == null) null else new hadoopIo.Text(value.asInstanceOf[String]) + + def getIntWritable(value: Any): hadoopIo.IntWritable = + if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int]) + + def getDoubleWritable(value: Any): hiveIo.DoubleWritable = + if (value == null) { + null + } else { + new hiveIo.DoubleWritable(value.asInstanceOf[Double]) + } + + def getBooleanWritable(value: Any): hadoopIo.BooleanWritable = + if (value == null) { + null + } else { + new hadoopIo.BooleanWritable(value.asInstanceOf[Boolean]) + } + + def getLongWritable(value: Any): hadoopIo.LongWritable = + if (value == null) null else new hadoopIo.LongWritable(value.asInstanceOf[Long]) + + def getFloatWritable(value: Any): hadoopIo.FloatWritable = + if (value == null) { + null + } else { + new hadoopIo.FloatWritable(value.asInstanceOf[Float]) + } + + def getShortWritable(value: Any): hiveIo.ShortWritable = + if (value == null) null else new hiveIo.ShortWritable(value.asInstanceOf[Short]) + + def getByteWritable(value: Any): hiveIo.ByteWritable = + if (value == null) null else new hiveIo.ByteWritable(value.asInstanceOf[Byte]) + + def getBinaryWritable(value: Any): hadoopIo.BytesWritable = + if (value == null) { + null + } else { + new hadoopIo.BytesWritable(value.asInstanceOf[Array[Byte]]) + } + + def getDateWritable(value: Any): hiveIo.DateWritable = + if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[java.sql.Date]) + + def getTimestampWritable(value: Any): hiveIo.TimestampWritable = + if (value == null) { + null + } else { + new hiveIo.TimestampWritable(value.asInstanceOf[java.sql.Timestamp]) + } + + def getDecimalWritable(value: Any): hiveIo.HiveDecimalWritable = + if (value == null) { + null + } else { + // TODO precise, scale? + new hiveIo.HiveDecimalWritable( + HiveShim.createDecimal(value.asInstanceOf[Decimal].toBigDecimal.underlying())) + } + + def getPrimitiveNullWritable: NullWritable = NullWritable.get() + def createDriverResultsArray = new JArrayList[Object] def processResults(results: JArrayList[Object]) = { @@ -355,7 +388,12 @@ private[hive] object HiveShim { } def toCatalystDecimal(hdoi: HiveDecimalObjectInspector, data: Any): Decimal = { - Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue(), hdoi.precision(), hdoi.scale()) + if (hdoi.preferWritable()) { + Decimal(hdoi.getPrimitiveWritableObject(data).getHiveDecimal().bigDecimalValue, + hdoi.precision(), hdoi.scale()) + } else { + Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue(), hdoi.precision(), hdoi.scale()) + } } } -- cgit v1.2.3