aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorCheng Hao <hao.cheng@intel.com>2014-10-28 19:11:57 -0700
committerMichael Armbrust <michael@databricks.com>2014-10-28 19:11:57 -0700
commitb5e79bf889700159d490cdac1f6322dff424b1d9 (patch)
treec5befc6a89689ec7f4c70f0cee73a19d45819578 /sql
parent1536d70331e9a4f5b5ea9dabfd72592ca1fc8e35 (diff)
downloadspark-b5e79bf889700159d490cdac1f6322dff424b1d9.tar.gz
spark-b5e79bf889700159d490cdac1f6322dff424b1d9.tar.bz2
spark-b5e79bf889700159d490cdac1f6322dff424b1d9.zip
[SPARK-3904] [SQL] add constant objectinspector support for udfs
In HQL, we convert all of the data type into normal `ObjectInspector`s for UDFs, most of cases it works, however, some of the UDF actually requires its children `ObjectInspector` to be the `ConstantObjectInspector`, which will cause exception. e.g. select named_struct("x", "str") from src limit 1; I updated the method `wrap` by adding the one more parameter `ObjectInspector`(to describe what it expects to wrap to, for example: java.lang.Integer or IntWritable). As well as the `unwrap` method by providing the input `ObjectInspector`. Author: Cheng Hao <hao.cheng@intel.com> Closes #2762 from chenghao-intel/udf_coi and squashes the following commits: bcacfd7 [Cheng Hao] Shim for both Hive 0.12 & 0.13.1 2416e5d [Cheng Hao] revert to hive 0.12 5793c01 [Cheng Hao] add space before while 4e56e1b [Cheng Hao] style issue 683d3fd [Cheng Hao] Add golden files fe591e4 [Cheng Hao] update HiveGenericUdf for set the ObjectInspector while constructing the DeferredObject f6740fe [Cheng Hao] Support Constant ObjectInspector for Map & List 8814c3a [Cheng Hao] Passing ContantObjectInspector(when necessary) for UDF initializing
Diffstat (limited to 'sql')
-rw-r--r--sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala8
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala185
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala2
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala62
-rw-r--r--sql/hive/src/test/resources/golden/constant array-0-761ef205b10ac4a10122c8b4ce10ada1
-rw-r--r--sql/hive/src/test/resources/golden/udf_named_struct-0-8f0ea83364b78634fbb3752c5a5c7251
-rw-r--r--sql/hive/src/test/resources/golden/udf_named_struct-1-380c9638cc6ea8ea42f187bf0cedf3501
-rw-r--r--sql/hive/src/test/resources/golden/udf_named_struct-2-22a79ac608b1249306f82f4bdc669b170
-rw-r--r--sql/hive/src/test/resources/golden/udf_named_struct-3-d7e4a555934307155784904ff9df188b1
-rw-r--r--sql/hive/src/test/resources/golden/udf_sort_array-0-e86d559aeb84a4cc017a103182c22bfb0
-rw-r--r--sql/hive/src/test/resources/golden/udf_sort_array-1-976cd8b6b50a2748bbc768aa5e11cf821
-rw-r--r--sql/hive/src/test/resources/golden/udf_sort_array-10-9e047718e5fea6ea79124f1e899f1c131
-rw-r--r--sql/hive/src/test/resources/golden/udf_sort_array-2-c429ec85a6da60ebd4bc6f0f266e8b934
-rw-r--r--sql/hive/src/test/resources/golden/udf_sort_array-3-55c4cdaf8438b06675d60848d68f35de0
-rw-r--r--sql/hive/src/test/resources/golden/udf_struct-0-f41043b7d9f14fa5e998c90454c7bdb11
-rw-r--r--sql/hive/src/test/resources/golden/udf_struct-1-8ccdb20153debdab789ea8ad0228e2eb1
-rw-r--r--sql/hive/src/test/resources/golden/udf_struct-2-4a62774a6de7571c8d2bcb77da63f8f30
-rw-r--r--sql/hive/src/test/resources/golden/udf_struct-3-abffdaacb0c7076ab538fbeec072daa21
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala8
-rw-r--r--sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala57
-rw-r--r--sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala64
21 files changed, 307 insertions, 92 deletions
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
index 26d9ca05c8..1a3c24be42 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
@@ -233,7 +233,11 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
// Sort with Limit clause causes failure.
"ctas",
- "ctas_hadoop20"
+ "ctas_hadoop20",
+
+ // timestamp in array, the output format of Hive contains double quotes, while
+ // Spark SQL doesn't
+ "udf_sort_array"
) ++ HiveShim.compatibilityBlackList
/**
@@ -861,6 +865,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"udf_minute",
"udf_modulo",
"udf_month",
+ "udf_named_struct",
"udf_negative",
"udf_not",
"udf_notequal",
@@ -894,6 +899,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"udf_stddev_pop",
"udf_stddev_samp",
"udf_string",
+ "udf_struct",
"udf_substring",
"udf_subtract",
"udf_sum",
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
index fad7373a2f..c6103a124d 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
@@ -18,7 +18,9 @@
package org.apache.spark.sql.hive
import org.apache.hadoop.hive.common.`type`.HiveDecimal
+import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory
import org.apache.hadoop.hive.serde2.objectinspector._
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector._
import org.apache.hadoop.hive.serde2.objectinspector.primitive._
import org.apache.hadoop.hive.serde2.{io => hiveIo}
import org.apache.hadoop.{io => hadoopIo}
@@ -78,44 +80,13 @@ private[hive] trait HiveInspectors {
case c: Class[_] if c == classOf[java.lang.Object] => NullType
}
- /** Converts hive types to native catalyst types. */
- def unwrap(a: Any): Any = a match {
- case null => null
- case i: hadoopIo.IntWritable => i.get
- case t: hadoopIo.Text => t.toString
- case l: hadoopIo.LongWritable => l.get
- case d: hadoopIo.DoubleWritable => d.get
- case d: hiveIo.DoubleWritable => d.get
- case s: hiveIo.ShortWritable => s.get
- case b: hadoopIo.BooleanWritable => b.get
- case b: hiveIo.ByteWritable => b.get
- case b: hadoopIo.FloatWritable => b.get
- case b: hadoopIo.BytesWritable => {
- val bytes = new Array[Byte](b.getLength)
- System.arraycopy(b.getBytes(), 0, bytes, 0, b.getLength)
- bytes
- }
- case d: hiveIo.DateWritable => d.get
- case t: hiveIo.TimestampWritable => t.getTimestamp
- case b: hiveIo.HiveDecimalWritable => BigDecimal(b.getHiveDecimal().bigDecimalValue())
- case list: java.util.List[_] => list.map(unwrap)
- case map: java.util.Map[_,_] => map.map { case (k, v) => (unwrap(k), unwrap(v)) }.toMap
- case array: Array[_] => array.map(unwrap).toSeq
- case p: java.lang.Short => p
- case p: java.lang.Long => p
- case p: java.lang.Float => p
- case p: java.lang.Integer => p
- case p: java.lang.Double => p
- case p: java.lang.Byte => p
- case p: java.lang.Boolean => p
- case str: String => str
- case p: java.math.BigDecimal => p
- case p: Array[Byte] => p
- case p: java.sql.Date => p
- case p: java.sql.Timestamp => p
- }
-
- def unwrapData(data: Any, oi: ObjectInspector): Any = oi match {
+ /**
+ * Converts hive types to native catalyst types.
+ * @param data the data in Hive type
+ * @param oi the ObjectInspector associated with the Hive Type
+ * @return convert the data into catalyst type
+ */
+ def unwrap(data: Any, oi: ObjectInspector): Any = oi match {
case hvoi: HiveVarcharObjectInspector =>
if (data == null) null else hvoi.getPrimitiveJavaObject(data).getValue
case hdoi: HiveDecimalObjectInspector =>
@@ -123,43 +94,89 @@ private[hive] trait HiveInspectors {
case pi: PrimitiveObjectInspector => pi.getPrimitiveJavaObject(data)
case li: ListObjectInspector =>
Option(li.getList(data))
- .map(_.map(unwrapData(_, li.getListElementObjectInspector)).toSeq)
+ .map(_.map(unwrap(_, li.getListElementObjectInspector)).toSeq)
.orNull
case mi: MapObjectInspector =>
Option(mi.getMap(data)).map(
_.map {
case (k,v) =>
- (unwrapData(k, mi.getMapKeyObjectInspector),
- unwrapData(v, mi.getMapValueObjectInspector))
+ (unwrap(k, mi.getMapKeyObjectInspector),
+ unwrap(v, mi.getMapValueObjectInspector))
}.toMap).orNull
case si: StructObjectInspector =>
val allRefs = si.getAllStructFieldRefs
new GenericRow(
allRefs.map(r =>
- unwrapData(si.getStructFieldData(data,r), r.getFieldObjectInspector)).toArray)
+ unwrap(si.getStructFieldData(data,r), r.getFieldObjectInspector)).toArray)
}
- /** Converts native catalyst types to the types expected by Hive */
- def wrap(a: Any): AnyRef = a match {
- case s: String => s: java.lang.String
- case i: Int => i: java.lang.Integer
- case b: Boolean => b: java.lang.Boolean
- case f: Float => f: java.lang.Float
- case d: Double => d: java.lang.Double
- case l: Long => l: java.lang.Long
- case l: Short => l: java.lang.Short
- case l: Byte => l: java.lang.Byte
- case b: BigDecimal => HiveShim.createDecimal(b.underlying())
- case b: Array[Byte] => b
- case d: java.sql.Date => d
- case t: java.sql.Timestamp => t
- case s: Seq[_] => seqAsJavaList(s.map(wrap))
- case m: Map[_,_] =>
- // Some UDFs seem to assume we pass in a HashMap.
- val hashMap = new java.util.HashMap[AnyRef, AnyRef]()
- hashMap.putAll(m.map { case (k, v) => wrap(k) -> wrap(v) })
- hashMap
- case null => null
+ /**
+ * Converts native catalyst types to the types expected by Hive
+ * @param a the value to be wrapped
+ * @param oi This ObjectInspector associated with the value returned by this function, and
+ * the ObjectInspector should also be consistent with those returned from
+ * toInspector: DataType => ObjectInspector and
+ * toInspector: Expression => ObjectInspector
+ */
+ def wrap(a: Any, oi: ObjectInspector): AnyRef = if (a == null) {
+ null
+ } else {
+ oi match {
+ case x: ConstantObjectInspector => x.getWritableConstantValue
+ case x: PrimitiveObjectInspector => a match {
+ // TODO what if x.preferWritable() == true? reuse the writable?
+ case s: String => s: java.lang.String
+ case i: Int => i: java.lang.Integer
+ case b: Boolean => b: java.lang.Boolean
+ case f: Float => f: java.lang.Float
+ case d: Double => d: java.lang.Double
+ case l: Long => l: java.lang.Long
+ case l: Short => l: java.lang.Short
+ case l: Byte => l: java.lang.Byte
+ case b: BigDecimal => HiveShim.createDecimal(b.underlying())
+ case b: Array[Byte] => b
+ case d: java.sql.Date => d
+ case t: java.sql.Timestamp => t
+ }
+ case x: StructObjectInspector =>
+ val fieldRefs = x.getAllStructFieldRefs
+ val row = a.asInstanceOf[Seq[_]]
+ val result = new java.util.ArrayList[AnyRef](fieldRefs.length)
+ var i = 0
+ while (i < fieldRefs.length) {
+ result.add(wrap(row(i), fieldRefs.get(i).getFieldObjectInspector))
+ i += 1
+ }
+
+ result
+ case x: ListObjectInspector =>
+ val list = new java.util.ArrayList[Object]
+ a.asInstanceOf[Seq[_]].foreach {
+ v => list.add(wrap(v, x.getListElementObjectInspector))
+ }
+ list
+ case x: MapObjectInspector =>
+ // Some UDFs seem to assume we pass in a HashMap.
+ val hashMap = new java.util.HashMap[AnyRef, AnyRef]()
+ hashMap.putAll(a.asInstanceOf[Map[_, _]].map {
+ case (k, v) =>
+ wrap(k, x.getMapKeyObjectInspector) -> wrap(v, x.getMapValueObjectInspector)
+ })
+
+ hashMap
+ }
+ }
+
+ def wrap(
+ row: Seq[Any],
+ inspectors: Seq[ObjectInspector],
+ cache: Array[AnyRef]): Array[AnyRef] = {
+ var i = 0
+ while (i < inspectors.length) {
+ cache(i) = wrap(row(i), inspectors(i))
+ i += 1
+ }
+ cache
}
def toInspector(dataType: DataType): ObjectInspector = dataType match {
@@ -186,6 +203,48 @@ private[hive] trait HiveInspectors {
fields.map(f => f.name), fields.map(f => toInspector(f.dataType)))
}
+ def toInspector(expr: Expression): ObjectInspector = expr match {
+ case Literal(value: String, StringType) =>
+ HiveShim.getPrimitiveWritableConstantObjectInspector(value)
+ case Literal(value: Int, IntegerType) =>
+ HiveShim.getPrimitiveWritableConstantObjectInspector(value)
+ case Literal(value: Double, DoubleType) =>
+ HiveShim.getPrimitiveWritableConstantObjectInspector(value)
+ case Literal(value: Boolean, BooleanType) =>
+ HiveShim.getPrimitiveWritableConstantObjectInspector(value)
+ case Literal(value: Long, LongType) =>
+ HiveShim.getPrimitiveWritableConstantObjectInspector(value)
+ case Literal(value: Float, FloatType) =>
+ HiveShim.getPrimitiveWritableConstantObjectInspector(value)
+ case Literal(value: Short, ShortType) =>
+ HiveShim.getPrimitiveWritableConstantObjectInspector(value)
+ case Literal(value: Byte, ByteType) =>
+ HiveShim.getPrimitiveWritableConstantObjectInspector(value)
+ case Literal(value: Array[Byte], BinaryType) =>
+ HiveShim.getPrimitiveWritableConstantObjectInspector(value)
+ case Literal(value: java.sql.Date, DateType) =>
+ HiveShim.getPrimitiveWritableConstantObjectInspector(value)
+ case Literal(value: java.sql.Timestamp, TimestampType) =>
+ HiveShim.getPrimitiveWritableConstantObjectInspector(value)
+ case Literal(value: BigDecimal, DecimalType) =>
+ HiveShim.getPrimitiveWritableConstantObjectInspector(value)
+ case Literal(_, NullType) =>
+ HiveShim.getPrimitiveNullWritableConstantObjectInspector
+ case Literal(value: Seq[_], ArrayType(dt, _)) =>
+ val listObjectInspector = toInspector(dt)
+ val list = new java.util.ArrayList[Object]()
+ value.foreach(v => list.add(wrap(v, listObjectInspector)))
+ ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, list)
+ case Literal(map: Map[_, _], MapType(keyType, valueType, _)) =>
+ val value = new java.util.HashMap[Object, Object]()
+ val keyOI = toInspector(keyType)
+ val valueOI = toInspector(valueType)
+ map.foreach (entry => value.put(wrap(entry._1, keyOI), wrap(entry._2, valueOI)))
+ ObjectInspectorFactory.getStandardConstantMapObjectInspector(keyOI, valueOI, value)
+ case Literal(_, dt) => sys.error(s"Hive doesn't support the constant type [$dt].")
+ case _ => toInspector(expr.dataType)
+ }
+
def inspectorToDataType(inspector: ObjectInspector): DataType = inspector match {
case s: StructObjectInspector =>
StructType(s.getAllStructFieldRefs.map(f => {
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
index 9ff7ab5a12..e49f0957d1 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
@@ -291,7 +291,7 @@ private[hive] object HadoopTableReader extends HiveInspectors {
case oi: DoubleObjectInspector =>
(value: Any, row: MutableRow, ordinal: Int) => row.setDouble(ordinal, oi.get(value))
case oi =>
- (value: Any, row: MutableRow, ordinal: Int) => row(ordinal) = unwrapData(value, oi)
+ (value: Any, row: MutableRow, ordinal: Int) => row(ordinal) = unwrap(value, oi)
}
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
index 683c820dec..aff4ddce92 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
@@ -21,7 +21,9 @@ import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper
import scala.collection.mutable.ArrayBuffer
-import org.apache.hadoop.hive.common.`type`.HiveDecimal
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory
import org.apache.hadoop.hive.ql.exec.{UDF, UDAF}
import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry}
import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType}
@@ -97,7 +99,7 @@ private[hive] case class HiveSimpleUdf(functionClassName: String, children: Seq[
function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo))
@transient
- protected lazy val arguments = children.map(c => toInspector(c.dataType)).toArray
+ protected lazy val arguments = children.map(toInspector).toArray
@transient
protected lazy val isUDFDeterministic = {
@@ -116,12 +118,19 @@ private[hive] case class HiveSimpleUdf(functionClassName: String, children: Seq[
@transient
lazy val dataType = javaClassToDataType(method.getReturnType)
+ @transient
+ lazy val returnInspector = ObjectInspectorFactory.getReflectionObjectInspector(
+ method.getGenericReturnType(), ObjectInspectorOptions.JAVA)
+
+ @transient
+ protected lazy val cached = new Array[AnyRef](children.length)
+
// TODO: Finish input output types.
override def eval(input: Row): Any = {
- val evaluatedChildren = children.map(c => wrap(c.eval(input)))
-
- unwrap(FunctionRegistry.invoke(method, function, conversionHelper
- .convertIfNecessary(evaluatedChildren: _*): _*))
+ unwrap(
+ FunctionRegistry.invoke(method, function, conversionHelper
+ .convertIfNecessary(wrap(children.map(c => c.eval(input)), arguments, cached): _*): _*),
+ returnInspector)
}
}
@@ -133,7 +142,7 @@ private[hive] case class HiveGenericUdf(functionClassName: String, children: Seq
type UDFType = GenericUDF
@transient
- protected lazy val argumentInspectors = children.map(_.dataType).map(toInspector)
+ protected lazy val argumentInspectors = children.map(toInspector)
@transient
protected lazy val returnInspector = function.initialize(argumentInspectors.toArray)
@@ -148,18 +157,18 @@ private[hive] case class HiveGenericUdf(functionClassName: String, children: Seq
isUDFDeterministic && children.foldLeft(true)((prev, n) => prev && n.foldable)
}
- protected lazy val deferedObjects = Array.fill[DeferredObject](children.length)({
- new DeferredObjectAdapter
- })
+ @transient
+ protected lazy val deferedObjects =
+ argumentInspectors.map(new DeferredObjectAdapter(_)).toArray[DeferredObject]
// Adapter from Catalyst ExpressionResult to Hive DeferredObject
- class DeferredObjectAdapter extends DeferredObject {
+ class DeferredObjectAdapter(oi: ObjectInspector) extends DeferredObject {
private var func: () => Any = _
def set(func: () => Any) {
this.func = func
}
override def prepare(i: Int) = {}
- override def get(): AnyRef = wrap(func())
+ override def get(): AnyRef = wrap(func(), oi)
}
lazy val dataType: DataType = inspectorToDataType(returnInspector)
@@ -169,10 +178,13 @@ private[hive] case class HiveGenericUdf(functionClassName: String, children: Seq
var i = 0
while (i < children.length) {
val idx = i
- deferedObjects(i).asInstanceOf[DeferredObjectAdapter].set(() => {children(idx).eval(input)})
+ deferedObjects(i).asInstanceOf[DeferredObjectAdapter].set(
+ () => {
+ children(idx).eval(input)
+ })
i += 1
}
- unwrap(function.evaluate(deferedObjects))
+ unwrap(function.evaluate(deferedObjects), returnInspector)
}
}
@@ -260,12 +272,14 @@ private[hive] case class HiveGenericUdtf(
protected lazy val inputInspectors = children.map(_.dataType).map(toInspector)
@transient
- protected lazy val outputInspectors = {
- val structInspector = function.initialize(inputInspectors.toArray)
- structInspector.getAllStructFieldRefs.map(_.getFieldObjectInspector)
- }
+ protected lazy val outputInspector = function.initialize(inputInspectors.toArray)
- protected lazy val outputDataTypes = outputInspectors.map(inspectorToDataType)
+ @transient
+ protected lazy val udtInput = new Array[AnyRef](children.length)
+
+ protected lazy val outputDataTypes = outputInspector.getAllStructFieldRefs.map {
+ field => inspectorToDataType(field.getFieldObjectInspector)
+ }
override protected def makeOutput() = {
// Use column names when given, otherwise c_1, c_2, ... c_n.
@@ -283,14 +297,12 @@ private[hive] case class HiveGenericUdtf(
}
override def eval(input: Row): TraversableOnce[Row] = {
- outputInspectors // Make sure initialized.
+ outputInspector // Make sure initialized.
val inputProjection = new InterpretedProjection(children)
val collector = new UDTFCollector
function.setCollector(collector)
-
- val udtInput = inputProjection(input).map(wrap).toArray
- function.process(udtInput)
+ function.process(wrap(inputProjection(input), inputInspectors, udtInput))
collector.collectRows()
}
@@ -301,7 +313,7 @@ private[hive] case class HiveGenericUdtf(
// We need to clone the input here because implementations of
// GenericUDTF reuse the same object. Luckily they are always an array, so
// it is easy to clone.
- collected += new GenericRow(input.asInstanceOf[Array[_]].map(unwrap))
+ collected += unwrap(input, outputInspector).asInstanceOf[Row]
}
def collectRows() = {
@@ -342,7 +354,7 @@ private[hive] case class HiveUdafFunction(
private val buffer =
function.getNewAggregationBuffer.asInstanceOf[GenericUDAFEvaluator.AbstractAggregationBuffer]
- override def eval(input: Row): Any = unwrapData(function.evaluate(buffer), returnInspector)
+ override def eval(input: Row): Any = unwrap(function.evaluate(buffer), returnInspector)
@transient
val inputProjection = new InterpretedProjection(exprs)
diff --git a/sql/hive/src/test/resources/golden/constant array-0-761ef205b10ac4a10122c8b4ce10ada b/sql/hive/src/test/resources/golden/constant array-0-761ef205b10ac4a10122c8b4ce10ada
new file mode 100644
index 0000000000..94f18d0986
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/constant array-0-761ef205b10ac4a10122c8b4ce10ada
@@ -0,0 +1 @@
+["enterprise databases","hadoop distributed file system","hadoop map-reduce"]
diff --git a/sql/hive/src/test/resources/golden/udf_named_struct-0-8f0ea83364b78634fbb3752c5a5c725 b/sql/hive/src/test/resources/golden/udf_named_struct-0-8f0ea83364b78634fbb3752c5a5c725
new file mode 100644
index 0000000000..9bff96e7fa
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_named_struct-0-8f0ea83364b78634fbb3752c5a5c725
@@ -0,0 +1 @@
+named_struct(name1, val1, name2, val2, ...) - Creates a struct with the given field names and values
diff --git a/sql/hive/src/test/resources/golden/udf_named_struct-1-380c9638cc6ea8ea42f187bf0cedf350 b/sql/hive/src/test/resources/golden/udf_named_struct-1-380c9638cc6ea8ea42f187bf0cedf350
new file mode 100644
index 0000000000..9bff96e7fa
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_named_struct-1-380c9638cc6ea8ea42f187bf0cedf350
@@ -0,0 +1 @@
+named_struct(name1, val1, name2, val2, ...) - Creates a struct with the given field names and values
diff --git a/sql/hive/src/test/resources/golden/udf_named_struct-2-22a79ac608b1249306f82f4bdc669b17 b/sql/hive/src/test/resources/golden/udf_named_struct-2-22a79ac608b1249306f82f4bdc669b17
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_named_struct-2-22a79ac608b1249306f82f4bdc669b17
diff --git a/sql/hive/src/test/resources/golden/udf_named_struct-3-d7e4a555934307155784904ff9df188b b/sql/hive/src/test/resources/golden/udf_named_struct-3-d7e4a555934307155784904ff9df188b
new file mode 100644
index 0000000000..de25f51b5b
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_named_struct-3-d7e4a555934307155784904ff9df188b
@@ -0,0 +1 @@
+{"foo":1,"bar":2} 1
diff --git a/sql/hive/src/test/resources/golden/udf_sort_array-0-e86d559aeb84a4cc017a103182c22bfb b/sql/hive/src/test/resources/golden/udf_sort_array-0-e86d559aeb84a4cc017a103182c22bfb
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_sort_array-0-e86d559aeb84a4cc017a103182c22bfb
diff --git a/sql/hive/src/test/resources/golden/udf_sort_array-1-976cd8b6b50a2748bbc768aa5e11cf82 b/sql/hive/src/test/resources/golden/udf_sort_array-1-976cd8b6b50a2748bbc768aa5e11cf82
new file mode 100644
index 0000000000..d514df4191
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_sort_array-1-976cd8b6b50a2748bbc768aa5e11cf82
@@ -0,0 +1 @@
+sort_array(array(obj1, obj2,...)) - Sorts the input array in ascending order according to the natural ordering of the array elements.
diff --git a/sql/hive/src/test/resources/golden/udf_sort_array-10-9e047718e5fea6ea79124f1e899f1c13 b/sql/hive/src/test/resources/golden/udf_sort_array-10-9e047718e5fea6ea79124f1e899f1c13
new file mode 100644
index 0000000000..9d33cd51fe
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_sort_array-10-9e047718e5fea6ea79124f1e899f1c13
@@ -0,0 +1 @@
+[1,2,3,4,5] [1,2,7,8,9] [4,8,16,32,64] [1,100,246,357,1000] [false,true] [1.414,1.618,2.718,3.141] [1.41421,1.61803,2.71828,3.14159] ["","aramis","athos","portos"] ["1970-01-05 13:51:04.042","1970-01-07 00:54:54.442","1970-01-16 12:50:35.242"]
diff --git a/sql/hive/src/test/resources/golden/udf_sort_array-2-c429ec85a6da60ebd4bc6f0f266e8b93 b/sql/hive/src/test/resources/golden/udf_sort_array-2-c429ec85a6da60ebd4bc6f0f266e8b93
new file mode 100644
index 0000000000..43e36513de
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_sort_array-2-c429ec85a6da60ebd4bc6f0f266e8b93
@@ -0,0 +1,4 @@
+sort_array(array(obj1, obj2,...)) - Sorts the input array in ascending order according to the natural ordering of the array elements.
+Example:
+ > SELECT sort_array(array('b', 'd', 'c', 'a')) FROM src LIMIT 1;
+ 'a', 'b', 'c', 'd'
diff --git a/sql/hive/src/test/resources/golden/udf_sort_array-3-55c4cdaf8438b06675d60848d68f35de b/sql/hive/src/test/resources/golden/udf_sort_array-3-55c4cdaf8438b06675d60848d68f35de
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_sort_array-3-55c4cdaf8438b06675d60848d68f35de
diff --git a/sql/hive/src/test/resources/golden/udf_struct-0-f41043b7d9f14fa5e998c90454c7bdb1 b/sql/hive/src/test/resources/golden/udf_struct-0-f41043b7d9f14fa5e998c90454c7bdb1
new file mode 100644
index 0000000000..062cb1bc68
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_struct-0-f41043b7d9f14fa5e998c90454c7bdb1
@@ -0,0 +1 @@
+struct(col1, col2, col3, ...) - Creates a struct with the given field values
diff --git a/sql/hive/src/test/resources/golden/udf_struct-1-8ccdb20153debdab789ea8ad0228e2eb b/sql/hive/src/test/resources/golden/udf_struct-1-8ccdb20153debdab789ea8ad0228e2eb
new file mode 100644
index 0000000000..062cb1bc68
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_struct-1-8ccdb20153debdab789ea8ad0228e2eb
@@ -0,0 +1 @@
+struct(col1, col2, col3, ...) - Creates a struct with the given field values
diff --git a/sql/hive/src/test/resources/golden/udf_struct-2-4a62774a6de7571c8d2bcb77da63f8f3 b/sql/hive/src/test/resources/golden/udf_struct-2-4a62774a6de7571c8d2bcb77da63f8f3
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_struct-2-4a62774a6de7571c8d2bcb77da63f8f3
diff --git a/sql/hive/src/test/resources/golden/udf_struct-3-abffdaacb0c7076ab538fbeec072daa2 b/sql/hive/src/test/resources/golden/udf_struct-3-abffdaacb0c7076ab538fbeec072daa2
new file mode 100644
index 0000000000..ff1a28fa47
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_struct-3-abffdaacb0c7076ab538fbeec072daa2
@@ -0,0 +1 @@
+{"col1":1} {"col1":1,"col2":"a"} 1 a
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index 322a25bb20..ffe1f0b90f 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -34,6 +34,14 @@ case class TestData(a: Int, b: String)
* A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution.
*/
class HiveQuerySuite extends HiveComparisonTest {
+ createQueryTest("constant array",
+ """
+ |SELECT sort_array(
+ | sort_array(
+ | array("hadoop distributed file system",
+ | "enterprise databases", "hadoop map-reduce")))
+ |FROM src LIMIT 1;
+ """.stripMargin)
createQueryTest("count distinct 0 values",
"""
diff --git a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala
index 8cb81db8a9..afc252ac27 100644
--- a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala
+++ b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala
@@ -29,7 +29,11 @@ import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table}
import org.apache.hadoop.hive.ql.plan.{CreateTableDesc, FileSinkDesc, TableDesc}
import org.apache.hadoop.hive.ql.processors._
import org.apache.hadoop.hive.ql.stats.StatsSetupConst
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector
import org.apache.hadoop.hive.serde2.{Deserializer, ColumnProjectionUtils}
+import org.apache.hadoop.hive.serde2.{io => hiveIo}
import org.apache.hadoop.{io => hadoopIo}
import org.apache.hadoop.mapred.InputFormat
import scala.collection.JavaConversions._
@@ -50,6 +54,59 @@ private[hive] object HiveShim {
new TableDesc(serdeClass, inputFormatClass, outputFormatClass, properties)
}
+ def getPrimitiveWritableConstantObjectInspector(value: String): ObjectInspector =
+ PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
+ PrimitiveCategory.STRING, new hadoopIo.Text(value))
+
+ def getPrimitiveWritableConstantObjectInspector(value: Int): ObjectInspector =
+ PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
+ PrimitiveCategory.INT, new hadoopIo.IntWritable(value))
+
+ def getPrimitiveWritableConstantObjectInspector(value: Double): ObjectInspector =
+ PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
+ PrimitiveCategory.DOUBLE, new hiveIo.DoubleWritable(value))
+
+ def getPrimitiveWritableConstantObjectInspector(value: Boolean): ObjectInspector =
+ PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
+ PrimitiveCategory.BOOLEAN, new hadoopIo.BooleanWritable(value))
+
+ def getPrimitiveWritableConstantObjectInspector(value: Long): ObjectInspector =
+ PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
+ PrimitiveCategory.LONG, new hadoopIo.LongWritable(value))
+
+ def getPrimitiveWritableConstantObjectInspector(value: Float): ObjectInspector =
+ PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
+ PrimitiveCategory.FLOAT, new hadoopIo.FloatWritable(value))
+
+ def getPrimitiveWritableConstantObjectInspector(value: Short): ObjectInspector =
+ PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
+ PrimitiveCategory.SHORT, new hiveIo.ShortWritable(value))
+
+ def getPrimitiveWritableConstantObjectInspector(value: Byte): ObjectInspector =
+ PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
+ PrimitiveCategory.BYTE, new hiveIo.ByteWritable(value))
+
+ def getPrimitiveWritableConstantObjectInspector(value: Array[Byte]): ObjectInspector =
+ PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
+ PrimitiveCategory.BINARY, new hadoopIo.BytesWritable(value))
+
+ def getPrimitiveWritableConstantObjectInspector(value: java.sql.Date): ObjectInspector =
+ PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
+ PrimitiveCategory.DATE, new hiveIo.DateWritable(value))
+
+ def getPrimitiveWritableConstantObjectInspector(value: java.sql.Timestamp): ObjectInspector =
+ PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
+ PrimitiveCategory.TIMESTAMP, new hiveIo.TimestampWritable(value))
+
+ def getPrimitiveWritableConstantObjectInspector(value: BigDecimal): ObjectInspector =
+ PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
+ PrimitiveCategory.DECIMAL,
+ new hiveIo.HiveDecimalWritable(HiveShim.createDecimal(value.underlying())))
+
+ def getPrimitiveNullWritableConstantObjectInspector: ObjectInspector =
+ PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
+ PrimitiveCategory.VOID, null)
+
def createDriverResultsArray = new JArrayList[String]
def processResults(results: JArrayList[String]) = results
diff --git a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala
index b9a742cc6e..42cd65b251 100644
--- a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala
+++ b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala
@@ -21,6 +21,7 @@ import java.util.{ArrayList => JArrayList}
import java.util.Properties
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
+import org.apache.hadoop.mapred.InputFormat
import org.apache.hadoop.hive.common.StatsSetupConst
import org.apache.hadoop.hive.common.`type`.{HiveDecimal}
import org.apache.hadoop.hive.conf.HiveConf
@@ -28,10 +29,16 @@ import org.apache.hadoop.hive.ql.Context
import org.apache.hadoop.hive.ql.metadata.{Table, Hive, Partition}
import org.apache.hadoop.hive.ql.plan.{CreateTableDesc, FileSinkDesc, TableDesc}
import org.apache.hadoop.hive.ql.processors.CommandProcessorFactory
+import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory
import org.apache.hadoop.hive.serde2.{ColumnProjectionUtils, Deserializer}
-import org.apache.hadoop.mapred.InputFormat
-import org.apache.spark.Logging
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector
+import org.apache.hadoop.hive.serde2.{Deserializer, ColumnProjectionUtils}
+import org.apache.hadoop.hive.serde2.{io => hiveIo}
import org.apache.hadoop.{io => hadoopIo}
+import org.apache.spark.Logging
+
import scala.collection.JavaConversions._
import scala.language.implicitConversions
@@ -54,6 +61,59 @@ private[hive] object HiveShim {
new TableDesc(inputFormatClass, outputFormatClass, properties)
}
+ def getPrimitiveWritableConstantObjectInspector(value: String): ObjectInspector =
+ PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
+ TypeInfoFactory.stringTypeInfo, new hadoopIo.Text(value))
+
+ def getPrimitiveWritableConstantObjectInspector(value: Int): ObjectInspector =
+ PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
+ TypeInfoFactory.intTypeInfo, new hadoopIo.IntWritable(value))
+
+ def getPrimitiveWritableConstantObjectInspector(value: Double): ObjectInspector =
+ PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
+ TypeInfoFactory.doubleTypeInfo, new hiveIo.DoubleWritable(value))
+
+ def getPrimitiveWritableConstantObjectInspector(value: Boolean): ObjectInspector =
+ PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
+ TypeInfoFactory.booleanTypeInfo, new hadoopIo.BooleanWritable(value))
+
+ def getPrimitiveWritableConstantObjectInspector(value: Long): ObjectInspector =
+ PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
+ TypeInfoFactory.longTypeInfo, new hadoopIo.LongWritable(value))
+
+ def getPrimitiveWritableConstantObjectInspector(value: Float): ObjectInspector =
+ PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
+ TypeInfoFactory.floatTypeInfo, new hadoopIo.FloatWritable(value))
+
+ def getPrimitiveWritableConstantObjectInspector(value: Short): ObjectInspector =
+ PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
+ TypeInfoFactory.shortTypeInfo, new hiveIo.ShortWritable(value))
+
+ def getPrimitiveWritableConstantObjectInspector(value: Byte): ObjectInspector =
+ PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
+ TypeInfoFactory.byteTypeInfo, new hiveIo.ByteWritable(value))
+
+ def getPrimitiveWritableConstantObjectInspector(value: Array[Byte]): ObjectInspector =
+ PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
+ TypeInfoFactory.binaryTypeInfo, new hadoopIo.BytesWritable(value))
+
+ def getPrimitiveWritableConstantObjectInspector(value: java.sql.Date): ObjectInspector =
+ PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
+ TypeInfoFactory.dateTypeInfo, new hiveIo.DateWritable(value))
+
+ def getPrimitiveWritableConstantObjectInspector(value: java.sql.Timestamp): ObjectInspector =
+ PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
+ TypeInfoFactory.timestampTypeInfo, new hiveIo.TimestampWritable(value))
+
+ def getPrimitiveWritableConstantObjectInspector(value: BigDecimal): ObjectInspector =
+ PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
+ TypeInfoFactory.decimalTypeInfo,
+ new hiveIo.HiveDecimalWritable(HiveShim.createDecimal(value.underlying())))
+
+ def getPrimitiveNullWritableConstantObjectInspector: ObjectInspector =
+ PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
+ TypeInfoFactory.voidTypeInfo, null)
+
def createDriverResultsArray = new JArrayList[Object]
def processResults(results: JArrayList[Object]) = {