aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala5
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala93
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala7
-rw-r--r--sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala11
-rw-r--r--sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala107
5 files changed, 173 insertions, 50 deletions
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index f4c42bbc5b..cd4e5a239e 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -1128,7 +1128,10 @@ private[hive] object HiveQl {
Explode(attributes, nodeToExpr(child))
case Token("TOK_FUNCTION", Token(functionName, Nil) :: children) =>
- HiveGenericUdtf(functionName, attributes, children.map(nodeToExpr))
+ HiveGenericUdtf(
+ new HiveFunctionWrapper(functionName),
+ attributes,
+ children.map(nodeToExpr))
case a: ASTNode =>
throw new NotImplementedError(
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
index fecf8faaf4..ed2e96df8a 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
@@ -54,46 +54,31 @@ private[hive] abstract class HiveFunctionRegistry
val functionClassName = functionInfo.getFunctionClass.getName
if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) {
- HiveSimpleUdf(functionClassName, children)
+ HiveSimpleUdf(new HiveFunctionWrapper(functionClassName), children)
} else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) {
- HiveGenericUdf(functionClassName, children)
+ HiveGenericUdf(new HiveFunctionWrapper(functionClassName), children)
} else if (
classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) {
- HiveGenericUdaf(functionClassName, children)
+ HiveGenericUdaf(new HiveFunctionWrapper(functionClassName), children)
} else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) {
- HiveUdaf(functionClassName, children)
+ HiveUdaf(new HiveFunctionWrapper(functionClassName), children)
} else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) {
- HiveGenericUdtf(functionClassName, Nil, children)
+ HiveGenericUdtf(new HiveFunctionWrapper(functionClassName), Nil, children)
} else {
sys.error(s"No handler for udf ${functionInfo.getFunctionClass}")
}
}
}
-private[hive] trait HiveFunctionFactory {
- val functionClassName: String
-
- def createFunction[UDFType]() =
- getContextOrSparkClassLoader.loadClass(functionClassName).newInstance.asInstanceOf[UDFType]
-}
-
-private[hive] abstract class HiveUdf extends Expression with Logging with HiveFunctionFactory {
- self: Product =>
-
- type UDFType
+private[hive] case class HiveSimpleUdf(funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
+ extends Expression with HiveInspectors with Logging {
type EvaluatedType = Any
+ type UDFType = UDF
def nullable = true
- lazy val function = createFunction[UDFType]()
-
- override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})"
-}
-
-private[hive] case class HiveSimpleUdf(functionClassName: String, children: Seq[Expression])
- extends HiveUdf with HiveInspectors {
-
- type UDFType = UDF
+ @transient
+ lazy val function = funcWrapper.createFunction[UDFType]()
@transient
protected lazy val method =
@@ -131,6 +116,8 @@ private[hive] case class HiveSimpleUdf(functionClassName: String, children: Seq[
.convertIfNecessary(wrap(children.map(c => c.eval(input)), arguments, cached): _*): _*),
returnInspector)
}
+
+ override def toString = s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
}
// Adapter from Catalyst ExpressionResult to Hive DeferredObject
@@ -144,16 +131,23 @@ private[hive] class DeferredObjectAdapter(oi: ObjectInspector)
override def get(): AnyRef = wrap(func(), oi)
}
-private[hive] case class HiveGenericUdf(functionClassName: String, children: Seq[Expression])
- extends HiveUdf with HiveInspectors {
+private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
+ extends Expression with HiveInspectors with Logging {
type UDFType = GenericUDF
+ type EvaluatedType = Any
+
+ def nullable = true
+
+ @transient
+ lazy val function = funcWrapper.createFunction[UDFType]()
@transient
protected lazy val argumentInspectors = children.map(toInspector)
@transient
- protected lazy val returnInspector =
+ protected lazy val returnInspector = {
function.initializeAndFoldConstants(argumentInspectors.toArray)
+ }
@transient
protected lazy val isUDFDeterministic = {
@@ -183,18 +177,19 @@ private[hive] case class HiveGenericUdf(functionClassName: String, children: Seq
}
unwrap(function.evaluate(deferedObjects), returnInspector)
}
+
+ override def toString = s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
}
private[hive] case class HiveGenericUdaf(
- functionClassName: String,
+ funcWrapper: HiveFunctionWrapper,
children: Seq[Expression]) extends AggregateExpression
- with HiveInspectors
- with HiveFunctionFactory {
+ with HiveInspectors {
type UDFType = AbstractGenericUDAFResolver
@transient
- protected lazy val resolver: AbstractGenericUDAFResolver = createFunction()
+ protected lazy val resolver: AbstractGenericUDAFResolver = funcWrapper.createFunction()
@transient
protected lazy val objectInspector = {
@@ -209,22 +204,22 @@ private[hive] case class HiveGenericUdaf(
def nullable: Boolean = true
- override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})"
+ override def toString = s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
- def newInstance() = new HiveUdafFunction(functionClassName, children, this)
+ def newInstance() = new HiveUdafFunction(funcWrapper, children, this)
}
/** It is used as a wrapper for the hive functions which uses UDAF interface */
private[hive] case class HiveUdaf(
- functionClassName: String,
+ funcWrapper: HiveFunctionWrapper,
children: Seq[Expression]) extends AggregateExpression
- with HiveInspectors
- with HiveFunctionFactory {
+ with HiveInspectors {
type UDFType = UDAF
@transient
- protected lazy val resolver: AbstractGenericUDAFResolver = new GenericUDAFBridge(createFunction())
+ protected lazy val resolver: AbstractGenericUDAFResolver =
+ new GenericUDAFBridge(funcWrapper.createFunction())
@transient
protected lazy val objectInspector = {
@@ -239,10 +234,10 @@ private[hive] case class HiveUdaf(
def nullable: Boolean = true
- override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})"
+ override def toString = s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
def newInstance() =
- new HiveUdafFunction(functionClassName, children, this, true)
+ new HiveUdafFunction(funcWrapper, children, this, true)
}
/**
@@ -257,13 +252,13 @@ private[hive] case class HiveUdaf(
* user defined aggregations, which have clean semantics even in a partitioned execution.
*/
private[hive] case class HiveGenericUdtf(
- functionClassName: String,
+ funcWrapper: HiveFunctionWrapper,
aliasNames: Seq[String],
children: Seq[Expression])
- extends Generator with HiveInspectors with HiveFunctionFactory {
+ extends Generator with HiveInspectors {
@transient
- protected lazy val function: GenericUDTF = createFunction()
+ protected lazy val function: GenericUDTF = funcWrapper.createFunction()
@transient
protected lazy val inputInspectors = children.map(_.dataType).map(toInspector)
@@ -320,25 +315,24 @@ private[hive] case class HiveGenericUdtf(
}
}
- override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})"
+ override def toString = s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
}
private[hive] case class HiveUdafFunction(
- functionClassName: String,
+ funcWrapper: HiveFunctionWrapper,
exprs: Seq[Expression],
base: AggregateExpression,
isUDAFBridgeRequired: Boolean = false)
extends AggregateFunction
- with HiveInspectors
- with HiveFunctionFactory {
+ with HiveInspectors {
def this() = this(null, null, null)
private val resolver =
if (isUDAFBridgeRequired) {
- new GenericUDAFBridge(createFunction[UDAF]())
+ new GenericUDAFBridge(funcWrapper.createFunction[UDAF]())
} else {
- createFunction[AbstractGenericUDAFResolver]()
+ funcWrapper.createFunction[AbstractGenericUDAFResolver]()
}
private val inspectors = exprs.map(_.dataType).map(toInspector).toArray
@@ -361,3 +355,4 @@ private[hive] case class HiveUdafFunction(
function.iterate(buffer, inputs)
}
}
+
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
index 872f28d514..5fcaf671a8 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
@@ -60,6 +60,13 @@ class HiveUdfSuite extends QueryTest {
| getStruct(1).f5 FROM src LIMIT 1
""".stripMargin).first() === Row(1, 2, 3, 4, 5))
}
+
+ test("SPARK-4785 When called with arguments referring column fields, PMOD throws NPE") {
+ checkAnswer(
+ sql("SELECT PMOD(CAST(key as INT), 10) FROM src LIMIT 1"),
+ 8
+ )
+ }
test("hive struct udf") {
sql(
diff --git a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala
index 76f09cbcde..754ffc4220 100644
--- a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala
+++ b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala
@@ -43,6 +43,17 @@ import scala.language.implicitConversions
import org.apache.spark.sql.catalyst.types.DecimalType
+class HiveFunctionWrapper(var functionClassName: String) extends java.io.Serializable {
+ // for Serialization
+ def this() = this(null)
+
+ import org.apache.spark.util.Utils._
+ def createFunction[UDFType <: AnyRef](): UDFType = {
+ getContextOrSparkClassLoader
+ .loadClass(functionClassName).newInstance.asInstanceOf[UDFType]
+ }
+}
+
/**
* A compatibility layer for interacting with Hive version 0.12.0.
*/
diff --git a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala
index 91f7ceac21..7c8cbf10c1 100644
--- a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala
+++ b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.hive
import java.util.{ArrayList => JArrayList}
import java.util.Properties
+
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapred.InputFormat
@@ -42,6 +43,112 @@ import org.apache.spark.sql.catalyst.types.decimal.Decimal
import scala.collection.JavaConversions._
import scala.language.implicitConversions
+
+/**
+ * This class provides the UDF creation and also the UDF instance serialization and
+ * de-serialization cross process boundary.
+ *
+ * Detail discussion can be found at https://github.com/apache/spark/pull/3640
+ *
+ * @param functionClassName UDF class name
+ */
+class HiveFunctionWrapper(var functionClassName: String) extends java.io.Externalizable {
+ // for Serialization
+ def this() = this(null)
+
+ import java.io.{OutputStream, InputStream}
+ import com.esotericsoftware.kryo.Kryo
+ import org.apache.spark.util.Utils._
+ import org.apache.hadoop.hive.ql.exec.Utilities
+ import org.apache.hadoop.hive.ql.exec.UDF
+
+ @transient
+ private val methodDeSerialize = {
+ val method = classOf[Utilities].getDeclaredMethod(
+ "deserializeObjectByKryo",
+ classOf[Kryo],
+ classOf[InputStream],
+ classOf[Class[_]])
+ method.setAccessible(true)
+
+ method
+ }
+
+ @transient
+ private val methodSerialize = {
+ val method = classOf[Utilities].getDeclaredMethod(
+ "serializeObjectByKryo",
+ classOf[Kryo],
+ classOf[Object],
+ classOf[OutputStream])
+ method.setAccessible(true)
+
+ method
+ }
+
+ def deserializePlan[UDFType](is: java.io.InputStream, clazz: Class[_]): UDFType = {
+ methodDeSerialize.invoke(null, Utilities.runtimeSerializationKryo.get(), is, clazz)
+ .asInstanceOf[UDFType]
+ }
+
+ def serializePlan(function: AnyRef, out: java.io.OutputStream): Unit = {
+ methodSerialize.invoke(null, Utilities.runtimeSerializationKryo.get(), function, out)
+ }
+
+ private var instance: AnyRef = null
+
+ def writeExternal(out: java.io.ObjectOutput) {
+ // output the function name
+ out.writeUTF(functionClassName)
+
+ // Write a flag if instance is null or not
+ out.writeBoolean(instance != null)
+ if (instance != null) {
+ // Some of the UDF are serializable, but some others are not
+ // Hive Utilities can handle both cases
+ val baos = new java.io.ByteArrayOutputStream()
+ serializePlan(instance, baos)
+ val functionInBytes = baos.toByteArray
+
+ // output the function bytes
+ out.writeInt(functionInBytes.length)
+ out.write(functionInBytes, 0, functionInBytes.length)
+ }
+ }
+
+ def readExternal(in: java.io.ObjectInput) {
+ // read the function name
+ functionClassName = in.readUTF()
+
+ if (in.readBoolean()) {
+ // if the instance is not null
+ // read the function in bytes
+ val functionInBytesLength = in.readInt()
+ val functionInBytes = new Array[Byte](functionInBytesLength)
+ in.read(functionInBytes, 0, functionInBytesLength)
+
+ // deserialize the function object via Hive Utilities
+ instance = deserializePlan[AnyRef](new java.io.ByteArrayInputStream(functionInBytes),
+ getContextOrSparkClassLoader.loadClass(functionClassName))
+ }
+ }
+
+ def createFunction[UDFType <: AnyRef](): UDFType = {
+ if (instance != null) {
+ instance.asInstanceOf[UDFType]
+ } else {
+ val func = getContextOrSparkClassLoader
+ .loadClass(functionClassName).newInstance.asInstanceOf[UDFType]
+ if (!func.isInstanceOf[UDF]) {
+ // We cache the function if it's no the Simple UDF,
+ // as we always have to create new instance for Simple UDF
+ instance = func
+ }
+ func
+ }
+ }
+}
+
/**
* A compatibility layer for interacting with Hive version 0.13.1.
*/