diff options
author | Davies Liu <davies@databricks.com> | 2015-11-09 23:27:36 -0800 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2015-11-09 23:27:36 -0800 |
commit | d6cd3a18e720e8f6f1f307e0dffad3512952d997 (patch) | |
tree | ff96731005cb6019eb447134a8fce6892409e741 /sql | |
parent | c4e19b3819df4cd7a1c495a00bd2844cf55f4dbd (diff) | |
download | spark-d6cd3a18e720e8f6f1f307e0dffad3512952d997.tar.gz spark-d6cd3a18e720e8f6f1f307e0dffad3512952d997.tar.bz2 spark-d6cd3a18e720e8f6f1f307e0dffad3512952d997.zip |
[SPARK-11599] [SQL] fix NPE when resolve Hive UDF in SQLParser
The DataFrame APIs that takes a SQL expression always use SQLParser, then the HiveFunctionRegistry will called outside of Hive state, cause NPE if there is not a active Session State for current thread (in PySpark).
cc rxin yhuai
Author: Davies Liu <davies@databricks.com>
Closes #9576 from davies/hive_udf.
Diffstat (limited to 'sql')
-rw-r--r-- | sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala | 10 | ||||
-rw-r--r-- | sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala | 33 |
2 files changed, 34 insertions, 9 deletions
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 2d72b959af..c5f69657f5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -454,7 +454,15 @@ class HiveContext private[hive]( // Note that HiveUDFs will be overridden by functions registered in this context. @transient override protected[sql] lazy val functionRegistry: FunctionRegistry = - new HiveFunctionRegistry(FunctionRegistry.builtin.copy()) + new HiveFunctionRegistry(FunctionRegistry.builtin.copy()) { + override def lookupFunction(name: String, children: Seq[Expression]): Expression = { + // Hive Registry need current database to lookup function + // TODO: the current database of executionHive should be consistent with metadataHive + executionHive.withHiveState { + super.lookupFunction(name, children) + } + } + } // The Hive UDF current_database() is foldable, will be evaluated by optimizer, but the optimizer // can't access the SessionState of metadataHive. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 78378c8b69..f0a7a6cc7a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -20,22 +20,19 @@ package org.apache.spark.sql.hive.execution import java.io.File import java.util.{Locale, TimeZone} -import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoin - import scala.util.Try -import org.scalatest.BeforeAndAfter - import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.scalatest.BeforeAndAfter -import org.apache.spark.{SparkFiles, SparkException} -import org.apache.spark.sql.{AnalysisException, DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.Cast import org.apache.spark.sql.catalyst.plans.logical.Project +import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoin import org.apache.spark.sql.hive._ -import org.apache.spark.sql.hive.test.TestHiveContext -import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext} +import org.apache.spark.sql.{AnalysisException, DataFrame, Row} +import org.apache.spark.{SparkException, SparkFiles} case class TestData(a: Int, b: String) @@ -1237,6 +1234,26 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } + test("lookup hive UDF in another thread") { + val e = intercept[AnalysisException] { + range(1).selectExpr("not_a_udf()") + } + assert(e.getMessage.contains("undefined function not_a_udf")) + var success = false + val t = new Thread("test") { + override def run(): Unit = { + val e = intercept[AnalysisException] { + range(1).selectExpr("not_a_udf()") + } + assert(e.getMessage.contains("undefined function not_a_udf")) + success = true + } + } + t.start() + t.join() + assert(success) + } + createQueryTest("select from thrift based table", "SELECT * from src_thrift") |