aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-07-26 10:27:39 -0700
committerReynold Xin <rxin@databricks.com>2015-07-26 10:27:39 -0700
commit6c400b4f39be3fb5f473b8d2db11d239ea8ddf42 (patch)
tree4bd58339db246f3a686185e441eb1fb13ea0bc3a /sql
parentb79bf1df6238c087c3ec524344f1fc179719c5de (diff)
downloadspark-6c400b4f39be3fb5f473b8d2db11d239ea8ddf42.tar.gz
spark-6c400b4f39be3fb5f473b8d2db11d239ea8ddf42.tar.bz2
spark-6c400b4f39be3fb5f473b8d2db11d239ea8ddf42.zip
[SPARK-9354][SQL] Remove InternalRow.get generic getter call in Hive integration code.
Replaced them with get(ordinal, datatype) so we can use UnsafeRow here. I passed the data types throughout. Author: Reynold Xin <rxin@databricks.com> Closes #7669 from rxin/row-generic-getter-hive and squashes the following commits: 3467d8e [Reynold Xin] [SPARK-9354][SQL] Remove Internal.get generic getter call in Hive integration code.
Diffstat (limited to 'sql')
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala43
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala74
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala53
3 files changed, 102 insertions, 68 deletions
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
index 16977ce30c..f467500259 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
@@ -46,7 +46,7 @@ import scala.collection.JavaConversions._
* long / scala.Long
* short / scala.Short
* byte / scala.Byte
- * org.apache.spark.sql.types.Decimal
+ * [[org.apache.spark.sql.types.Decimal]]
* Array[Byte]
* java.sql.Date
* java.sql.Timestamp
@@ -54,7 +54,7 @@ import scala.collection.JavaConversions._
* Map: scala.collection.immutable.Map
* List: scala.collection.immutable.Seq
* Struct:
- * org.apache.spark.sql.catalyst.expression.Row
+ * [[org.apache.spark.sql.catalyst.InternalRow]]
* Union: NOT SUPPORTED YET
* The Complex types plays as a container, which can hold arbitrary data types.
*
@@ -454,7 +454,7 @@ private[hive] trait HiveInspectors {
*
* NOTICE: the complex data type requires recursive wrapping.
*/
- def wrap(a: Any, oi: ObjectInspector): AnyRef = oi match {
+ def wrap(a: Any, oi: ObjectInspector, dataType: DataType): AnyRef = oi match {
case x: ConstantObjectInspector => x.getWritableConstantValue
case _ if a == null => null
case x: PrimitiveObjectInspector => x match {
@@ -488,43 +488,50 @@ private[hive] trait HiveInspectors {
}
case x: SettableStructObjectInspector =>
val fieldRefs = x.getAllStructFieldRefs
+ val structType = dataType.asInstanceOf[StructType]
val row = a.asInstanceOf[InternalRow]
// 1. create the pojo (most likely) object
val result = x.create()
var i = 0
while (i < fieldRefs.length) {
// 2. set the property for the pojo
+ val tpe = structType(i).dataType
x.setStructFieldData(
result,
fieldRefs.get(i),
- wrap(row.get(i), fieldRefs.get(i).getFieldObjectInspector))
+ wrap(row.get(i, tpe), fieldRefs.get(i).getFieldObjectInspector, tpe))
i += 1
}
result
case x: StructObjectInspector =>
val fieldRefs = x.getAllStructFieldRefs
+ val structType = dataType.asInstanceOf[StructType]
val row = a.asInstanceOf[InternalRow]
val result = new java.util.ArrayList[AnyRef](fieldRefs.length)
var i = 0
while (i < fieldRefs.length) {
- result.add(wrap(row.get(i), fieldRefs.get(i).getFieldObjectInspector))
+ val tpe = structType(i).dataType
+ result.add(wrap(row.get(i, tpe), fieldRefs.get(i).getFieldObjectInspector, tpe))
i += 1
}
result
case x: ListObjectInspector =>
val list = new java.util.ArrayList[Object]
+ val tpe = dataType.asInstanceOf[ArrayType].elementType
a.asInstanceOf[Seq[_]].foreach {
- v => list.add(wrap(v, x.getListElementObjectInspector))
+ v => list.add(wrap(v, x.getListElementObjectInspector, tpe))
}
list
case x: MapObjectInspector =>
+ val keyType = dataType.asInstanceOf[MapType].keyType
+ val valueType = dataType.asInstanceOf[MapType].valueType
// Some UDFs seem to assume we pass in a HashMap.
val hashMap = new java.util.HashMap[AnyRef, AnyRef]()
- hashMap.putAll(a.asInstanceOf[Map[_, _]].map {
- case (k, v) =>
- wrap(k, x.getMapKeyObjectInspector) -> wrap(v, x.getMapValueObjectInspector)
+ hashMap.putAll(a.asInstanceOf[Map[_, _]].map { case (k, v) =>
+ wrap(k, x.getMapKeyObjectInspector, keyType) ->
+ wrap(v, x.getMapValueObjectInspector, valueType)
})
hashMap
@@ -533,22 +540,24 @@ private[hive] trait HiveInspectors {
def wrap(
row: InternalRow,
inspectors: Seq[ObjectInspector],
- cache: Array[AnyRef]): Array[AnyRef] = {
+ cache: Array[AnyRef],
+ dataTypes: Array[DataType]): Array[AnyRef] = {
var i = 0
while (i < inspectors.length) {
- cache(i) = wrap(row.get(i), inspectors(i))
+ cache(i) = wrap(row.get(i, dataTypes(i)), inspectors(i), dataTypes(i))
i += 1
}
cache
}
def wrap(
- row: Seq[Any],
- inspectors: Seq[ObjectInspector],
- cache: Array[AnyRef]): Array[AnyRef] = {
+ row: Seq[Any],
+ inspectors: Seq[ObjectInspector],
+ cache: Array[AnyRef],
+ dataTypes: Array[DataType]): Array[AnyRef] = {
var i = 0
while (i < inspectors.length) {
- cache(i) = wrap(row(i), inspectors(i))
+ cache(i) = wrap(row(i), inspectors(i), dataTypes(i))
i += 1
}
cache
@@ -625,7 +634,7 @@ private[hive] trait HiveInspectors {
ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, null)
} else {
val list = new java.util.ArrayList[Object]()
- value.asInstanceOf[Seq[_]].foreach(v => list.add(wrap(v, listObjectInspector)))
+ value.asInstanceOf[Seq[_]].foreach(v => list.add(wrap(v, listObjectInspector, dt)))
ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, list)
}
case Literal(value, MapType(keyType, valueType, _)) =>
@@ -636,7 +645,7 @@ private[hive] trait HiveInspectors {
} else {
val map = new java.util.HashMap[Object, Object]()
value.asInstanceOf[Map[_, _]].foreach (entry => {
- map.put(wrap(entry._1, keyOI), wrap(entry._2, valueOI))
+ map.put(wrap(entry._1, keyOI, keyType), wrap(entry._2, valueOI, valueType))
})
ObjectInspectorFactory.getStandardConstantMapObjectInspector(keyOI, valueOI, map)
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index 3259b50acc..54bf6bd67f 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -83,24 +83,22 @@ private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry)
private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
extends Expression with HiveInspectors with CodegenFallback with Logging {
- type UDFType = UDF
-
override def deterministic: Boolean = isUDFDeterministic
override def nullable: Boolean = true
@transient
- lazy val function = funcWrapper.createFunction[UDFType]()
+ lazy val function = funcWrapper.createFunction[UDF]()
@transient
- protected lazy val method =
+ private lazy val method =
function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo))
@transient
- protected lazy val arguments = children.map(toInspector).toArray
+ private lazy val arguments = children.map(toInspector).toArray
@transient
- protected lazy val isUDFDeterministic = {
+ private lazy val isUDFDeterministic = {
val udfType = function.getClass().getAnnotation(classOf[HiveUDFType])
udfType != null && udfType.deterministic()
}
@@ -109,7 +107,7 @@ private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, childre
// Create parameter converters
@transient
- protected lazy val conversionHelper = new ConversionHelper(method, arguments)
+ private lazy val conversionHelper = new ConversionHelper(method, arguments)
@transient
lazy val dataType = javaClassToDataType(method.getReturnType)
@@ -119,14 +117,19 @@ private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, childre
method.getGenericReturnType(), ObjectInspectorOptions.JAVA)
@transient
- protected lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length)
+ private lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length)
+
+ @transient
+ private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray
// TODO: Finish input output types.
override def eval(input: InternalRow): Any = {
- unwrap(
- FunctionRegistry.invoke(method, function, conversionHelper
- .convertIfNecessary(wrap(children.map(c => c.eval(input)), arguments, cached): _*): _*),
- returnInspector)
+ val inputs = wrap(children.map(c => c.eval(input)), arguments, cached, inputDataTypes)
+ val ret = FunctionRegistry.invoke(
+ method,
+ function,
+ conversionHelper.convertIfNecessary(inputs : _*): _*)
+ unwrap(ret, returnInspector)
}
override def toString: String = {
@@ -135,47 +138,48 @@ private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, childre
}
// Adapter from Catalyst ExpressionResult to Hive DeferredObject
-private[hive] class DeferredObjectAdapter(oi: ObjectInspector)
+private[hive] class DeferredObjectAdapter(oi: ObjectInspector, dataType: DataType)
extends DeferredObject with HiveInspectors {
+
private var func: () => Any = _
def set(func: () => Any): Unit = {
this.func = func
}
override def prepare(i: Int): Unit = {}
- override def get(): AnyRef = wrap(func(), oi)
+ override def get(): AnyRef = wrap(func(), oi, dataType)
}
private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
extends Expression with HiveInspectors with CodegenFallback with Logging {
- type UDFType = GenericUDF
+
+ override def nullable: Boolean = true
override def deterministic: Boolean = isUDFDeterministic
- override def nullable: Boolean = true
+ override def foldable: Boolean =
+ isUDFDeterministic && returnInspector.isInstanceOf[ConstantObjectInspector]
@transient
- lazy val function = funcWrapper.createFunction[UDFType]()
+ lazy val function = funcWrapper.createFunction[GenericUDF]()
@transient
- protected lazy val argumentInspectors = children.map(toInspector)
+ private lazy val argumentInspectors = children.map(toInspector)
@transient
- protected lazy val returnInspector = {
+ private lazy val returnInspector = {
function.initializeAndFoldConstants(argumentInspectors.toArray)
}
@transient
- protected lazy val isUDFDeterministic = {
+ private lazy val isUDFDeterministic = {
val udfType = function.getClass.getAnnotation(classOf[HiveUDFType])
udfType != null && udfType.deterministic()
}
- override def foldable: Boolean =
- isUDFDeterministic && returnInspector.isInstanceOf[ConstantObjectInspector]
-
@transient
- protected lazy val deferedObjects =
- argumentInspectors.map(new DeferredObjectAdapter(_)).toArray[DeferredObject]
+ private lazy val deferedObjects = argumentInspectors.zip(children).map { case (inspect, child) =>
+ new DeferredObjectAdapter(inspect, child.dataType)
+ }.toArray[DeferredObject]
lazy val dataType: DataType = inspectorToDataType(returnInspector)
@@ -354,6 +358,9 @@ private[hive] case class HiveWindowFunction(
// Output buffer.
private var outputBuffer: Any = _
+ @transient
+ private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray
+
override def init(): Unit = {
evaluator.init(GenericUDAFEvaluator.Mode.COMPLETE, inputInspectors)
}
@@ -368,8 +375,13 @@ private[hive] case class HiveWindowFunction(
}
override def prepareInputParameters(input: InternalRow): AnyRef = {
- wrap(inputProjection(input), inputInspectors, new Array[AnyRef](children.length))
+ wrap(
+ inputProjection(input),
+ inputInspectors,
+ new Array[AnyRef](children.length),
+ inputDataTypes)
}
+
// Add input parameters for a single row.
override def update(input: AnyRef): Unit = {
evaluator.iterate(hiveEvaluatorBuffer, input.asInstanceOf[Array[AnyRef]])
@@ -510,12 +522,15 @@ private[hive] case class HiveGenericUDTF(
field => (inspectorToDataType(field.getFieldObjectInspector), true)
}
+ @transient
+ private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray
+
override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
outputInspector // Make sure initialized.
val inputProjection = new InterpretedProjection(children)
- function.process(wrap(inputProjection(input), inputInspectors, udtInput))
+ function.process(wrap(inputProjection(input), inputInspectors, udtInput, inputDataTypes))
collector.collectRows()
}
@@ -584,9 +599,12 @@ private[hive] case class HiveUDAFFunction(
@transient
protected lazy val cached = new Array[AnyRef](exprs.length)
+ @transient
+ private lazy val inputDataTypes: Array[DataType] = exprs.map(_.dataType).toArray
+
def update(input: InternalRow): Unit = {
val inputs = inputProjection(input)
- function.iterate(buffer, wrap(inputs, inspectors, cached))
+ function.iterate(buffer, wrap(inputs, inspectors, cached, inputDataTypes))
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala
index 8bb498a06f..0330013f53 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala
@@ -48,7 +48,11 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors {
ObjectInspectorOptions.JAVA).asInstanceOf[StructObjectInspector]
val a = unwrap(state, soi).asInstanceOf[InternalRow]
- val b = wrap(a, soi).asInstanceOf[UDAFPercentile.State]
+
+ val dt = new StructType()
+ .add("counts", MapType(LongType, LongType))
+ .add("percentiles", ArrayType(DoubleType))
+ val b = wrap(a, soi, dt).asInstanceOf[UDAFPercentile.State]
val sfCounts = soi.getStructFieldRef("counts")
val sfPercentiles = soi.getStructFieldRef("percentiles")
@@ -158,44 +162,45 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors {
val writableOIs = dataTypes.map(toWritableInspector)
val nullRow = data.map(d => null)
- checkValues(nullRow, nullRow.zip(writableOIs).map {
- case (d, oi) => unwrap(wrap(d, oi), oi)
+ checkValues(nullRow, nullRow.zip(writableOIs).zip(dataTypes).map {
+ case ((d, oi), dt) => unwrap(wrap(d, oi, dt), oi)
})
// struct couldn't be constant, sweep it out
val constantExprs = data.filter(!_.dataType.isInstanceOf[StructType])
+ val constantTypes = constantExprs.map(_.dataType)
val constantData = constantExprs.map(_.eval())
val constantNullData = constantData.map(_ => null)
val constantWritableOIs = constantExprs.map(e => toWritableInspector(e.dataType))
val constantNullWritableOIs =
constantExprs.map(e => toInspector(Literal.create(null, e.dataType)))
- checkValues(constantData, constantData.zip(constantWritableOIs).map {
- case (d, oi) => unwrap(wrap(d, oi), oi)
+ checkValues(constantData, constantData.zip(constantWritableOIs).zip(constantTypes).map {
+ case ((d, oi), dt) => unwrap(wrap(d, oi, dt), oi)
})
- checkValues(constantNullData, constantData.zip(constantNullWritableOIs).map {
- case (d, oi) => unwrap(wrap(d, oi), oi)
+ checkValues(constantNullData, constantData.zip(constantNullWritableOIs).zip(constantTypes).map {
+ case ((d, oi), dt) => unwrap(wrap(d, oi, dt), oi)
})
- checkValues(constantNullData, constantNullData.zip(constantWritableOIs).map {
- case (d, oi) => unwrap(wrap(d, oi), oi)
+ checkValues(constantNullData, constantNullData.zip(constantWritableOIs).zip(constantTypes).map {
+ case ((d, oi), dt) => unwrap(wrap(d, oi, dt), oi)
})
}
test("wrap / unwrap primitive writable object inspector") {
val writableOIs = dataTypes.map(toWritableInspector)
- checkValues(row, row.zip(writableOIs).map {
- case (data, oi) => unwrap(wrap(data, oi), oi)
+ checkValues(row, row.zip(writableOIs).zip(dataTypes).map {
+ case ((data, oi), dt) => unwrap(wrap(data, oi, dt), oi)
})
}
test("wrap / unwrap primitive java object inspector") {
val ois = dataTypes.map(toInspector)
- checkValues(row, row.zip(ois).map {
- case (data, oi) => unwrap(wrap(data, oi), oi)
+ checkValues(row, row.zip(ois).zip(dataTypes).map {
+ case ((data, oi), dt) => unwrap(wrap(data, oi, dt), oi)
})
}
@@ -205,31 +210,33 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors {
})
val inspector = toInspector(dt)
checkValues(row,
- unwrap(wrap(InternalRow.fromSeq(row), inspector), inspector).asInstanceOf[InternalRow])
- checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt)))
+ unwrap(wrap(InternalRow.fromSeq(row), inspector, dt), inspector).asInstanceOf[InternalRow])
+ checkValue(null, unwrap(wrap(null, toInspector(dt), dt), toInspector(dt)))
}
test("wrap / unwrap Array Type") {
val dt = ArrayType(dataTypes(0))
val d = row(0) :: row(0) :: Nil
- checkValue(d, unwrap(wrap(d, toInspector(dt)), toInspector(dt)))
- checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt)))
+ checkValue(d, unwrap(wrap(d, toInspector(dt), dt), toInspector(dt)))
+ checkValue(null, unwrap(wrap(null, toInspector(dt), dt), toInspector(dt)))
checkValue(d,
- unwrap(wrap(d, toInspector(Literal.create(d, dt))), toInspector(Literal.create(d, dt))))
+ unwrap(wrap(d, toInspector(Literal.create(d, dt)), dt), toInspector(Literal.create(d, dt))))
checkValue(d,
- unwrap(wrap(null, toInspector(Literal.create(d, dt))), toInspector(Literal.create(d, dt))))
+ unwrap(wrap(null, toInspector(Literal.create(d, dt)), dt),
+ toInspector(Literal.create(d, dt))))
}
test("wrap / unwrap Map Type") {
val dt = MapType(dataTypes(0), dataTypes(1))
val d = Map(row(0) -> row(1))
- checkValue(d, unwrap(wrap(d, toInspector(dt)), toInspector(dt)))
- checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt)))
+ checkValue(d, unwrap(wrap(d, toInspector(dt), dt), toInspector(dt)))
+ checkValue(null, unwrap(wrap(null, toInspector(dt), dt), toInspector(dt)))
checkValue(d,
- unwrap(wrap(d, toInspector(Literal.create(d, dt))), toInspector(Literal.create(d, dt))))
+ unwrap(wrap(d, toInspector(Literal.create(d, dt)), dt), toInspector(Literal.create(d, dt))))
checkValue(d,
- unwrap(wrap(null, toInspector(Literal.create(d, dt))), toInspector(Literal.create(d, dt))))
+ unwrap(wrap(null, toInspector(Literal.create(d, dt)), dt),
+ toInspector(Literal.create(d, dt))))
}
}