diff options
author | Reynold Xin <rxin@databricks.com> | 2015-07-26 10:27:39 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2015-07-26 10:27:39 -0700 |
commit | 6c400b4f39be3fb5f473b8d2db11d239ea8ddf42 (patch) | |
tree | 4bd58339db246f3a686185e441eb1fb13ea0bc3a /sql/hive | |
parent | b79bf1df6238c087c3ec524344f1fc179719c5de (diff) | |
download | spark-6c400b4f39be3fb5f473b8d2db11d239ea8ddf42.tar.gz spark-6c400b4f39be3fb5f473b8d2db11d239ea8ddf42.tar.bz2 spark-6c400b4f39be3fb5f473b8d2db11d239ea8ddf42.zip |
[SPARK-9354][SQL] Remove InternalRow.get generic getter call in Hive integration code.
Replaced them with get(ordinal, datatype) so we can use UnsafeRow here.
I passed the data types throughout.
Author: Reynold Xin <rxin@databricks.com>
Closes #7669 from rxin/row-generic-getter-hive and squashes the following commits:
3467d8e [Reynold Xin] [SPARK-9354][SQL] Remove Internal.get generic getter call in Hive integration code.
Diffstat (limited to 'sql/hive')
3 files changed, 102 insertions, 68 deletions
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 16977ce30c..f467500259 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -46,7 +46,7 @@ import scala.collection.JavaConversions._ * long / scala.Long * short / scala.Short * byte / scala.Byte - * org.apache.spark.sql.types.Decimal + * [[org.apache.spark.sql.types.Decimal]] * Array[Byte] * java.sql.Date * java.sql.Timestamp @@ -54,7 +54,7 @@ import scala.collection.JavaConversions._ * Map: scala.collection.immutable.Map * List: scala.collection.immutable.Seq * Struct: - * org.apache.spark.sql.catalyst.expression.Row + * [[org.apache.spark.sql.catalyst.InternalRow]] * Union: NOT SUPPORTED YET * The Complex types plays as a container, which can hold arbitrary data types. * @@ -454,7 +454,7 @@ private[hive] trait HiveInspectors { * * NOTICE: the complex data type requires recursive wrapping. */ - def wrap(a: Any, oi: ObjectInspector): AnyRef = oi match { + def wrap(a: Any, oi: ObjectInspector, dataType: DataType): AnyRef = oi match { case x: ConstantObjectInspector => x.getWritableConstantValue case _ if a == null => null case x: PrimitiveObjectInspector => x match { @@ -488,43 +488,50 @@ private[hive] trait HiveInspectors { } case x: SettableStructObjectInspector => val fieldRefs = x.getAllStructFieldRefs + val structType = dataType.asInstanceOf[StructType] val row = a.asInstanceOf[InternalRow] // 1. create the pojo (most likely) object val result = x.create() var i = 0 while (i < fieldRefs.length) { // 2. set the property for the pojo + val tpe = structType(i).dataType x.setStructFieldData( result, fieldRefs.get(i), - wrap(row.get(i), fieldRefs.get(i).getFieldObjectInspector)) + wrap(row.get(i, tpe), fieldRefs.get(i).getFieldObjectInspector, tpe)) i += 1 } result case x: StructObjectInspector => val fieldRefs = x.getAllStructFieldRefs + val structType = dataType.asInstanceOf[StructType] val row = a.asInstanceOf[InternalRow] val result = new java.util.ArrayList[AnyRef](fieldRefs.length) var i = 0 while (i < fieldRefs.length) { - result.add(wrap(row.get(i), fieldRefs.get(i).getFieldObjectInspector)) + val tpe = structType(i).dataType + result.add(wrap(row.get(i, tpe), fieldRefs.get(i).getFieldObjectInspector, tpe)) i += 1 } result case x: ListObjectInspector => val list = new java.util.ArrayList[Object] + val tpe = dataType.asInstanceOf[ArrayType].elementType a.asInstanceOf[Seq[_]].foreach { - v => list.add(wrap(v, x.getListElementObjectInspector)) + v => list.add(wrap(v, x.getListElementObjectInspector, tpe)) } list case x: MapObjectInspector => + val keyType = dataType.asInstanceOf[MapType].keyType + val valueType = dataType.asInstanceOf[MapType].valueType // Some UDFs seem to assume we pass in a HashMap. val hashMap = new java.util.HashMap[AnyRef, AnyRef]() - hashMap.putAll(a.asInstanceOf[Map[_, _]].map { - case (k, v) => - wrap(k, x.getMapKeyObjectInspector) -> wrap(v, x.getMapValueObjectInspector) + hashMap.putAll(a.asInstanceOf[Map[_, _]].map { case (k, v) => + wrap(k, x.getMapKeyObjectInspector, keyType) -> + wrap(v, x.getMapValueObjectInspector, valueType) }) hashMap @@ -533,22 +540,24 @@ private[hive] trait HiveInspectors { def wrap( row: InternalRow, inspectors: Seq[ObjectInspector], - cache: Array[AnyRef]): Array[AnyRef] = { + cache: Array[AnyRef], + dataTypes: Array[DataType]): Array[AnyRef] = { var i = 0 while (i < inspectors.length) { - cache(i) = wrap(row.get(i), inspectors(i)) + cache(i) = wrap(row.get(i, dataTypes(i)), inspectors(i), dataTypes(i)) i += 1 } cache } def wrap( - row: Seq[Any], - inspectors: Seq[ObjectInspector], - cache: Array[AnyRef]): Array[AnyRef] = { + row: Seq[Any], + inspectors: Seq[ObjectInspector], + cache: Array[AnyRef], + dataTypes: Array[DataType]): Array[AnyRef] = { var i = 0 while (i < inspectors.length) { - cache(i) = wrap(row(i), inspectors(i)) + cache(i) = wrap(row(i), inspectors(i), dataTypes(i)) i += 1 } cache @@ -625,7 +634,7 @@ private[hive] trait HiveInspectors { ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, null) } else { val list = new java.util.ArrayList[Object]() - value.asInstanceOf[Seq[_]].foreach(v => list.add(wrap(v, listObjectInspector))) + value.asInstanceOf[Seq[_]].foreach(v => list.add(wrap(v, listObjectInspector, dt))) ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, list) } case Literal(value, MapType(keyType, valueType, _)) => @@ -636,7 +645,7 @@ private[hive] trait HiveInspectors { } else { val map = new java.util.HashMap[Object, Object]() value.asInstanceOf[Map[_, _]].foreach (entry => { - map.put(wrap(entry._1, keyOI), wrap(entry._2, valueOI)) + map.put(wrap(entry._1, keyOI, keyType), wrap(entry._2, valueOI, valueType)) }) ObjectInspectorFactory.getStandardConstantMapObjectInspector(keyOI, valueOI, map) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 3259b50acc..54bf6bd67f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -83,24 +83,22 @@ private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry) private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Expression with HiveInspectors with CodegenFallback with Logging { - type UDFType = UDF - override def deterministic: Boolean = isUDFDeterministic override def nullable: Boolean = true @transient - lazy val function = funcWrapper.createFunction[UDFType]() + lazy val function = funcWrapper.createFunction[UDF]() @transient - protected lazy val method = + private lazy val method = function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo)) @transient - protected lazy val arguments = children.map(toInspector).toArray + private lazy val arguments = children.map(toInspector).toArray @transient - protected lazy val isUDFDeterministic = { + private lazy val isUDFDeterministic = { val udfType = function.getClass().getAnnotation(classOf[HiveUDFType]) udfType != null && udfType.deterministic() } @@ -109,7 +107,7 @@ private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, childre // Create parameter converters @transient - protected lazy val conversionHelper = new ConversionHelper(method, arguments) + private lazy val conversionHelper = new ConversionHelper(method, arguments) @transient lazy val dataType = javaClassToDataType(method.getReturnType) @@ -119,14 +117,19 @@ private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, childre method.getGenericReturnType(), ObjectInspectorOptions.JAVA) @transient - protected lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length) + private lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length) + + @transient + private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray // TODO: Finish input output types. override def eval(input: InternalRow): Any = { - unwrap( - FunctionRegistry.invoke(method, function, conversionHelper - .convertIfNecessary(wrap(children.map(c => c.eval(input)), arguments, cached): _*): _*), - returnInspector) + val inputs = wrap(children.map(c => c.eval(input)), arguments, cached, inputDataTypes) + val ret = FunctionRegistry.invoke( + method, + function, + conversionHelper.convertIfNecessary(inputs : _*): _*) + unwrap(ret, returnInspector) } override def toString: String = { @@ -135,47 +138,48 @@ private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, childre } // Adapter from Catalyst ExpressionResult to Hive DeferredObject -private[hive] class DeferredObjectAdapter(oi: ObjectInspector) +private[hive] class DeferredObjectAdapter(oi: ObjectInspector, dataType: DataType) extends DeferredObject with HiveInspectors { + private var func: () => Any = _ def set(func: () => Any): Unit = { this.func = func } override def prepare(i: Int): Unit = {} - override def get(): AnyRef = wrap(func(), oi) + override def get(): AnyRef = wrap(func(), oi, dataType) } private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Expression with HiveInspectors with CodegenFallback with Logging { - type UDFType = GenericUDF + + override def nullable: Boolean = true override def deterministic: Boolean = isUDFDeterministic - override def nullable: Boolean = true + override def foldable: Boolean = + isUDFDeterministic && returnInspector.isInstanceOf[ConstantObjectInspector] @transient - lazy val function = funcWrapper.createFunction[UDFType]() + lazy val function = funcWrapper.createFunction[GenericUDF]() @transient - protected lazy val argumentInspectors = children.map(toInspector) + private lazy val argumentInspectors = children.map(toInspector) @transient - protected lazy val returnInspector = { + private lazy val returnInspector = { function.initializeAndFoldConstants(argumentInspectors.toArray) } @transient - protected lazy val isUDFDeterministic = { + private lazy val isUDFDeterministic = { val udfType = function.getClass.getAnnotation(classOf[HiveUDFType]) udfType != null && udfType.deterministic() } - override def foldable: Boolean = - isUDFDeterministic && returnInspector.isInstanceOf[ConstantObjectInspector] - @transient - protected lazy val deferedObjects = - argumentInspectors.map(new DeferredObjectAdapter(_)).toArray[DeferredObject] + private lazy val deferedObjects = argumentInspectors.zip(children).map { case (inspect, child) => + new DeferredObjectAdapter(inspect, child.dataType) + }.toArray[DeferredObject] lazy val dataType: DataType = inspectorToDataType(returnInspector) @@ -354,6 +358,9 @@ private[hive] case class HiveWindowFunction( // Output buffer. private var outputBuffer: Any = _ + @transient + private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray + override def init(): Unit = { evaluator.init(GenericUDAFEvaluator.Mode.COMPLETE, inputInspectors) } @@ -368,8 +375,13 @@ private[hive] case class HiveWindowFunction( } override def prepareInputParameters(input: InternalRow): AnyRef = { - wrap(inputProjection(input), inputInspectors, new Array[AnyRef](children.length)) + wrap( + inputProjection(input), + inputInspectors, + new Array[AnyRef](children.length), + inputDataTypes) } + // Add input parameters for a single row. override def update(input: AnyRef): Unit = { evaluator.iterate(hiveEvaluatorBuffer, input.asInstanceOf[Array[AnyRef]]) @@ -510,12 +522,15 @@ private[hive] case class HiveGenericUDTF( field => (inspectorToDataType(field.getFieldObjectInspector), true) } + @transient + private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray + override def eval(input: InternalRow): TraversableOnce[InternalRow] = { outputInspector // Make sure initialized. val inputProjection = new InterpretedProjection(children) - function.process(wrap(inputProjection(input), inputInspectors, udtInput)) + function.process(wrap(inputProjection(input), inputInspectors, udtInput, inputDataTypes)) collector.collectRows() } @@ -584,9 +599,12 @@ private[hive] case class HiveUDAFFunction( @transient protected lazy val cached = new Array[AnyRef](exprs.length) + @transient + private lazy val inputDataTypes: Array[DataType] = exprs.map(_.dataType).toArray + def update(input: InternalRow): Unit = { val inputs = inputProjection(input) - function.iterate(buffer, wrap(inputs, inspectors, cached)) + function.iterate(buffer, wrap(inputs, inspectors, cached, inputDataTypes)) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index 8bb498a06f..0330013f53 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -48,7 +48,11 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { ObjectInspectorOptions.JAVA).asInstanceOf[StructObjectInspector] val a = unwrap(state, soi).asInstanceOf[InternalRow] - val b = wrap(a, soi).asInstanceOf[UDAFPercentile.State] + + val dt = new StructType() + .add("counts", MapType(LongType, LongType)) + .add("percentiles", ArrayType(DoubleType)) + val b = wrap(a, soi, dt).asInstanceOf[UDAFPercentile.State] val sfCounts = soi.getStructFieldRef("counts") val sfPercentiles = soi.getStructFieldRef("percentiles") @@ -158,44 +162,45 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { val writableOIs = dataTypes.map(toWritableInspector) val nullRow = data.map(d => null) - checkValues(nullRow, nullRow.zip(writableOIs).map { - case (d, oi) => unwrap(wrap(d, oi), oi) + checkValues(nullRow, nullRow.zip(writableOIs).zip(dataTypes).map { + case ((d, oi), dt) => unwrap(wrap(d, oi, dt), oi) }) // struct couldn't be constant, sweep it out val constantExprs = data.filter(!_.dataType.isInstanceOf[StructType]) + val constantTypes = constantExprs.map(_.dataType) val constantData = constantExprs.map(_.eval()) val constantNullData = constantData.map(_ => null) val constantWritableOIs = constantExprs.map(e => toWritableInspector(e.dataType)) val constantNullWritableOIs = constantExprs.map(e => toInspector(Literal.create(null, e.dataType))) - checkValues(constantData, constantData.zip(constantWritableOIs).map { - case (d, oi) => unwrap(wrap(d, oi), oi) + checkValues(constantData, constantData.zip(constantWritableOIs).zip(constantTypes).map { + case ((d, oi), dt) => unwrap(wrap(d, oi, dt), oi) }) - checkValues(constantNullData, constantData.zip(constantNullWritableOIs).map { - case (d, oi) => unwrap(wrap(d, oi), oi) + checkValues(constantNullData, constantData.zip(constantNullWritableOIs).zip(constantTypes).map { + case ((d, oi), dt) => unwrap(wrap(d, oi, dt), oi) }) - checkValues(constantNullData, constantNullData.zip(constantWritableOIs).map { - case (d, oi) => unwrap(wrap(d, oi), oi) + checkValues(constantNullData, constantNullData.zip(constantWritableOIs).zip(constantTypes).map { + case ((d, oi), dt) => unwrap(wrap(d, oi, dt), oi) }) } test("wrap / unwrap primitive writable object inspector") { val writableOIs = dataTypes.map(toWritableInspector) - checkValues(row, row.zip(writableOIs).map { - case (data, oi) => unwrap(wrap(data, oi), oi) + checkValues(row, row.zip(writableOIs).zip(dataTypes).map { + case ((data, oi), dt) => unwrap(wrap(data, oi, dt), oi) }) } test("wrap / unwrap primitive java object inspector") { val ois = dataTypes.map(toInspector) - checkValues(row, row.zip(ois).map { - case (data, oi) => unwrap(wrap(data, oi), oi) + checkValues(row, row.zip(ois).zip(dataTypes).map { + case ((data, oi), dt) => unwrap(wrap(data, oi, dt), oi) }) } @@ -205,31 +210,33 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { }) val inspector = toInspector(dt) checkValues(row, - unwrap(wrap(InternalRow.fromSeq(row), inspector), inspector).asInstanceOf[InternalRow]) - checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) + unwrap(wrap(InternalRow.fromSeq(row), inspector, dt), inspector).asInstanceOf[InternalRow]) + checkValue(null, unwrap(wrap(null, toInspector(dt), dt), toInspector(dt))) } test("wrap / unwrap Array Type") { val dt = ArrayType(dataTypes(0)) val d = row(0) :: row(0) :: Nil - checkValue(d, unwrap(wrap(d, toInspector(dt)), toInspector(dt))) - checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) + checkValue(d, unwrap(wrap(d, toInspector(dt), dt), toInspector(dt))) + checkValue(null, unwrap(wrap(null, toInspector(dt), dt), toInspector(dt))) checkValue(d, - unwrap(wrap(d, toInspector(Literal.create(d, dt))), toInspector(Literal.create(d, dt)))) + unwrap(wrap(d, toInspector(Literal.create(d, dt)), dt), toInspector(Literal.create(d, dt)))) checkValue(d, - unwrap(wrap(null, toInspector(Literal.create(d, dt))), toInspector(Literal.create(d, dt)))) + unwrap(wrap(null, toInspector(Literal.create(d, dt)), dt), + toInspector(Literal.create(d, dt)))) } test("wrap / unwrap Map Type") { val dt = MapType(dataTypes(0), dataTypes(1)) val d = Map(row(0) -> row(1)) - checkValue(d, unwrap(wrap(d, toInspector(dt)), toInspector(dt))) - checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) + checkValue(d, unwrap(wrap(d, toInspector(dt), dt), toInspector(dt))) + checkValue(null, unwrap(wrap(null, toInspector(dt), dt), toInspector(dt))) checkValue(d, - unwrap(wrap(d, toInspector(Literal.create(d, dt))), toInspector(Literal.create(d, dt)))) + unwrap(wrap(d, toInspector(Literal.create(d, dt)), dt), toInspector(Literal.create(d, dt)))) checkValue(d, - unwrap(wrap(null, toInspector(Literal.create(d, dt))), toInspector(Literal.create(d, dt)))) + unwrap(wrap(null, toInspector(Literal.create(d, dt)), dt), + toInspector(Literal.create(d, dt)))) } } |