aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala45
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala2
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala14
3 files changed, 56 insertions, 5 deletions
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
index 5903b9e71c..eb1895f263 100644
--- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
@@ -463,6 +463,51 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest {
assert(conf.get("spark.sql.hive.version") === Some("1.2.1"))
}
}
+
+ test("SPARK-11595 ADD JAR with input path having URL scheme") {
+ withJdbcStatement { statement =>
+ val jarPath = "../hive/src/test/resources/TestUDTF.jar"
+ val jarURL = s"file://${System.getProperty("user.dir")}/$jarPath"
+
+ Seq(
+ s"ADD JAR $jarURL",
+ s"""CREATE TEMPORARY FUNCTION udtf_count2
+ |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2'
+ """.stripMargin
+ ).foreach(statement.execute)
+
+ val rs1 = statement.executeQuery("DESCRIBE FUNCTION udtf_count2")
+
+ assert(rs1.next())
+ assert(rs1.getString(1) === "Function: udtf_count2")
+
+ assert(rs1.next())
+ assertResult("Class: org.apache.spark.sql.hive.execution.GenericUDTFCount2") {
+ rs1.getString(1)
+ }
+
+ assert(rs1.next())
+ assert(rs1.getString(1) === "Usage: To be added.")
+
+ val dataPath = "../hive/src/test/resources/data/files/kv1.txt"
+
+ Seq(
+ s"CREATE TABLE test_udtf(key INT, value STRING)",
+ s"LOAD DATA LOCAL INPATH '$dataPath' OVERWRITE INTO TABLE test_udtf"
+ ).foreach(statement.execute)
+
+ val rs2 = statement.executeQuery(
+ "SELECT key, cc FROM test_udtf LATERAL VIEW udtf_count2(value) dd AS cc")
+
+ assert(rs2.next())
+ assert(rs2.getInt(1) === 97)
+ assert(rs2.getInt(2) === 500)
+
+ assert(rs2.next())
+ assert(rs2.getInt(1) === 97)
+ assert(rs2.getInt(2) === 500)
+ }
+ }
}
class HiveThriftHttpServerSuite extends HiveThriftJdbcTest {
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index ba6204633b..0c473799cc 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -454,7 +454,7 @@ class HiveContext private[hive](
// Note that HiveUDFs will be overridden by functions registered in this context.
@transient
override protected[sql] lazy val functionRegistry: FunctionRegistry =
- new HiveFunctionRegistry(FunctionRegistry.builtin.copy()) {
+ new HiveFunctionRegistry(FunctionRegistry.builtin.copy(), this) {
override def lookupFunction(name: String, children: Seq[Expression]): Expression = {
// Hive Registry need current database to lookup function
// TODO: the current database of executionHive should be consistent with metadataHive
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index a9db70119d..e6fe2ad5f2 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -46,17 +46,23 @@ import org.apache.spark.sql.hive.HiveShim._
import org.apache.spark.sql.types._
-private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry)
+private[hive] class HiveFunctionRegistry(
+ underlying: analysis.FunctionRegistry,
+ hiveContext: HiveContext)
extends analysis.FunctionRegistry with HiveInspectors {
- def getFunctionInfo(name: String): FunctionInfo = FunctionRegistry.getFunctionInfo(name)
+ def getFunctionInfo(name: String): FunctionInfo = {
+ hiveContext.executionHive.withHiveState {
+ FunctionRegistry.getFunctionInfo(name)
+ }
+ }
override def lookupFunction(name: String, children: Seq[Expression]): Expression = {
Try(underlying.lookupFunction(name, children)).getOrElse {
// We only look it up to see if it exists, but do not include it in the HiveUDF since it is
// not always serializable.
val functionInfo: FunctionInfo =
- Option(FunctionRegistry.getFunctionInfo(name.toLowerCase)).getOrElse(
+ Option(getFunctionInfo(name.toLowerCase)).getOrElse(
throw new AnalysisException(s"undefined function $name"))
val functionClassName = functionInfo.getFunctionClass.getName
@@ -110,7 +116,7 @@ private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry)
override def lookupFunction(name: String): Option[ExpressionInfo] = {
underlying.lookupFunction(name).orElse(
Try {
- val info = FunctionRegistry.getFunctionInfo(name)
+ val info = getFunctionInfo(name)
val annotation = info.getFunctionClass.getAnnotation(classOf[Description])
if (annotation != null) {
Some(new ExpressionInfo(