aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2014-07-23 16:26:55 -0700
committerMichael Armbrust <michael@databricks.com>2014-07-23 16:26:55 -0700
commit1871574a240e6f28adeb6bc8accc98c851cafae5 (patch)
tree082d8477467200c3e4dae7f78887e668b8bf3b1f /sql
parente060d3ee2d910a5a802bb29630dca6f66cc0525d (diff)
downloadspark-1871574a240e6f28adeb6bc8accc98c851cafae5.tar.gz
spark-1871574a240e6f28adeb6bc8accc98c851cafae5.tar.bz2
spark-1871574a240e6f28adeb6bc8accc98c851cafae5.zip
[SPARK-2569][SQL] Fix shipping of TEMPORARY hive UDFs.
Instead of shipping just the name and then looking up the info on the workers, we now ship the whole classname. Also, I refactored the file as it was getting pretty large to move out the type conversion code to its own file. Author: Michael Armbrust <michael@databricks.com> Closes #1552 from marmbrus/fixTempUdfs and squashes the following commits: b695904 [Michael Armbrust] Make add jar execute with Hive. Ship the whole function class name since sometimes we cannot lookup temporary functions on the workers.
Diffstat (limited to 'sql')
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala230
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala4
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala262
3 files changed, 261 insertions, 235 deletions
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
new file mode 100644
index 0000000000..ad7dc0ecdb
--- /dev/null
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
@@ -0,0 +1,230 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+import org.apache.hadoop.hive.common.`type`.HiveDecimal
+import org.apache.hadoop.hive.serde2.objectinspector._
+import org.apache.hadoop.hive.serde2.objectinspector.primitive._
+import org.apache.hadoop.hive.serde2.{io => hiveIo}
+import org.apache.hadoop.{io => hadoopIo}
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.types
+import org.apache.spark.sql.catalyst.types._
+
+/* Implicit conversions */
+import scala.collection.JavaConversions._
+
+private[hive] trait HiveInspectors {
+
+ def javaClassToDataType(clz: Class[_]): DataType = clz match {
+ // writable
+ case c: Class[_] if c == classOf[hadoopIo.DoubleWritable] => DoubleType
+ case c: Class[_] if c == classOf[hiveIo.DoubleWritable] => DoubleType
+ case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType
+ case c: Class[_] if c == classOf[hiveIo.ByteWritable] => ByteType
+ case c: Class[_] if c == classOf[hiveIo.ShortWritable] => ShortType
+ case c: Class[_] if c == classOf[hiveIo.TimestampWritable] => TimestampType
+ case c: Class[_] if c == classOf[hadoopIo.Text] => StringType
+ case c: Class[_] if c == classOf[hadoopIo.IntWritable] => IntegerType
+ case c: Class[_] if c == classOf[hadoopIo.LongWritable] => LongType
+ case c: Class[_] if c == classOf[hadoopIo.FloatWritable] => FloatType
+ case c: Class[_] if c == classOf[hadoopIo.BooleanWritable] => BooleanType
+ case c: Class[_] if c == classOf[hadoopIo.BytesWritable] => BinaryType
+
+ // java class
+ case c: Class[_] if c == classOf[java.lang.String] => StringType
+ case c: Class[_] if c == classOf[java.sql.Timestamp] => TimestampType
+ case c: Class[_] if c == classOf[HiveDecimal] => DecimalType
+ case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType
+ case c: Class[_] if c == classOf[Array[Byte]] => BinaryType
+ case c: Class[_] if c == classOf[java.lang.Short] => ShortType
+ case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType
+ case c: Class[_] if c == classOf[java.lang.Long] => LongType
+ case c: Class[_] if c == classOf[java.lang.Double] => DoubleType
+ case c: Class[_] if c == classOf[java.lang.Byte] => ByteType
+ case c: Class[_] if c == classOf[java.lang.Float] => FloatType
+ case c: Class[_] if c == classOf[java.lang.Boolean] => BooleanType
+
+ // primitive type
+ case c: Class[_] if c == java.lang.Short.TYPE => ShortType
+ case c: Class[_] if c == java.lang.Integer.TYPE => IntegerType
+ case c: Class[_] if c == java.lang.Long.TYPE => LongType
+ case c: Class[_] if c == java.lang.Double.TYPE => DoubleType
+ case c: Class[_] if c == java.lang.Byte.TYPE => ByteType
+ case c: Class[_] if c == java.lang.Float.TYPE => FloatType
+ case c: Class[_] if c == java.lang.Boolean.TYPE => BooleanType
+
+ case c: Class[_] if c.isArray => ArrayType(javaClassToDataType(c.getComponentType))
+ }
+
+ /** Converts hive types to native catalyst types. */
+ def unwrap(a: Any): Any = a match {
+ case null => null
+ case i: hadoopIo.IntWritable => i.get
+ case t: hadoopIo.Text => t.toString
+ case l: hadoopIo.LongWritable => l.get
+ case d: hadoopIo.DoubleWritable => d.get
+ case d: hiveIo.DoubleWritable => d.get
+ case s: hiveIo.ShortWritable => s.get
+ case b: hadoopIo.BooleanWritable => b.get
+ case b: hiveIo.ByteWritable => b.get
+ case b: hadoopIo.FloatWritable => b.get
+ case b: hadoopIo.BytesWritable => {
+ val bytes = new Array[Byte](b.getLength)
+ System.arraycopy(b.getBytes(), 0, bytes, 0, b.getLength)
+ bytes
+ }
+ case t: hiveIo.TimestampWritable => t.getTimestamp
+ case b: hiveIo.HiveDecimalWritable => BigDecimal(b.getHiveDecimal().bigDecimalValue())
+ case list: java.util.List[_] => list.map(unwrap)
+ case map: java.util.Map[_,_] => map.map { case (k, v) => (unwrap(k), unwrap(v)) }.toMap
+ case array: Array[_] => array.map(unwrap).toSeq
+ case p: java.lang.Short => p
+ case p: java.lang.Long => p
+ case p: java.lang.Float => p
+ case p: java.lang.Integer => p
+ case p: java.lang.Double => p
+ case p: java.lang.Byte => p
+ case p: java.lang.Boolean => p
+ case str: String => str
+ case p: java.math.BigDecimal => p
+ case p: Array[Byte] => p
+ case p: java.sql.Timestamp => p
+ }
+
+ def unwrapData(data: Any, oi: ObjectInspector): Any = oi match {
+ case hvoi: HiveVarcharObjectInspector =>
+ if (data == null) null else hvoi.getPrimitiveJavaObject(data).getValue
+ case hdoi: HiveDecimalObjectInspector =>
+ if (data == null) null else BigDecimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue())
+ case pi: PrimitiveObjectInspector => pi.getPrimitiveJavaObject(data)
+ case li: ListObjectInspector =>
+ Option(li.getList(data))
+ .map(_.map(unwrapData(_, li.getListElementObjectInspector)).toSeq)
+ .orNull
+ case mi: MapObjectInspector =>
+ Option(mi.getMap(data)).map(
+ _.map {
+ case (k,v) =>
+ (unwrapData(k, mi.getMapKeyObjectInspector),
+ unwrapData(v, mi.getMapValueObjectInspector))
+ }.toMap).orNull
+ case si: StructObjectInspector =>
+ val allRefs = si.getAllStructFieldRefs
+ new GenericRow(
+ allRefs.map(r =>
+ unwrapData(si.getStructFieldData(data,r), r.getFieldObjectInspector)).toArray)
+ }
+
+ /** Converts native catalyst types to the types expected by Hive */
+ def wrap(a: Any): AnyRef = a match {
+ case s: String => new hadoopIo.Text(s) // TODO why should be Text?
+ case i: Int => i: java.lang.Integer
+ case b: Boolean => b: java.lang.Boolean
+ case f: Float => f: java.lang.Float
+ case d: Double => d: java.lang.Double
+ case l: Long => l: java.lang.Long
+ case l: Short => l: java.lang.Short
+ case l: Byte => l: java.lang.Byte
+ case b: BigDecimal => b.bigDecimal
+ case b: Array[Byte] => b
+ case t: java.sql.Timestamp => t
+ case s: Seq[_] => seqAsJavaList(s.map(wrap))
+ case m: Map[_,_] =>
+ mapAsJavaMap(m.map { case (k, v) => wrap(k) -> wrap(v) })
+ case null => null
+ }
+
+ def toInspector(dataType: DataType): ObjectInspector = dataType match {
+ case ArrayType(tpe) => ObjectInspectorFactory.getStandardListObjectInspector(toInspector(tpe))
+ case MapType(keyType, valueType) =>
+ ObjectInspectorFactory.getStandardMapObjectInspector(
+ toInspector(keyType), toInspector(valueType))
+ case StringType => PrimitiveObjectInspectorFactory.javaStringObjectInspector
+ case IntegerType => PrimitiveObjectInspectorFactory.javaIntObjectInspector
+ case DoubleType => PrimitiveObjectInspectorFactory.javaDoubleObjectInspector
+ case BooleanType => PrimitiveObjectInspectorFactory.javaBooleanObjectInspector
+ case LongType => PrimitiveObjectInspectorFactory.javaLongObjectInspector
+ case FloatType => PrimitiveObjectInspectorFactory.javaFloatObjectInspector
+ case ShortType => PrimitiveObjectInspectorFactory.javaShortObjectInspector
+ case ByteType => PrimitiveObjectInspectorFactory.javaByteObjectInspector
+ case NullType => PrimitiveObjectInspectorFactory.javaVoidObjectInspector
+ case BinaryType => PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector
+ case TimestampType => PrimitiveObjectInspectorFactory.javaTimestampObjectInspector
+ case DecimalType => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector
+ case StructType(fields) =>
+ ObjectInspectorFactory.getStandardStructObjectInspector(
+ fields.map(f => f.name), fields.map(f => toInspector(f.dataType)))
+ }
+
+ def inspectorToDataType(inspector: ObjectInspector): DataType = inspector match {
+ case s: StructObjectInspector =>
+ StructType(s.getAllStructFieldRefs.map(f => {
+ types.StructField(
+ f.getFieldName, inspectorToDataType(f.getFieldObjectInspector), nullable = true)
+ }))
+ case l: ListObjectInspector => ArrayType(inspectorToDataType(l.getListElementObjectInspector))
+ case m: MapObjectInspector =>
+ MapType(
+ inspectorToDataType(m.getMapKeyObjectInspector),
+ inspectorToDataType(m.getMapValueObjectInspector))
+ case _: WritableStringObjectInspector => StringType
+ case _: JavaStringObjectInspector => StringType
+ case _: WritableIntObjectInspector => IntegerType
+ case _: JavaIntObjectInspector => IntegerType
+ case _: WritableDoubleObjectInspector => DoubleType
+ case _: JavaDoubleObjectInspector => DoubleType
+ case _: WritableBooleanObjectInspector => BooleanType
+ case _: JavaBooleanObjectInspector => BooleanType
+ case _: WritableLongObjectInspector => LongType
+ case _: JavaLongObjectInspector => LongType
+ case _: WritableShortObjectInspector => ShortType
+ case _: JavaShortObjectInspector => ShortType
+ case _: WritableByteObjectInspector => ByteType
+ case _: JavaByteObjectInspector => ByteType
+ case _: WritableFloatObjectInspector => FloatType
+ case _: JavaFloatObjectInspector => FloatType
+ case _: WritableBinaryObjectInspector => BinaryType
+ case _: JavaBinaryObjectInspector => BinaryType
+ case _: WritableHiveDecimalObjectInspector => DecimalType
+ case _: JavaHiveDecimalObjectInspector => DecimalType
+ case _: WritableTimestampObjectInspector => TimestampType
+ case _: JavaTimestampObjectInspector => TimestampType
+ }
+
+ implicit class typeInfoConversions(dt: DataType) {
+ import org.apache.hadoop.hive.serde2.typeinfo._
+ import TypeInfoFactory._
+
+ def toTypeInfo: TypeInfo = dt match {
+ case BinaryType => binaryTypeInfo
+ case BooleanType => booleanTypeInfo
+ case ByteType => byteTypeInfo
+ case DoubleType => doubleTypeInfo
+ case FloatType => floatTypeInfo
+ case IntegerType => intTypeInfo
+ case LongType => longTypeInfo
+ case ShortType => shortTypeInfo
+ case StringType => stringTypeInfo
+ case DecimalType => decimalTypeInfo
+ case TimestampType => timestampTypeInfo
+ case NullType => voidTypeInfo
+ }
+ }
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index 53480a521d..c4ca9f362a 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -42,8 +42,6 @@ private[hive] case class ShellCommand(cmd: String) extends Command
private[hive] case class SourceCommand(filePath: String) extends Command
-private[hive] case class AddJar(jarPath: String) extends Command
-
private[hive] case class AddFile(filePath: String) extends Command
/** Provides a mapping from HiveQL statements to catalyst logical plans and expression trees. */
@@ -229,7 +227,7 @@ private[hive] object HiveQl {
} else if (sql.trim.toLowerCase.startsWith("uncache table")) {
CacheCommand(sql.trim.drop(14).trim, false)
} else if (sql.trim.toLowerCase.startsWith("add jar")) {
- AddJar(sql.trim.drop(8))
+ NativeCommand(sql)
} else if (sql.trim.toLowerCase.startsWith("add file")) {
AddFile(sql.trim.drop(9))
} else if (sql.trim.toLowerCase.startsWith("dfs")) {
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
index fc33c5b460..057eb60a02 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
@@ -24,22 +24,19 @@ import org.apache.hadoop.hive.ql.exec.UDF
import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry}
import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType}
import org.apache.hadoop.hive.ql.udf.generic._
-import org.apache.hadoop.hive.serde2.objectinspector._
-import org.apache.hadoop.hive.serde2.objectinspector.primitive._
-import org.apache.hadoop.hive.serde2.{io => hiveIo}
-import org.apache.hadoop.{io => hadoopIo}
import org.apache.spark.sql.Logging
import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.types
import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.util.Utils.getContextOrSparkClassLoader
/* Implicit conversions */
import scala.collection.JavaConversions._
-private[hive] object HiveFunctionRegistry
- extends analysis.FunctionRegistry with HiveFunctionFactory with HiveInspectors {
+private[hive] object HiveFunctionRegistry extends analysis.FunctionRegistry with HiveInspectors {
+
+ def getFunctionInfo(name: String) = FunctionRegistry.getFunctionInfo(name)
def lookupFunction(name: String, children: Seq[Expression]): Expression = {
// We only look it up to see if it exists, but do not include it in the HiveUDF since it is
@@ -47,111 +44,37 @@ private[hive] object HiveFunctionRegistry
val functionInfo: FunctionInfo = Option(FunctionRegistry.getFunctionInfo(name)).getOrElse(
sys.error(s"Couldn't find function $name"))
+ val functionClassName = functionInfo.getFunctionClass.getName()
+
if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) {
- val function = createFunction[UDF](name)
+ val function = functionInfo.getFunctionClass.newInstance().asInstanceOf[UDF]
val method = function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo))
lazy val expectedDataTypes = method.getParameterTypes.map(javaClassToDataType)
HiveSimpleUdf(
- name,
+ functionClassName,
children.zip(expectedDataTypes).map { case (e, t) => Cast(e, t) }
)
} else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) {
- HiveGenericUdf(name, children)
+ HiveGenericUdf(functionClassName, children)
} else if (
classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) {
- HiveGenericUdaf(name, children)
+ HiveGenericUdaf(functionClassName, children)
} else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) {
- HiveGenericUdtf(name, Nil, children)
+ HiveGenericUdtf(functionClassName, Nil, children)
} else {
sys.error(s"No handler for udf ${functionInfo.getFunctionClass}")
}
}
-
- def javaClassToDataType(clz: Class[_]): DataType = clz match {
- // writable
- case c: Class[_] if c == classOf[hadoopIo.DoubleWritable] => DoubleType
- case c: Class[_] if c == classOf[hiveIo.DoubleWritable] => DoubleType
- case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType
- case c: Class[_] if c == classOf[hiveIo.ByteWritable] => ByteType
- case c: Class[_] if c == classOf[hiveIo.ShortWritable] => ShortType
- case c: Class[_] if c == classOf[hiveIo.TimestampWritable] => TimestampType
- case c: Class[_] if c == classOf[hadoopIo.Text] => StringType
- case c: Class[_] if c == classOf[hadoopIo.IntWritable] => IntegerType
- case c: Class[_] if c == classOf[hadoopIo.LongWritable] => LongType
- case c: Class[_] if c == classOf[hadoopIo.FloatWritable] => FloatType
- case c: Class[_] if c == classOf[hadoopIo.BooleanWritable] => BooleanType
- case c: Class[_] if c == classOf[hadoopIo.BytesWritable] => BinaryType
-
- // java class
- case c: Class[_] if c == classOf[java.lang.String] => StringType
- case c: Class[_] if c == classOf[java.sql.Timestamp] => TimestampType
- case c: Class[_] if c == classOf[HiveDecimal] => DecimalType
- case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType
- case c: Class[_] if c == classOf[Array[Byte]] => BinaryType
- case c: Class[_] if c == classOf[java.lang.Short] => ShortType
- case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType
- case c: Class[_] if c == classOf[java.lang.Long] => LongType
- case c: Class[_] if c == classOf[java.lang.Double] => DoubleType
- case c: Class[_] if c == classOf[java.lang.Byte] => ByteType
- case c: Class[_] if c == classOf[java.lang.Float] => FloatType
- case c: Class[_] if c == classOf[java.lang.Boolean] => BooleanType
-
- // primitive type
- case c: Class[_] if c == java.lang.Short.TYPE => ShortType
- case c: Class[_] if c == java.lang.Integer.TYPE => IntegerType
- case c: Class[_] if c == java.lang.Long.TYPE => LongType
- case c: Class[_] if c == java.lang.Double.TYPE => DoubleType
- case c: Class[_] if c == java.lang.Byte.TYPE => ByteType
- case c: Class[_] if c == java.lang.Float.TYPE => FloatType
- case c: Class[_] if c == java.lang.Boolean.TYPE => BooleanType
-
- case c: Class[_] if c.isArray => ArrayType(javaClassToDataType(c.getComponentType))
- }
}
private[hive] trait HiveFunctionFactory {
- def getFunctionInfo(name: String) = FunctionRegistry.getFunctionInfo(name)
- def getFunctionClass(name: String) = getFunctionInfo(name).getFunctionClass
- def createFunction[UDFType](name: String) =
- getFunctionClass(name).newInstance.asInstanceOf[UDFType]
-
- /** Converts hive types to native catalyst types. */
- def unwrap(a: Any): Any = a match {
- case null => null
- case i: hadoopIo.IntWritable => i.get
- case t: hadoopIo.Text => t.toString
- case l: hadoopIo.LongWritable => l.get
- case d: hadoopIo.DoubleWritable => d.get
- case d: hiveIo.DoubleWritable => d.get
- case s: hiveIo.ShortWritable => s.get
- case b: hadoopIo.BooleanWritable => b.get
- case b: hiveIo.ByteWritable => b.get
- case b: hadoopIo.FloatWritable => b.get
- case b: hadoopIo.BytesWritable => {
- val bytes = new Array[Byte](b.getLength)
- System.arraycopy(b.getBytes(), 0, bytes, 0, b.getLength)
- bytes
- }
- case t: hiveIo.TimestampWritable => t.getTimestamp
- case b: hiveIo.HiveDecimalWritable => BigDecimal(b.getHiveDecimal().bigDecimalValue())
- case list: java.util.List[_] => list.map(unwrap)
- case map: java.util.Map[_,_] => map.map { case (k, v) => (unwrap(k), unwrap(v)) }.toMap
- case array: Array[_] => array.map(unwrap).toSeq
- case p: java.lang.Short => p
- case p: java.lang.Long => p
- case p: java.lang.Float => p
- case p: java.lang.Integer => p
- case p: java.lang.Double => p
- case p: java.lang.Byte => p
- case p: java.lang.Boolean => p
- case str: String => str
- case p: java.math.BigDecimal => p
- case p: Array[Byte] => p
- case p: java.sql.Timestamp => p
- }
+ val functionClassName: String
+
+ def createFunction[UDFType]() =
+ getContextOrSparkClassLoader.loadClass(functionClassName).newInstance.asInstanceOf[UDFType]
}
private[hive] abstract class HiveUdf extends Expression with Logging with HiveFunctionFactory {
@@ -160,19 +83,17 @@ private[hive] abstract class HiveUdf extends Expression with Logging with HiveFu
type UDFType
type EvaluatedType = Any
- val name: String
-
def nullable = true
def references = children.flatMap(_.references).toSet
- // FunctionInfo is not serializable so we must look it up here again.
- lazy val functionInfo = getFunctionInfo(name)
- lazy val function = createFunction[UDFType](name)
+ lazy val function = createFunction[UDFType]()
- override def toString = s"$nodeName#${functionInfo.getDisplayName}(${children.mkString(",")})"
+ override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})"
}
-private[hive] case class HiveSimpleUdf(name: String, children: Seq[Expression]) extends HiveUdf {
+private[hive] case class HiveSimpleUdf(functionClassName: String, children: Seq[Expression])
+ extends HiveUdf {
+
import org.apache.spark.sql.hive.HiveFunctionRegistry._
type UDFType = UDF
@@ -226,7 +147,7 @@ private[hive] case class HiveSimpleUdf(name: String, children: Seq[Expression])
}
}
-private[hive] case class HiveGenericUdf(name: String, children: Seq[Expression])
+private[hive] case class HiveGenericUdf(functionClassName: String, children: Seq[Expression])
extends HiveUdf with HiveInspectors {
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._
@@ -277,131 +198,8 @@ private[hive] case class HiveGenericUdf(name: String, children: Seq[Expression])
}
}
-private[hive] trait HiveInspectors {
-
- def unwrapData(data: Any, oi: ObjectInspector): Any = oi match {
- case hvoi: HiveVarcharObjectInspector =>
- if (data == null) null else hvoi.getPrimitiveJavaObject(data).getValue
- case hdoi: HiveDecimalObjectInspector =>
- if (data == null) null else BigDecimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue())
- case pi: PrimitiveObjectInspector => pi.getPrimitiveJavaObject(data)
- case li: ListObjectInspector =>
- Option(li.getList(data))
- .map(_.map(unwrapData(_, li.getListElementObjectInspector)).toSeq)
- .orNull
- case mi: MapObjectInspector =>
- Option(mi.getMap(data)).map(
- _.map {
- case (k,v) =>
- (unwrapData(k, mi.getMapKeyObjectInspector),
- unwrapData(v, mi.getMapValueObjectInspector))
- }.toMap).orNull
- case si: StructObjectInspector =>
- val allRefs = si.getAllStructFieldRefs
- new GenericRow(
- allRefs.map(r =>
- unwrapData(si.getStructFieldData(data,r), r.getFieldObjectInspector)).toArray)
- }
-
- /** Converts native catalyst types to the types expected by Hive */
- def wrap(a: Any): AnyRef = a match {
- case s: String => new hadoopIo.Text(s) // TODO why should be Text?
- case i: Int => i: java.lang.Integer
- case b: Boolean => b: java.lang.Boolean
- case f: Float => f: java.lang.Float
- case d: Double => d: java.lang.Double
- case l: Long => l: java.lang.Long
- case l: Short => l: java.lang.Short
- case l: Byte => l: java.lang.Byte
- case b: BigDecimal => b.bigDecimal
- case b: Array[Byte] => b
- case t: java.sql.Timestamp => t
- case s: Seq[_] => seqAsJavaList(s.map(wrap))
- case m: Map[_,_] =>
- mapAsJavaMap(m.map { case (k, v) => wrap(k) -> wrap(v) })
- case null => null
- }
-
- def toInspector(dataType: DataType): ObjectInspector = dataType match {
- case ArrayType(tpe) => ObjectInspectorFactory.getStandardListObjectInspector(toInspector(tpe))
- case MapType(keyType, valueType) =>
- ObjectInspectorFactory.getStandardMapObjectInspector(
- toInspector(keyType), toInspector(valueType))
- case StringType => PrimitiveObjectInspectorFactory.javaStringObjectInspector
- case IntegerType => PrimitiveObjectInspectorFactory.javaIntObjectInspector
- case DoubleType => PrimitiveObjectInspectorFactory.javaDoubleObjectInspector
- case BooleanType => PrimitiveObjectInspectorFactory.javaBooleanObjectInspector
- case LongType => PrimitiveObjectInspectorFactory.javaLongObjectInspector
- case FloatType => PrimitiveObjectInspectorFactory.javaFloatObjectInspector
- case ShortType => PrimitiveObjectInspectorFactory.javaShortObjectInspector
- case ByteType => PrimitiveObjectInspectorFactory.javaByteObjectInspector
- case NullType => PrimitiveObjectInspectorFactory.javaVoidObjectInspector
- case BinaryType => PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector
- case TimestampType => PrimitiveObjectInspectorFactory.javaTimestampObjectInspector
- case DecimalType => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector
- case StructType(fields) =>
- ObjectInspectorFactory.getStandardStructObjectInspector(
- fields.map(f => f.name), fields.map(f => toInspector(f.dataType)))
- }
-
- def inspectorToDataType(inspector: ObjectInspector): DataType = inspector match {
- case s: StructObjectInspector =>
- StructType(s.getAllStructFieldRefs.map(f => {
- types.StructField(
- f.getFieldName, inspectorToDataType(f.getFieldObjectInspector), nullable = true)
- }))
- case l: ListObjectInspector => ArrayType(inspectorToDataType(l.getListElementObjectInspector))
- case m: MapObjectInspector =>
- MapType(
- inspectorToDataType(m.getMapKeyObjectInspector),
- inspectorToDataType(m.getMapValueObjectInspector))
- case _: WritableStringObjectInspector => StringType
- case _: JavaStringObjectInspector => StringType
- case _: WritableIntObjectInspector => IntegerType
- case _: JavaIntObjectInspector => IntegerType
- case _: WritableDoubleObjectInspector => DoubleType
- case _: JavaDoubleObjectInspector => DoubleType
- case _: WritableBooleanObjectInspector => BooleanType
- case _: JavaBooleanObjectInspector => BooleanType
- case _: WritableLongObjectInspector => LongType
- case _: JavaLongObjectInspector => LongType
- case _: WritableShortObjectInspector => ShortType
- case _: JavaShortObjectInspector => ShortType
- case _: WritableByteObjectInspector => ByteType
- case _: JavaByteObjectInspector => ByteType
- case _: WritableFloatObjectInspector => FloatType
- case _: JavaFloatObjectInspector => FloatType
- case _: WritableBinaryObjectInspector => BinaryType
- case _: JavaBinaryObjectInspector => BinaryType
- case _: WritableHiveDecimalObjectInspector => DecimalType
- case _: JavaHiveDecimalObjectInspector => DecimalType
- case _: WritableTimestampObjectInspector => TimestampType
- case _: JavaTimestampObjectInspector => TimestampType
- }
-
- implicit class typeInfoConversions(dt: DataType) {
- import org.apache.hadoop.hive.serde2.typeinfo._
- import TypeInfoFactory._
-
- def toTypeInfo: TypeInfo = dt match {
- case BinaryType => binaryTypeInfo
- case BooleanType => booleanTypeInfo
- case ByteType => byteTypeInfo
- case DoubleType => doubleTypeInfo
- case FloatType => floatTypeInfo
- case IntegerType => intTypeInfo
- case LongType => longTypeInfo
- case ShortType => shortTypeInfo
- case StringType => stringTypeInfo
- case DecimalType => decimalTypeInfo
- case TimestampType => timestampTypeInfo
- case NullType => voidTypeInfo
- }
- }
-}
-
private[hive] case class HiveGenericUdaf(
- name: String,
+ functionClassName: String,
children: Seq[Expression]) extends AggregateExpression
with HiveInspectors
with HiveFunctionFactory {
@@ -409,7 +207,7 @@ private[hive] case class HiveGenericUdaf(
type UDFType = AbstractGenericUDAFResolver
@transient
- protected lazy val resolver: AbstractGenericUDAFResolver = createFunction(name)
+ protected lazy val resolver: AbstractGenericUDAFResolver = createFunction()
@transient
protected lazy val objectInspector = {
@@ -426,9 +224,9 @@ private[hive] case class HiveGenericUdaf(
def references: Set[Attribute] = children.map(_.references).flatten.toSet
- override def toString = s"$nodeName#$name(${children.mkString(",")})"
+ override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})"
- def newInstance() = new HiveUdafFunction(name, children, this)
+ def newInstance() = new HiveUdafFunction(functionClassName, children, this)
}
/**
@@ -443,7 +241,7 @@ private[hive] case class HiveGenericUdaf(
* user defined aggregations, which have clean semantics even in a partitioned execution.
*/
private[hive] case class HiveGenericUdtf(
- name: String,
+ functionClassName: String,
aliasNames: Seq[String],
children: Seq[Expression])
extends Generator with HiveInspectors with HiveFunctionFactory {
@@ -451,7 +249,7 @@ private[hive] case class HiveGenericUdtf(
override def references = children.flatMap(_.references).toSet
@transient
- protected lazy val function: GenericUDTF = createFunction(name)
+ protected lazy val function: GenericUDTF = createFunction()
protected lazy val inputInspectors = children.map(_.dataType).map(toInspector)
@@ -506,11 +304,11 @@ private[hive] case class HiveGenericUdtf(
}
}
- override def toString = s"$nodeName#$name(${children.mkString(",")})"
+ override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})"
}
private[hive] case class HiveUdafFunction(
- functionName: String,
+ functionClassName: String,
exprs: Seq[Expression],
base: AggregateExpression)
extends AggregateFunction
@@ -519,7 +317,7 @@ private[hive] case class HiveUdafFunction(
def this() = this(null, null, null)
- private val resolver = createFunction[AbstractGenericUDAFResolver](functionName)
+ private val resolver = createFunction[AbstractGenericUDAFResolver]()
private val inspectors = exprs.map(_.dataType).map(toInspector).toArray