aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--project/MimaExcludes.scala22
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala28
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala164
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala14
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala59
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala21
-rw-r--r--sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala76
-rw-r--r--sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala9
-rw-r--r--sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala5
-rw-r--r--sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala8
-rw-r--r--sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala76
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala155
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala28
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala9
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala85
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala107
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala27
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala27
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala13
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala6
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala32
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala9
22 files changed, 540 insertions, 440 deletions
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 2d4d146f51..08e4a449cf 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -79,7 +79,27 @@ object MimaExcludes {
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.ml.regression.LeastSquaresAggregator.add"),
ProblemFilters.exclude[MissingMethodProblem](
- "org.apache.spark.ml.regression.LeastSquaresCostFun.this")
+ "org.apache.spark.ml.regression.LeastSquaresCostFun.this"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.sql.SQLContext.clearLastInstantiatedContext"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.sql.SQLContext.setLastInstantiatedContext"),
+ ProblemFilters.exclude[MissingClassProblem](
+ "org.apache.spark.sql.SQLContext$SQLSession"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.sql.SQLContext.detachSession"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.sql.SQLContext.tlSession"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.sql.SQLContext.defaultSession"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.sql.SQLContext.currentSession"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.sql.SQLContext.openSession"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.sql.SQLContext.setSession"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.sql.SQLContext.createSession")
)
case v if v.startsWith("1.5") =>
Seq(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index e6122d92b7..ba77b70a37 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -51,23 +51,37 @@ class SimpleFunctionRegistry extends FunctionRegistry {
private val functionBuilders =
StringKeyHashMap[(ExpressionInfo, FunctionBuilder)](caseSensitive = false)
- override def registerFunction(name: String, info: ExpressionInfo, builder: FunctionBuilder)
- : Unit = {
+ override def registerFunction(
+ name: String,
+ info: ExpressionInfo,
+ builder: FunctionBuilder): Unit = synchronized {
functionBuilders.put(name, (info, builder))
}
override def lookupFunction(name: String, children: Seq[Expression]): Expression = {
- val func = functionBuilders.get(name).map(_._2).getOrElse {
- throw new AnalysisException(s"undefined function $name")
+ val func = synchronized {
+ functionBuilders.get(name).map(_._2).getOrElse {
+ throw new AnalysisException(s"undefined function $name")
+ }
}
func(children)
}
- override def listFunction(): Seq[String] = functionBuilders.iterator.map(_._1).toList.sorted
+ override def listFunction(): Seq[String] = synchronized {
+ functionBuilders.iterator.map(_._1).toList.sorted
+ }
- override def lookupFunction(name: String): Option[ExpressionInfo] = {
+ override def lookupFunction(name: String): Option[ExpressionInfo] = synchronized {
functionBuilders.get(name).map(_._1)
}
+
+ def copy(): SimpleFunctionRegistry = synchronized {
+ val registry = new SimpleFunctionRegistry
+ functionBuilders.iterator.foreach { case (name, (info, builder)) =>
+ registry.registerFunction(name, info, builder)
+ }
+ registry
+ }
}
/**
@@ -257,7 +271,7 @@ object FunctionRegistry {
expression[InputFileName]("input_file_name")
)
- val builtin: FunctionRegistry = {
+ val builtin: SimpleFunctionRegistry = {
val fr = new SimpleFunctionRegistry
expressions.foreach { case (name, (info, builder)) => fr.registerFunction(name, info, builder) }
fr
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index cb0a3e361c..2bdfd82af0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -30,6 +30,7 @@ import org.apache.spark.SparkContext
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.rdd.RDD
+import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
import org.apache.spark.sql.SQLConf.SQLConfEntry
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.errors.DialectException
@@ -38,15 +39,12 @@ import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, _}
-import org.apache.spark.sql.execution.{Filter, _}
-import org.apache.spark.sql.{execution => sparkexecution}
-import org.apache.spark.sql.execution._
-import org.apache.spark.sql.sources._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab}
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types._
+import org.apache.spark.sql.{execution => sparkexecution}
import org.apache.spark.util.Utils
/**
@@ -64,18 +62,30 @@ import org.apache.spark.util.Utils
*
* @since 1.0.0
*/
-class SQLContext(@transient val sparkContext: SparkContext)
- extends org.apache.spark.Logging
- with Serializable {
+class SQLContext private[sql](
+ @transient val sparkContext: SparkContext,
+ @transient protected[sql] val cacheManager: CacheManager)
+ extends org.apache.spark.Logging with Serializable {
self =>
+ def this(sparkContext: SparkContext) = this(sparkContext, new CacheManager)
def this(sparkContext: JavaSparkContext) = this(sparkContext.sc)
/**
+ * Returns a SQLContext as new session, with separated SQL configurations, temporary tables,
+ * registered functions, but sharing the same SparkContext and CacheManager.
+ *
+ * @since 1.6.0
+ */
+ def newSession(): SQLContext = {
+ new SQLContext(sparkContext, cacheManager)
+ }
+
+ /**
* @return Spark SQL configuration
*/
- protected[sql] def conf = currentSession().conf
+ protected[sql] lazy val conf = new SQLConf
// `listener` should be only used in the driver
@transient private[sql] val listener = new SQLListener(this)
@@ -142,13 +152,11 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
def getAllConfs: immutable.Map[String, String] = conf.getAllConfs
- // TODO how to handle the temp table per user session?
@transient
protected[sql] lazy val catalog: Catalog = new SimpleCatalog(conf)
- // TODO how to handle the temp function per user session?
@transient
- protected[sql] lazy val functionRegistry: FunctionRegistry = FunctionRegistry.builtin
+ protected[sql] lazy val functionRegistry: FunctionRegistry = FunctionRegistry.builtin.copy()
@transient
protected[sql] lazy val analyzer: Analyzer =
@@ -198,20 +206,19 @@ class SQLContext(@transient val sparkContext: SparkContext)
protected[sql] def executePlan(plan: LogicalPlan) =
new sparkexecution.QueryExecution(this, plan)
- @transient
- protected[sql] val tlSession = new ThreadLocal[SQLSession]() {
- override def initialValue: SQLSession = defaultSession
- }
-
- @transient
- protected[sql] val defaultSession = createSession()
-
protected[sql] def dialectClassName = if (conf.dialect == "sql") {
classOf[DefaultParserDialect].getCanonicalName
} else {
conf.dialect
}
+ /**
+ * Add a jar to SQLContext
+ */
+ protected[sql] def addJar(path: String): Unit = {
+ sparkContext.addJar(path)
+ }
+
{
// We extract spark sql settings from SparkContext's conf and put them to
// Spark SQL's conf.
@@ -236,9 +243,6 @@ class SQLContext(@transient val sparkContext: SparkContext)
}
}
- @transient
- protected[sql] val cacheManager = new CacheManager(this)
-
/**
* :: Experimental ::
* A collection of methods that are considered experimental, but can be used to hook into
@@ -300,21 +304,25 @@ class SQLContext(@transient val sparkContext: SparkContext)
* @group cachemgmt
* @since 1.3.0
*/
- def isCached(tableName: String): Boolean = cacheManager.isCached(tableName)
+ def isCached(tableName: String): Boolean = {
+ cacheManager.lookupCachedData(table(tableName)).nonEmpty
+ }
/**
* Caches the specified table in-memory.
* @group cachemgmt
* @since 1.3.0
*/
- def cacheTable(tableName: String): Unit = cacheManager.cacheTable(tableName)
+ def cacheTable(tableName: String): Unit = {
+ cacheManager.cacheQuery(table(tableName), Some(tableName))
+ }
/**
* Removes the specified table from the in-memory cache.
* @group cachemgmt
* @since 1.3.0
*/
- def uncacheTable(tableName: String): Unit = cacheManager.uncacheTable(tableName)
+ def uncacheTable(tableName: String): Unit = cacheManager.uncacheQuery(table(tableName))
/**
* Removes all cached tables from the in-memory cache.
@@ -830,36 +838,6 @@ class SQLContext(@transient val sparkContext: SparkContext)
)
}
- protected[sql] def openSession(): SQLSession = {
- detachSession()
- val session = createSession()
- tlSession.set(session)
-
- session
- }
-
- protected[sql] def currentSession(): SQLSession = {
- tlSession.get()
- }
-
- protected[sql] def createSession(): SQLSession = {
- new this.SQLSession()
- }
-
- protected[sql] def detachSession(): Unit = {
- tlSession.remove()
- }
-
- protected[sql] def setSession(session: SQLSession): Unit = {
- detachSession()
- tlSession.set(session)
- }
-
- protected[sql] class SQLSession {
- // Note that this is a lazy val so we can override the default value in subclasses.
- protected[sql] lazy val conf: SQLConf = new SQLConf
- }
-
@deprecated("use org.apache.spark.sql.QueryExecution", "1.6.0")
protected[sql] class QueryExecution(logical: LogicalPlan)
extends sparkexecution.QueryExecution(this, logical)
@@ -1196,46 +1174,90 @@ class SQLContext(@transient val sparkContext: SparkContext)
// Register a succesfully instantiatd context to the singleton. This should be at the end of
// the class definition so that the singleton is updated only if there is no exception in the
// construction of the instance.
- SQLContext.setLastInstantiatedContext(self)
+ sparkContext.addSparkListener(new SparkListener {
+ override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = {
+ SQLContext.clearInstantiatedContext(self)
+ }
+ })
+
+ SQLContext.setInstantiatedContext(self)
}
/**
* This SQLContext object contains utility functions to create a singleton SQLContext instance,
- * or to get the last created SQLContext instance.
+ * or to get the created SQLContext instance.
+ *
+ * It also provides utility functions to support preference for threads in multiple sessions
+ * scenario, setActive could set a SQLContext for current thread, which will be returned by
+ * getOrCreate instead of the global one.
*/
object SQLContext {
- private val INSTANTIATION_LOCK = new Object()
+ /**
+ * The active SQLContext for the current thread.
+ */
+ private val activeContext: InheritableThreadLocal[SQLContext] =
+ new InheritableThreadLocal[SQLContext]
/**
- * Reference to the last created SQLContext.
+ * Reference to the created SQLContext.
*/
- @transient private val lastInstantiatedContext = new AtomicReference[SQLContext]()
+ @transient private val instantiatedContext = new AtomicReference[SQLContext]()
/**
* Get the singleton SQLContext if it exists or create a new one using the given SparkContext.
+ *
* This function can be used to create a singleton SQLContext object that can be shared across
* the JVM.
+ *
+ * If there is an active SQLContext for current thread, it will be returned instead of the global
+ * one.
+ *
+ * @since 1.5.0
*/
def getOrCreate(sparkContext: SparkContext): SQLContext = {
- INSTANTIATION_LOCK.synchronized {
- if (lastInstantiatedContext.get() == null) {
+ val ctx = activeContext.get()
+ if (ctx != null) {
+ return ctx
+ }
+
+ synchronized {
+ val ctx = instantiatedContext.get()
+ if (ctx == null) {
new SQLContext(sparkContext)
+ } else {
+ ctx
}
}
- lastInstantiatedContext.get()
}
- private[sql] def clearLastInstantiatedContext(): Unit = {
- INSTANTIATION_LOCK.synchronized {
- lastInstantiatedContext.set(null)
- }
+ private[sql] def clearInstantiatedContext(sqlContext: SQLContext): Unit = {
+ instantiatedContext.compareAndSet(sqlContext, null)
}
- private[sql] def setLastInstantiatedContext(sqlContext: SQLContext): Unit = {
- INSTANTIATION_LOCK.synchronized {
- lastInstantiatedContext.set(sqlContext)
- }
+ private[sql] def setInstantiatedContext(sqlContext: SQLContext): Unit = {
+ instantiatedContext.compareAndSet(null, sqlContext)
+ }
+
+ /**
+ * Changes the SQLContext that will be returned in this thread and its children when
+ * SQLContext.getOrCreate() is called. This can be used to ensure that a given thread receives
+ * a SQLContext with an isolated session, instead of the global (first created) context.
+ *
+ * @since 1.6.0
+ */
+ def setActive(sqlContext: SQLContext): Unit = {
+ activeContext.set(sqlContext)
+ }
+
+ /**
+ * Clears the active SQLContext for current thread. Subsequent calls to getOrCreate will
+ * return the first created context instead of a thread-local override.
+ *
+ * @since 1.6.0
+ */
+ def clearActive(): Unit = {
+ activeContext.remove()
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
index d3e5c378d0..f85aeb1b02 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
@@ -20,9 +20,9 @@ package org.apache.spark.sql.execution
import java.util.concurrent.locks.ReentrantReadWriteLock
import org.apache.spark.Logging
+import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.columnar.InMemoryRelation
-import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK
@@ -37,7 +37,7 @@ private[sql] case class CachedData(plan: LogicalPlan, cachedRepresentation: InMe
*
* Internal to Spark SQL.
*/
-private[sql] class CacheManager(sqlContext: SQLContext) extends Logging {
+private[sql] class CacheManager extends Logging {
@transient
private val cachedData = new scala.collection.mutable.ArrayBuffer[CachedData]
@@ -45,15 +45,6 @@ private[sql] class CacheManager(sqlContext: SQLContext) extends Logging {
@transient
private val cacheLock = new ReentrantReadWriteLock
- /** Returns true if the table is currently cached in-memory. */
- def isCached(tableName: String): Boolean = lookupCachedData(sqlContext.table(tableName)).nonEmpty
-
- /** Caches the specified table in-memory. */
- def cacheTable(tableName: String): Unit = cacheQuery(sqlContext.table(tableName), Some(tableName))
-
- /** Removes the specified table from the in-memory cache. */
- def uncacheTable(tableName: String): Unit = uncacheQuery(sqlContext.table(tableName))
-
/** Acquires a read lock on the cache for the duration of `f`. */
private def readLock[A](f: => A): A = {
val lock = cacheLock.readLock()
@@ -96,6 +87,7 @@ private[sql] class CacheManager(sqlContext: SQLContext) extends Logging {
if (lookupCachedData(planToCache).nonEmpty) {
logWarning("Asked to cache already cached data.")
} else {
+ val sqlContext = query.sqlContext
cachedData +=
CachedData(
planToCache,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala
index dd88ae3700..1994dacfc4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala
@@ -17,33 +17,52 @@
package org.apache.spark.sql
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.{SharedSparkContext, SparkFunSuite}
-class SQLContextSuite extends SparkFunSuite with SharedSQLContext {
-
- override def afterAll(): Unit = {
- try {
- SQLContext.setLastInstantiatedContext(sqlContext)
- } finally {
- super.afterAll()
- }
- }
+class SQLContextSuite extends SparkFunSuite with SharedSparkContext{
test("getOrCreate instantiates SQLContext") {
- SQLContext.clearLastInstantiatedContext()
- val sqlContext = SQLContext.getOrCreate(sparkContext)
+ val sqlContext = SQLContext.getOrCreate(sc)
assert(sqlContext != null, "SQLContext.getOrCreate returned null")
- assert(SQLContext.getOrCreate(sparkContext).eq(sqlContext),
+ assert(SQLContext.getOrCreate(sc).eq(sqlContext),
"SQLContext created by SQLContext.getOrCreate not returned by SQLContext.getOrCreate")
}
- test("getOrCreate gets last explicitly instantiated SQLContext") {
- SQLContext.clearLastInstantiatedContext()
- val sqlContext = new SQLContext(sparkContext)
- assert(SQLContext.getOrCreate(sparkContext) != null,
- "SQLContext.getOrCreate after explicitly created SQLContext returned null")
- assert(SQLContext.getOrCreate(sparkContext).eq(sqlContext),
+ test("getOrCreate return the original SQLContext") {
+ val sqlContext = SQLContext.getOrCreate(sc)
+ val newSession = sqlContext.newSession()
+ assert(SQLContext.getOrCreate(sc).eq(sqlContext),
"SQLContext.getOrCreate after explicitly created SQLContext did not return the context")
+ SQLContext.setActive(newSession)
+ assert(SQLContext.getOrCreate(sc).eq(newSession),
+ "SQLContext.getOrCreate after explicitly setActive() did not return the active context")
+ }
+
+ test("Sessions of SQLContext") {
+ val sqlContext = SQLContext.getOrCreate(sc)
+ val session1 = sqlContext.newSession()
+ val session2 = sqlContext.newSession()
+
+ // all have the default configurations
+ val key = SQLConf.SHUFFLE_PARTITIONS.key
+ assert(session1.getConf(key) === session2.getConf(key))
+ session1.setConf(key, "1")
+ session2.setConf(key, "2")
+ assert(session1.getConf(key) === "1")
+ assert(session2.getConf(key) === "2")
+
+ // temporary table should not be shared
+ val df = session1.range(10)
+ df.registerTempTable("test1")
+ assert(session1.tableNames().contains("test1"))
+ assert(!session2.tableNames().contains("test1"))
+
+ // UDF should not be shared
+ def myadd(a: Int, b: Int): Int = a + b
+ session1.udf.register[Int, Int, Int]("myadd", myadd)
+ session1.sql("select myadd(1, 2)").explain()
+ intercept[AnalysisException] {
+ session2.sql("select myadd(1, 2)").explain()
+ }
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
index 10e633f3cd..c89a151650 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
@@ -31,23 +31,16 @@ private[sql] class TestSQLContext(sc: SparkContext) extends SQLContext(sc) { sel
new SparkConf().set("spark.sql.testkey", "true")))
}
- // Make sure we set those test specific confs correctly when we create
- // the SQLConf as well as when we call clear.
- protected[sql] override def createSession(): SQLSession = new this.SQLSession()
+ protected[sql] override lazy val conf: SQLConf = new SQLConf {
- /** A special [[SQLSession]] that uses fewer shuffle partitions than normal. */
- protected[sql] class SQLSession extends super.SQLSession {
- protected[sql] override lazy val conf: SQLConf = new SQLConf {
+ clear()
- clear()
+ override def clear(): Unit = {
+ super.clear()
- override def clear(): Unit = {
- super.clear()
-
- // Make sure we start with the default test configs even after clear
- TestSQLContext.overrideConfs.map {
- case (key, value) => setConfString(key, value)
- }
+ // Make sure we start with the default test configs even after clear
+ TestSQLContext.overrideConfs.map {
+ case (key, value) => setConfString(key, value)
}
}
}
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala
index 306f98bcb5..719b03e1c7 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala
@@ -20,19 +20,15 @@ package org.apache.spark.sql.hive.thriftserver
import java.security.PrivilegedExceptionAction
import java.sql.{Date, Timestamp}
import java.util.concurrent.RejectedExecutionException
-import java.util.{Arrays, Map => JMap, UUID}
+import java.util.{Arrays, UUID, Map => JMap}
import scala.collection.JavaConverters._
import scala.collection.mutable.{ArrayBuffer, Map => SMap}
import scala.util.control.NonFatal
-import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.metastore.api.FieldSchema
-import org.apache.hive.service.cli._
-import org.apache.hadoop.hive.ql.metadata.Hive
-import org.apache.hadoop.hive.ql.metadata.HiveException
-import org.apache.hadoop.hive.ql.session.SessionState
import org.apache.hadoop.hive.shims.Utils
+import org.apache.hive.service.cli._
import org.apache.hive.service.cli.operation.ExecuteStatementOperation
import org.apache.hive.service.cli.session.HiveSession
@@ -40,7 +36,7 @@ import org.apache.spark.Logging
import org.apache.spark.sql.execution.SetCommand
import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes}
import org.apache.spark.sql.types._
-import org.apache.spark.sql.{DataFrame, Row => SparkRow, SQLConf}
+import org.apache.spark.sql.{DataFrame, SQLConf, Row => SparkRow}
private[hive] class SparkExecuteStatementOperation(
@@ -143,30 +139,15 @@ private[hive] class SparkExecuteStatementOperation(
if (!runInBackground) {
runInternal()
} else {
- val parentSessionState = SessionState.get()
- val hiveConf = getConfigForOperation()
val sparkServiceUGI = Utils.getUGI()
- val sessionHive = getCurrentHive()
- val currentSqlSession = hiveContext.currentSession
// Runnable impl to call runInternal asynchronously,
// from a different thread
val backgroundOperation = new Runnable() {
override def run(): Unit = {
- val doAsAction = new PrivilegedExceptionAction[Object]() {
- override def run(): Object = {
-
- // User information is part of the metastore client member in Hive
- hiveContext.setSession(currentSqlSession)
- // Always use the latest class loader provided by executionHive's state.
- val executionHiveClassLoader =
- hiveContext.executionHive.state.getConf.getClassLoader
- sessionHive.getConf.setClassLoader(executionHiveClassLoader)
- parentSessionState.getConf.setClassLoader(executionHiveClassLoader)
-
- Hive.set(sessionHive)
- SessionState.setCurrentSessionState(parentSessionState)
+ val doAsAction = new PrivilegedExceptionAction[Unit]() {
+ override def run(): Unit = {
try {
runInternal()
} catch {
@@ -174,7 +155,6 @@ private[hive] class SparkExecuteStatementOperation(
setOperationException(e)
log.error("Error running hive query: ", e)
}
- return null
}
}
@@ -191,7 +171,7 @@ private[hive] class SparkExecuteStatementOperation(
try {
// This submit blocks if no background threads are available to run this operation
val backgroundHandle =
- getParentSession().getSessionManager().submitBackgroundOperation(backgroundOperation)
+ parentSession.getSessionManager().submitBackgroundOperation(backgroundOperation)
setBackgroundHandle(backgroundHandle)
} catch {
case rejected: RejectedExecutionException =>
@@ -210,6 +190,11 @@ private[hive] class SparkExecuteStatementOperation(
statementId = UUID.randomUUID().toString
logInfo(s"Running query '$statement' with $statementId")
setState(OperationState.RUNNING)
+ // Always use the latest class loader provided by executionHive's state.
+ val executionHiveClassLoader =
+ hiveContext.executionHive.state.getConf.getClassLoader
+ Thread.currentThread().setContextClassLoader(executionHiveClassLoader)
+
HiveThriftServer2.listener.onStatementStart(
statementId,
parentSession.getSessionHandle.getSessionId.toString,
@@ -279,43 +264,4 @@ private[hive] class SparkExecuteStatementOperation(
}
}
}
-
- /**
- * If there are query specific settings to overlay, then create a copy of config
- * There are two cases we need to clone the session config that's being passed to hive driver
- * 1. Async query -
- * If the client changes a config setting, that shouldn't reflect in the execution
- * already underway
- * 2. confOverlay -
- * The query specific settings should only be applied to the query config and not session
- * @return new configuration
- * @throws HiveSQLException
- */
- private def getConfigForOperation(): HiveConf = {
- var sqlOperationConf = getParentSession().getHiveConf()
- if (!getConfOverlay().isEmpty() || runInBackground) {
- // clone the partent session config for this query
- sqlOperationConf = new HiveConf(sqlOperationConf)
-
- // apply overlay query specific settings, if any
- getConfOverlay().asScala.foreach { case (k, v) =>
- try {
- sqlOperationConf.verifyAndSet(k, v)
- } catch {
- case e: IllegalArgumentException =>
- throw new HiveSQLException("Error applying statement specific settings", e)
- }
- }
- }
- return sqlOperationConf
- }
-
- private def getCurrentHive(): Hive = {
- try {
- return Hive.get()
- } catch {
- case e: HiveException =>
- throw new HiveSQLException("Failed to get current Hive object", e);
- }
- }
}
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala
index 92ac0ec3fc..33aaead3fb 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala
@@ -36,7 +36,7 @@ private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, hiveContext:
extends SessionManager(hiveServer)
with ReflectedCompositeService {
- private lazy val sparkSqlOperationManager = new SparkSQLOperationManager(hiveContext)
+ private lazy val sparkSqlOperationManager = new SparkSQLOperationManager()
override def init(hiveConf: HiveConf) {
setSuperField(this, "hiveConf", hiveConf)
@@ -60,13 +60,15 @@ private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, hiveContext:
sessionConf: java.util.Map[String, String],
withImpersonation: Boolean,
delegationToken: String): SessionHandle = {
- hiveContext.openSession()
val sessionHandle =
super.openSession(protocol, username, passwd, ipAddress, sessionConf, withImpersonation,
delegationToken)
val session = super.getSession(sessionHandle)
HiveThriftServer2.listener.onSessionCreated(
session.getIpAddress, sessionHandle.getSessionId.toString, session.getUsername)
+ val ctx = hiveContext.newSession()
+ ctx.setConf("spark.sql.hive.version", HiveContext.hiveExecutionVersion)
+ sparkSqlOperationManager.sessionToContexts += sessionHandle -> ctx
sessionHandle
}
@@ -74,7 +76,6 @@ private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, hiveContext:
HiveThriftServer2.listener.onSessionClosed(sessionHandle.getSessionId.toString)
super.closeSession(sessionHandle)
sparkSqlOperationManager.sessionToActivePool -= sessionHandle
-
- hiveContext.detachSession()
+ sparkSqlOperationManager.sessionToContexts.remove(sessionHandle)
}
}
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
index c8031ed0f3..476651a559 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
@@ -30,20 +30,21 @@ import org.apache.spark.sql.hive.thriftserver.{SparkExecuteStatementOperation, R
/**
* Executes queries using Spark SQL, and maintains a list of handles to active queries.
*/
-private[thriftserver] class SparkSQLOperationManager(hiveContext: HiveContext)
+private[thriftserver] class SparkSQLOperationManager()
extends OperationManager with Logging {
val handleToOperation = ReflectionUtils
.getSuperField[JMap[OperationHandle, Operation]](this, "handleToOperation")
val sessionToActivePool = Map[SessionHandle, String]()
+ val sessionToContexts = Map[SessionHandle, HiveContext]()
override def newExecuteStatementOperation(
parentSession: HiveSession,
statement: String,
confOverlay: JMap[String, String],
async: Boolean): ExecuteStatementOperation = synchronized {
-
+ val hiveContext = sessionToContexts(parentSession.getSessionHandle)
val runInBackground = async && hiveContext.hiveThriftServerAsync
val operation = new SparkExecuteStatementOperation(parentSession, statement, confOverlay,
runInBackground)(hiveContext, sessionToActivePool)
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
index e59a14ec00..76d1591a23 100644
--- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
@@ -96,7 +96,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging {
buffer += s"${new Timestamp(new Date().getTime)} - $source> $line"
// If we haven't found all expected answers and another expected answer comes up...
- if (next < expectedAnswers.size && line.startsWith(expectedAnswers(next))) {
+ if (next < expectedAnswers.size && line.contains(expectedAnswers(next))) {
next += 1
// If all expected answers have been found...
if (next == expectedAnswers.size) {
@@ -159,7 +159,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging {
s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE hive_test;"
-> "OK",
"CACHE TABLE hive_test;"
- -> "Time taken: ",
+ -> "",
"SELECT COUNT(*) FROM hive_test;"
-> "5",
"DROP TABLE hive_test;"
@@ -180,7 +180,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging {
"CREATE TABLE hive_test(key INT, val STRING);"
-> "OK",
"SHOW TABLES;"
- -> "Time taken: "
+ -> "hive_test"
)
runCliWithin(2.minute, Seq("--database", "hive_test_db", "-e", "SHOW TABLES;"))(
@@ -210,7 +210,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging {
s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE sourceTable;"
-> "OK",
"INSERT INTO TABLE t1 SELECT key, val FROM sourceTable;"
- -> "Time taken:",
+ -> "",
"SELECT count(key) FROM t1;"
-> "5",
"DROP TABLE t1;"
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
index 19b2f24456..ff8ca01506 100644
--- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
@@ -205,6 +205,7 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest {
import org.apache.spark.sql.SQLConf
var defaultV1: String = null
var defaultV2: String = null
+ var data: ArrayBuffer[Int] = null
withMultipleConnectionJdbcStatement(
// create table
@@ -214,10 +215,16 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest {
"DROP TABLE IF EXISTS test_map",
"CREATE TABLE test_map(key INT, value STRING)",
s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_map",
- "CACHE TABLE test_table AS SELECT key FROM test_map ORDER BY key DESC")
+ "CACHE TABLE test_table AS SELECT key FROM test_map ORDER BY key DESC",
+ "CREATE DATABASE db1")
queries.foreach(statement.execute)
+ val plan = statement.executeQuery("explain select * from test_table")
+ plan.next()
+ plan.next()
+ assert(plan.getString(1).contains("InMemoryColumnarTableScan"))
+
val rs1 = statement.executeQuery("SELECT key FROM test_table ORDER BY KEY DESC")
val buf1 = new collection.mutable.ArrayBuffer[Int]()
while (rs1.next()) {
@@ -233,6 +240,8 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest {
rs2.close()
assert(buf1 === buf2)
+
+ data = buf1
},
// first session, we get the default value of the session status
@@ -289,56 +298,51 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest {
rs2.close()
},
- // accessing the cached data in another session
+ // try to access the cached data in another session
{ statement =>
- val rs1 = statement.executeQuery("SELECT key FROM test_table ORDER BY KEY DESC")
- val buf1 = new collection.mutable.ArrayBuffer[Int]()
- while (rs1.next()) {
- buf1 += rs1.getInt(1)
+ // Cached temporary table can't be accessed by other sessions
+ intercept[SQLException] {
+ statement.executeQuery("SELECT key FROM test_table ORDER BY KEY DESC")
}
- rs1.close()
- val rs2 = statement.executeQuery("SELECT key FROM test_map ORDER BY KEY DESC")
- val buf2 = new collection.mutable.ArrayBuffer[Int]()
- while (rs2.next()) {
- buf2 += rs2.getInt(1)
+ val plan = statement.executeQuery("explain select key from test_map ORDER BY key DESC")
+ plan.next()
+ plan.next()
+ assert(plan.getString(1).contains("InMemoryColumnarTableScan"))
+
+ val rs = statement.executeQuery("SELECT key FROM test_map ORDER BY KEY DESC")
+ val buf = new collection.mutable.ArrayBuffer[Int]()
+ while (rs.next()) {
+ buf += rs.getInt(1)
}
- rs2.close()
+ rs.close()
+ assert(buf === data)
+ },
- assert(buf1 === buf2)
- statement.executeQuery("UNCACHE TABLE test_table")
+ // switch another database
+ { statement =>
+ statement.execute("USE db1")
- // TODO need to figure out how to determine if the data loaded from cache
- val rs3 = statement.executeQuery("SELECT key FROM test_map ORDER BY KEY DESC")
- val buf3 = new collection.mutable.ArrayBuffer[Int]()
- while (rs3.next()) {
- buf3 += rs3.getInt(1)
+ // there is no test_map table in db1
+ intercept[SQLException] {
+ statement.executeQuery("SELECT key FROM test_map ORDER BY KEY DESC")
}
- rs3.close()
- assert(buf1 === buf3)
+ statement.execute("CREATE TABLE test_map2(key INT, value STRING)")
},
- // accessing the uncached table
+ // access default database
{ statement =>
- // TODO need to figure out how to determine if the data loaded from cache
- val rs1 = statement.executeQuery("SELECT key FROM test_table ORDER BY KEY DESC")
- val buf1 = new collection.mutable.ArrayBuffer[Int]()
- while (rs1.next()) {
- buf1 += rs1.getInt(1)
- }
- rs1.close()
-
- val rs2 = statement.executeQuery("SELECT key FROM test_map ORDER BY KEY DESC")
- val buf2 = new collection.mutable.ArrayBuffer[Int]()
- while (rs2.next()) {
- buf2 += rs2.getInt(1)
+ // current database should still be `default`
+ intercept[SQLException] {
+ statement.executeQuery("SELECT key FROM test_map2")
}
- rs2.close()
- assert(buf1 === buf2)
+ statement.execute("USE db1")
+ // access test_map2
+ statement.executeQuery("SELECT key from test_map2")
}
)
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index 17de8ef56f..dad1e2347c 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -25,7 +25,6 @@ import java.util.concurrent.TimeUnit
import scala.collection.JavaConverters._
import scala.collection.mutable.HashMap
import scala.language.implicitConversions
-import scala.concurrent.duration._
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.hive.common.StatsSetupConst
@@ -34,32 +33,49 @@ import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.conf.HiveConf.ConfVars
import org.apache.hadoop.hive.ql.metadata.Table
import org.apache.hadoop.hive.ql.parse.VariableSubstitution
-import org.apache.hadoop.hive.ql.session.SessionState
import org.apache.hadoop.hive.serde2.io.{DateWritable, TimestampWritable}
-import org.apache.spark.Logging
-import org.apache.spark.SparkContext
import org.apache.spark.annotation.Experimental
-import org.apache.spark.sql._
+import org.apache.spark.api.java.JavaSparkContext
import org.apache.spark.sql.SQLConf.SQLConfEntry
import org.apache.spark.sql.SQLConf.SQLConfEntry._
-import org.apache.spark.sql.catalyst.{SqlParser, TableIdentifier, ParserDialect}
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUDFs, SetCommand}
-import org.apache.spark.sql.execution.datasources.{PreWriteCheck, PreInsertCastAndRename, DataSourceStrategy}
+import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, SqlParser}
+import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, PreInsertCastAndRename, PreWriteCheck}
+import org.apache.spark.sql.execution.{CacheManager, ExecutedCommand, ExtractPythonUDFs, SetCommand}
import org.apache.spark.sql.hive.client._
import org.apache.spark.sql.hive.execution.{DescribeHiveTableCommand, HiveNativeCommand}
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
+import org.apache.spark.{Logging, SparkContext}
/**
* This is the HiveQL Dialect, this dialect is strongly bind with HiveContext
*/
-private[hive] class HiveQLDialect extends ParserDialect {
+private[hive] class HiveQLDialect(sqlContext: HiveContext) extends ParserDialect {
override def parse(sqlText: String): LogicalPlan = {
- HiveQl.parseSql(sqlText)
+ sqlContext.executionHive.withHiveState {
+ HiveQl.parseSql(sqlText)
+ }
+ }
+}
+
+/**
+ * Returns the current database of metadataHive.
+ */
+private[hive] case class CurrentDatabase(ctx: HiveContext)
+ extends LeafExpression with CodegenFallback {
+ override def dataType: DataType = StringType
+ override def foldable: Boolean = true
+ override def nullable: Boolean = false
+ override def eval(input: InternalRow): Any = {
+ UTF8String.fromString(ctx.metadataHive.currentDatabase)
}
}
@@ -69,14 +85,30 @@ private[hive] class HiveQLDialect extends ParserDialect {
*
* @since 1.0.0
*/
-class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging {
+class HiveContext private[hive](
+ sc: SparkContext,
+ cacheManager: CacheManager,
+ @transient execHive: ClientWrapper,
+ @transient metaHive: ClientInterface) extends SQLContext(sc, cacheManager) with Logging {
self =>
- import HiveContext._
+ def this(sc: SparkContext) = this(sc, new CacheManager, null, null)
+ def this(sc: JavaSparkContext) = this(sc.sc)
+
+ import org.apache.spark.sql.hive.HiveContext._
logDebug("create HiveContext")
/**
+ * Returns a new HiveContext as new session, which will have separated SQLConf, UDF/UDAF,
+ * temporary tables and SessionState, but sharing the same CacheManager, IsolatedClientLoader
+ * and Hive client (both of execution and metadata) with existing HiveContext.
+ */
+ override def newSession(): HiveContext = {
+ new HiveContext(sc, cacheManager, executionHive.newSession(), metadataHive.newSession())
+ }
+
+ /**
* When true, enables an experimental feature where metastore tables that use the parquet SerDe
* are automatically converted to use the Spark SQL parquet table scan, instead of the Hive
* SerDe.
@@ -157,14 +189,18 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging {
* for storing persistent metadata, and only point to a dummy metastore in a temporary directory.
*/
@transient
- protected[hive] lazy val executionHive: ClientWrapper = {
+ protected[hive] lazy val executionHive: ClientWrapper = if (execHive != null) {
+ execHive
+ } else {
logInfo(s"Initializing execution hive, version $hiveExecutionVersion")
- new ClientWrapper(
+ val loader = new IsolatedClientLoader(
version = IsolatedClientLoader.hiveVersion(hiveExecutionVersion),
+ execJars = Seq(),
config = newTemporaryConfiguration(),
- initClassLoader = Utils.getContextOrSparkClassLoader)
+ isolationOn = false,
+ baseClassLoader = Utils.getContextOrSparkClassLoader)
+ loader.createClient().asInstanceOf[ClientWrapper]
}
- SessionState.setCurrentSessionState(executionHive.state)
/**
* Overrides default Hive configurations to avoid breaking changes to Spark SQL users.
@@ -182,7 +218,9 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging {
* in the hive-site.xml file.
*/
@transient
- protected[hive] lazy val metadataHive: ClientInterface = {
+ protected[hive] lazy val metadataHive: ClientInterface = if (metaHive != null) {
+ metaHive
+ } else {
val metaVersion = IsolatedClientLoader.hiveVersion(hiveMetastoreVersion)
// We instantiate a HiveConf here to read in the hive-site.xml file and then pass the options
@@ -268,14 +306,10 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging {
barrierPrefixes = hiveMetastoreBarrierPrefixes,
sharedPrefixes = hiveMetastoreSharedPrefixes)
}
- isolatedLoader.client
+ isolatedLoader.createClient()
}
protected[sql] override def parseSql(sql: String): LogicalPlan = {
- var state = SessionState.get()
- if (state == null) {
- SessionState.setCurrentSessionState(tlSession.get().asInstanceOf[SQLSession].sessionState)
- }
super.parseSql(substitutor.substitute(hiveconf, sql))
}
@@ -384,8 +418,6 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging {
}
}
- protected[hive] def hiveconf = tlSession.get().asInstanceOf[this.SQLSession].hiveconf
-
override def setConf(key: String, value: String): Unit = {
super.setConf(key, value)
executionHive.runSqlHive(s"SET $key=$value")
@@ -402,7 +434,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging {
setConf(entry.key, entry.stringConverter(value))
}
- /* A catalyst metadata catalog that points to the Hive Metastore. */
+ /* A catalyst metadata catalog that points to the Hive Metastore. */
@transient
override protected[sql] lazy val catalog =
new HiveMetastoreCatalog(metadataHive, this) with OverrideCatalog
@@ -410,7 +442,13 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging {
// Note that HiveUDFs will be overridden by functions registered in this context.
@transient
override protected[sql] lazy val functionRegistry: FunctionRegistry =
- new HiveFunctionRegistry(FunctionRegistry.builtin)
+ new HiveFunctionRegistry(FunctionRegistry.builtin.copy())
+
+ // The Hive UDF current_database() is foldable, will be evaluated by optimizer, but the optimizer
+ // can't access the SessionState of metadataHive.
+ functionRegistry.registerFunction(
+ "current_database",
+ (expressions: Seq[Expression]) => new CurrentDatabase(this))
/* An analyzer that uses the Hive metastore. */
@transient
@@ -430,10 +468,6 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging {
)
}
- override protected[sql] def createSession(): SQLSession = {
- new this.SQLSession()
- }
-
/** Overridden by child classes that need to set configuration before the client init. */
protected def configure(): Map[String, String] = {
// Hive 0.14.0 introduces timeout operations in HiveConf, and changes default values of a bunch
@@ -488,41 +522,40 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging {
}.toMap
}
- protected[hive] class SQLSession extends super.SQLSession {
- protected[sql] override lazy val conf: SQLConf = new SQLConf {
- override def dialect: String = getConf(SQLConf.DIALECT, "hiveql")
- override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false)
- }
-
- /**
- * SQLConf and HiveConf contracts:
- *
- * 1. reuse existing started SessionState if any
- * 2. when the Hive session is first initialized, params in HiveConf will get picked up by the
- * SQLConf. Additionally, any properties set by set() or a SET command inside sql() will be
- * set in the SQLConf *as well as* in the HiveConf.
- */
- protected[hive] lazy val sessionState: SessionState = {
- var state = SessionState.get()
- if (state == null) {
- state = new SessionState(new HiveConf(classOf[SessionState]))
- SessionState.start(state)
- }
- state
- }
+ /**
+ * SQLConf and HiveConf contracts:
+ *
+ * 1. create a new SessionState for each HiveContext
+ * 2. when the Hive session is first initialized, params in HiveConf will get picked up by the
+ * SQLConf. Additionally, any properties set by set() or a SET command inside sql() will be
+ * set in the SQLConf *as well as* in the HiveConf.
+ */
+ @transient
+ protected[hive] lazy val hiveconf: HiveConf = {
+ val c = executionHive.conf
+ setConf(c.getAllProperties)
+ c
+ }
- protected[hive] lazy val hiveconf: HiveConf = {
- setConf(sessionState.getConf.getAllProperties)
- sessionState.getConf
- }
+ protected[sql] override lazy val conf: SQLConf = new SQLConf {
+ override def dialect: String = getConf(SQLConf.DIALECT, "hiveql")
+ override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false)
}
- override protected[sql] def dialectClassName = if (conf.dialect == "hiveql") {
+ protected[sql] override def dialectClassName = if (conf.dialect == "hiveql") {
classOf[HiveQLDialect].getCanonicalName
} else {
super.dialectClassName
}
+ protected[sql] override def getSQLDialect(): ParserDialect = {
+ if (conf.dialect == "hiveql") {
+ new HiveQLDialect(this)
+ } else {
+ super.getSQLDialect()
+ }
+ }
+
@transient
private val hivePlanner = new SparkPlanner with HiveStrategies {
val hiveContext = self
@@ -598,6 +631,14 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging {
case _ => super.simpleString
}
}
+
+ protected[sql] override def addJar(path: String): Unit = {
+ // Add jar to Hive and classloader
+ executionHive.addJar(path)
+ metadataHive.addJar(path)
+ Thread.currentThread().setContextClassLoader(executionHive.clientLoader.classLoader)
+ super.addJar(path)
+ }
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index 2bf22f5449..250c232856 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -25,29 +25,27 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.conf.HiveConf.ConfVars
-import org.apache.hadoop.hive.serde.serdeConstants
-import org.apache.hadoop.hive.ql.{ErrorMsg, Context}
-import org.apache.hadoop.hive.ql.exec.{FunctionRegistry, FunctionInfo}
+import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry}
import org.apache.hadoop.hive.ql.lib.Node
import org.apache.hadoop.hive.ql.parse._
import org.apache.hadoop.hive.ql.plan.PlanUtils
import org.apache.hadoop.hive.ql.session.SessionState
+import org.apache.hadoop.hive.ql.{Context, ErrorMsg}
+import org.apache.hadoop.hive.serde.serdeConstants
import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe
import org.apache.spark.Logging
-import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst
+import org.apache.spark.sql.{AnalysisException, catalyst}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical
+import org.apache.spark.sql.catalyst.plans.{logical, _}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
import org.apache.spark.sql.execution.ExplainCommand
import org.apache.spark.sql.execution.datasources.DescribeCommand
import org.apache.spark.sql.hive.HiveShim._
import org.apache.spark.sql.hive.client._
-import org.apache.spark.sql.hive.execution.{HiveNativeCommand, DropTable, AnalyzeTable, HiveScriptIOSchema}
+import org.apache.spark.sql.hive.execution.{AnalyzeTable, DropTable, HiveNativeCommand, HiveScriptIOSchema}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
import org.apache.spark.util.random.RandomSampler
@@ -268,7 +266,7 @@ private[hive] object HiveQl extends Logging {
node
}
- private def createContext(): Context = new Context(SessionState.get().getConf())
+ private def createContext(): Context = new Context(hiveConf)
private def getAst(sql: String, context: Context) =
ParseUtils.findRootNonNullToken((new ParseDriver).parse(sql, context))
@@ -277,12 +275,16 @@ private[hive] object HiveQl extends Logging {
* Returns the HiveConf
*/
private[this] def hiveConf: HiveConf = {
- val ss = SessionState.get() // SessionState is lazy initialization, it can be null here
+ var ss = SessionState.get()
+ // SessionState is lazy initialization, it can be null here
if (ss == null) {
- new HiveConf()
- } else {
- ss.getConf
+ val original = Thread.currentThread().getContextClassLoader
+ val conf = new HiveConf(classOf[SessionState])
+ conf.setClassLoader(original)
+ ss = new SessionState(conf)
+ SessionState.start(ss)
}
+ ss.getConf
}
/** Returns a LogicalPlan for a given HiveQL string. */
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala
index 915eae9d21..9d9a55edd7 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala
@@ -178,6 +178,15 @@ private[hive] trait ClientInterface {
holdDDLTime: Boolean,
listBucketingEnabled: Boolean): Unit
+ /** Add a jar into class loader */
+ def addJar(path: String): Unit
+
+ /** Return a ClientInterface as new session, that will share the class loader and Hive client */
+ def newSession(): ClientInterface
+
+ /** Run a function within Hive state (SessionState, HiveConf, Hive client and class loader) */
+ def withHiveState[A](f: => A): A
+
/** Used for testing only. Removes all metadata from this instance of Hive. */
def reset(): Unit
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala
index 8f6d448b2a..3dce86c480 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala
@@ -60,7 +60,8 @@ import org.apache.spark.util.{CircularBuffer, Utils}
private[hive] class ClientWrapper(
override val version: HiveVersion,
config: Map[String, String],
- initClassLoader: ClassLoader)
+ initClassLoader: ClassLoader,
+ val clientLoader: IsolatedClientLoader)
extends ClientInterface
with Logging {
@@ -150,31 +151,29 @@ private[hive] class ClientWrapper(
// Switch to the initClassLoader.
Thread.currentThread().setContextClassLoader(initClassLoader)
val ret = try {
- val oldState = SessionState.get()
- if (oldState == null) {
- val initialConf = new HiveConf(classOf[SessionState])
- // HiveConf is a Hadoop Configuration, which has a field of classLoader and
- // the initial value will be the current thread's context class loader
- // (i.e. initClassLoader at here).
- // We call initialConf.setClassLoader(initClassLoader) at here to make
- // this action explicit.
- initialConf.setClassLoader(initClassLoader)
- config.foreach { case (k, v) =>
- if (k.toLowerCase.contains("password")) {
- logDebug(s"Hive Config: $k=xxx")
- } else {
- logDebug(s"Hive Config: $k=$v")
- }
- initialConf.set(k, v)
+ val initialConf = new HiveConf(classOf[SessionState])
+ // HiveConf is a Hadoop Configuration, which has a field of classLoader and
+ // the initial value will be the current thread's context class loader
+ // (i.e. initClassLoader at here).
+ // We call initialConf.setClassLoader(initClassLoader) at here to make
+ // this action explicit.
+ initialConf.setClassLoader(initClassLoader)
+ config.foreach { case (k, v) =>
+ if (k.toLowerCase.contains("password")) {
+ logDebug(s"Hive Config: $k=xxx")
+ } else {
+ logDebug(s"Hive Config: $k=$v")
}
- val newState = new SessionState(initialConf)
- SessionState.start(newState)
- newState.out = new PrintStream(outputBuffer, true, "UTF-8")
- newState.err = new PrintStream(outputBuffer, true, "UTF-8")
- newState
- } else {
- oldState
+ initialConf.set(k, v)
+ }
+ val state = new SessionState(initialConf)
+ if (clientLoader.cachedHive != null) {
+ Hive.set(clientLoader.cachedHive.asInstanceOf[Hive])
}
+ SessionState.start(state)
+ state.out = new PrintStream(outputBuffer, true, "UTF-8")
+ state.err = new PrintStream(outputBuffer, true, "UTF-8")
+ state
} finally {
Thread.currentThread().setContextClassLoader(original)
}
@@ -188,11 +187,6 @@ private[hive] class ClientWrapper(
conf.get(key, defaultValue)
}
- // TODO: should be a def?s
- // When we create this val client, the HiveConf of it (conf) is the one associated with state.
- @GuardedBy("this")
- private var client = Hive.get(conf)
-
// We use hive's conf for compatibility.
private val retryLimit = conf.getIntVar(HiveConf.ConfVars.METASTORETHRIFTFAILURERETRIES)
private val retryDelayMillis = shim.getMetastoreClientConnectRetryDelayMillis(conf)
@@ -200,7 +194,7 @@ private[hive] class ClientWrapper(
/**
* Runs `f` with multiple retries in case the hive metastore is temporarily unreachable.
*/
- private def retryLocked[A](f: => A): A = synchronized {
+ private def retryLocked[A](f: => A): A = clientLoader.synchronized {
// Hive sometimes retries internally, so set a deadline to avoid compounding delays.
val deadline = System.nanoTime + (retryLimit * retryDelayMillis * 1e6).toLong
var numTries = 0
@@ -215,13 +209,8 @@ private[hive] class ClientWrapper(
logWarning(
"HiveClientWrapper got thrift exception, destroying client and retrying " +
s"(${retryLimit - numTries} tries remaining)", e)
+ clientLoader.cachedHive = null
Thread.sleep(retryDelayMillis)
- try {
- client = Hive.get(state.getConf, true)
- } catch {
- case e: Exception if causedByThrift(e) =>
- logWarning("Failed to refresh hive client, will retry.", e)
- }
}
} while (numTries <= retryLimit && System.nanoTime < deadline)
if (System.nanoTime > deadline) {
@@ -242,13 +231,26 @@ private[hive] class ClientWrapper(
false
}
+ def client: Hive = {
+ if (clientLoader.cachedHive != null) {
+ clientLoader.cachedHive.asInstanceOf[Hive]
+ } else {
+ val c = Hive.get(conf)
+ clientLoader.cachedHive = c
+ c
+ }
+ }
+
/**
* Runs `f` with ThreadLocal session state and classloaders configured for this version of hive.
*/
- private def withHiveState[A](f: => A): A = retryLocked {
+ def withHiveState[A](f: => A): A = retryLocked {
val original = Thread.currentThread().getContextClassLoader
// Set the thread local metastore client to the client associated with this ClientWrapper.
Hive.set(client)
+ // The classloader in clientLoader could be changed after addJar, always use the latest
+ // classloader
+ state.getConf.setClassLoader(clientLoader.classLoader)
// setCurrentSessionState will use the classLoader associated
// with the HiveConf in `state` to override the context class loader of the current
// thread.
@@ -545,6 +547,15 @@ private[hive] class ClientWrapper(
listBucketingEnabled)
}
+ def addJar(path: String): Unit = {
+ clientLoader.addJar(path)
+ runSqlHive(s"ADD JAR $path")
+ }
+
+ def newSession(): ClientWrapper = {
+ clientLoader.createClient().asInstanceOf[ClientWrapper]
+ }
+
def reset(): Unit = withHiveState {
client.getAllTables("default").asScala.foreach { t =>
logDebug(s"Deleting table $t")
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala
index 1fe4cba957..567e4d7b41 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala
@@ -22,6 +22,7 @@ import java.lang.reflect.InvocationTargetException
import java.net.{URL, URLClassLoader}
import java.util
+import scala.collection.mutable
import scala.language.reflectiveCalls
import scala.util.Try
@@ -148,53 +149,75 @@ private[hive] class IsolatedClientLoader(
name.replaceAll("\\.", "/") + ".class"
/** The classloader that is used to load an isolated version of Hive. */
- protected val classLoader: ClassLoader = new URLClassLoader(allJars, rootClassLoader) {
- override def loadClass(name: String, resolve: Boolean): Class[_] = {
- val loaded = findLoadedClass(name)
- if (loaded == null) doLoadClass(name, resolve) else loaded
- }
-
- def doLoadClass(name: String, resolve: Boolean): Class[_] = {
- val classFileName = name.replaceAll("\\.", "/") + ".class"
- if (isBarrierClass(name) && isolationOn) {
- // For barrier classes, we construct a new copy of the class.
- val bytes = IOUtils.toByteArray(baseClassLoader.getResourceAsStream(classFileName))
- logDebug(s"custom defining: $name - ${util.Arrays.hashCode(bytes)}")
- defineClass(name, bytes, 0, bytes.length)
- } else if (!isSharedClass(name)) {
- logDebug(s"hive class: $name - ${getResource(classToPath(name))}")
- super.loadClass(name, resolve)
- } else {
- // For shared classes, we delegate to baseClassLoader.
- logDebug(s"shared class: $name")
- baseClassLoader.loadClass(name)
+ private[hive] var classLoader: ClassLoader = if (isolationOn) {
+ new URLClassLoader(allJars, rootClassLoader) {
+ override def loadClass(name: String, resolve: Boolean): Class[_] = {
+ val loaded = findLoadedClass(name)
+ if (loaded == null) doLoadClass(name, resolve) else loaded
+ }
+ def doLoadClass(name: String, resolve: Boolean): Class[_] = {
+ val classFileName = name.replaceAll("\\.", "/") + ".class"
+ if (isBarrierClass(name)) {
+ // For barrier classes, we construct a new copy of the class.
+ val bytes = IOUtils.toByteArray(baseClassLoader.getResourceAsStream(classFileName))
+ logDebug(s"custom defining: $name - ${util.Arrays.hashCode(bytes)}")
+ defineClass(name, bytes, 0, bytes.length)
+ } else if (!isSharedClass(name)) {
+ logDebug(s"hive class: $name - ${getResource(classToPath(name))}")
+ super.loadClass(name, resolve)
+ } else {
+ // For shared classes, we delegate to baseClassLoader.
+ logDebug(s"shared class: $name")
+ baseClassLoader.loadClass(name)
+ }
}
}
+ } else {
+ baseClassLoader
}
- // Pre-reflective instantiation setup.
- logDebug("Initializing the logger to avoid disaster...")
- Thread.currentThread.setContextClassLoader(classLoader)
+ private[hive] def addJar(path: String): Unit = synchronized {
+ val jarURL = new java.io.File(path).toURI.toURL
+ // TODO: we should avoid of stacking classloaders (use a single URLClassLoader and add jars
+ // to that)
+ classLoader = new java.net.URLClassLoader(Array(jarURL), classLoader)
+ }
/** The isolated client interface to Hive. */
- val client: ClientInterface = try {
- classLoader
- .loadClass(classOf[ClientWrapper].getName)
- .getConstructors.head
- .newInstance(version, config, classLoader)
- .asInstanceOf[ClientInterface]
- } catch {
- case e: InvocationTargetException =>
- if (e.getCause().isInstanceOf[NoClassDefFoundError]) {
- val cnf = e.getCause().asInstanceOf[NoClassDefFoundError]
- throw new ClassNotFoundException(
- s"$cnf when creating Hive client using classpath: ${execJars.mkString(", ")}\n" +
- "Please make sure that jars for your version of hive and hadoop are included in the " +
- s"paths passed to ${HiveContext.HIVE_METASTORE_JARS}.")
- } else {
- throw e
- }
- } finally {
- Thread.currentThread.setContextClassLoader(baseClassLoader)
+ private[hive] def createClient(): ClientInterface = {
+ if (!isolationOn) {
+ return new ClientWrapper(version, config, baseClassLoader, this)
+ }
+ // Pre-reflective instantiation setup.
+ logDebug("Initializing the logger to avoid disaster...")
+ val origLoader = Thread.currentThread().getContextClassLoader
+ Thread.currentThread.setContextClassLoader(classLoader)
+
+ try {
+ classLoader
+ .loadClass(classOf[ClientWrapper].getName)
+ .getConstructors.head
+ .newInstance(version, config, classLoader, this)
+ .asInstanceOf[ClientInterface]
+ } catch {
+ case e: InvocationTargetException =>
+ if (e.getCause().isInstanceOf[NoClassDefFoundError]) {
+ val cnf = e.getCause().asInstanceOf[NoClassDefFoundError]
+ throw new ClassNotFoundException(
+ s"$cnf when creating Hive client using classpath: ${execJars.mkString(", ")}\n" +
+ "Please make sure that jars for your version of hive and hadoop are included in the " +
+ s"paths passed to ${HiveContext.HIVE_METASTORE_JARS}.")
+ } else {
+ throw e
+ }
+ } finally {
+ Thread.currentThread.setContextClassLoader(origLoader)
+ }
}
+
+ /**
+ * The place holder for shared Hive client for all the HiveContext sessions (they share an
+ * IsolatedClientLoader).
+ */
+ private[hive] var cachedHive: Any = null
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
index 9f654eed57..51ec92afd0 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
@@ -18,18 +18,18 @@
package org.apache.spark.sql.hive.execution
import org.apache.hadoop.hive.metastore.MetaStoreUtils
+
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.{TableIdentifier, SqlParser}
+import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.execution.RunnableCommand
-import org.apache.spark.sql.execution.datasources.{ResolvedDataSource, LogicalRelation}
+import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource}
import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
-import org.apache.spark.util.Utils
/**
* Analyzes the given table in the current database to generate statistics, which will be
@@ -86,26 +86,7 @@ case class AddJar(path: String) extends RunnableCommand {
}
override def run(sqlContext: SQLContext): Seq[Row] = {
- val hiveContext = sqlContext.asInstanceOf[HiveContext]
- val currentClassLoader = Utils.getContextOrSparkClassLoader
-
- // Add jar to current context
- val jarURL = new java.io.File(path).toURI.toURL
- val newClassLoader = new java.net.URLClassLoader(Array(jarURL), currentClassLoader)
- Thread.currentThread.setContextClassLoader(newClassLoader)
- // We need to explicitly set the class loader associated with the conf in executionHive's
- // state because this class loader will be used as the context class loader of the current
- // thread to execute any Hive command.
- // We cannot use `org.apache.hadoop.hive.ql.metadata.Hive.get().getConf()` because Hive.get()
- // returns the value of a thread local variable and its HiveConf may not be the HiveConf
- // associated with `executionHive.state` (for example, HiveContext is created in one thread
- // and then add jar is called from another thread).
- hiveContext.executionHive.state.getConf.setClassLoader(newClassLoader)
- // Add jar to isolated hive (metadataHive) class loader.
- hiveContext.runSqlHive(s"ADD JAR $path")
-
- // Add jar to executors
- hiveContext.sparkContext.addJar(path)
+ sqlContext.addJar(path)
Seq(Row(0))
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
index be335a47dc..ff39ccb7c1 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
@@ -116,27 +116,18 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) {
override def executePlan(plan: LogicalPlan): this.QueryExecution =
new this.QueryExecution(plan)
- // Make sure we set those test specific confs correctly when we create
- // the SQLConf as well as when we call clear.
- override protected[sql] def createSession(): SQLSession = {
- new this.SQLSession()
- }
-
- protected[hive] class SQLSession extends super.SQLSession {
- protected[sql] override lazy val conf: SQLConf = new SQLConf {
- // TODO as in unit test, conf.clear() probably be called, all of the value will be cleared.
- // The super.getConf(SQLConf.DIALECT) is "sql" by default, we need to set it as "hiveql"
- override def dialect: String = super.getConf(SQLConf.DIALECT, "hiveql")
- override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false)
+ protected[sql] override lazy val conf: SQLConf = new SQLConf {
+ // The super.getConf(SQLConf.DIALECT) is "sql" by default, we need to set it as "hiveql"
+ override def dialect: String = super.getConf(SQLConf.DIALECT, "hiveql")
+ override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false)
- clear()
+ clear()
- override def clear(): Unit = {
- super.clear()
+ override def clear(): Unit = {
+ super.clear()
- TestHiveContext.overrideConfs.map {
- case (key, value) => setConfString(key, value)
- }
+ TestHiveContext.overrideConfs.map {
+ case (key, value) => setConfString(key, value)
}
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala
index 79cf40aba4..528a7398b1 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala
@@ -17,22 +17,15 @@
package org.apache.spark.sql.hive
-import org.apache.hadoop.hive.conf.HiveConf
-import org.apache.hadoop.hive.ql.session.SessionState
import org.apache.hadoop.hive.serde.serdeConstants
+import org.scalatest.BeforeAndAfterAll
+
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.hive.client.{ManagedTable, HiveColumn, ExternalTable, HiveTable}
-import org.scalatest.BeforeAndAfterAll
+import org.apache.spark.sql.hive.client.{ExternalTable, HiveColumn, HiveTable, ManagedTable}
class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll {
- override def beforeAll() {
- if (SessionState.get() == null) {
- SessionState.start(new HiveConf())
- }
- }
-
private def extractTableDesc(sql: String): (HiveTable, Boolean) = {
HiveQl.createPlan(sql).collect {
case CreateTableAsSelect(desc, child, allowExisting) => (desc, allowExisting)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala
index 2da22ec237..c6d034a23a 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala
@@ -53,7 +53,7 @@ class VersionsSuite extends SparkFunSuite with Logging {
test("success sanity check") {
val badClient = IsolatedClientLoader.forVersion(HiveContext.hiveExecutionVersion,
buildConf(),
- ivyPath).client
+ ivyPath).createClient()
val db = new HiveDatabase("default", "")
badClient.createDatabase(db)
}
@@ -83,7 +83,7 @@ class VersionsSuite extends SparkFunSuite with Logging {
ignore("failure sanity check") {
val e = intercept[Throwable] {
val badClient = quietly {
- IsolatedClientLoader.forVersion("13", buildConf(), ivyPath).client
+ IsolatedClientLoader.forVersion("13", buildConf(), ivyPath).createClient()
}
}
assert(getNestedMessages(e) contains "Unknown column 'A0.OWNER_NAME' in 'field list'")
@@ -97,7 +97,7 @@ class VersionsSuite extends SparkFunSuite with Logging {
test(s"$version: create client") {
client = null
System.gc() // Hack to avoid SEGV on some JVM versions.
- client = IsolatedClientLoader.forVersion(version, buildConf(), ivyPath).client
+ client = IsolatedClientLoader.forVersion(version, buildConf(), ivyPath).createClient()
}
test(s"$version: createDatabase") {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index fe63ad5683..2878500453 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -1133,6 +1133,38 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
conf.clear()
}
+ test("current_database with multiple sessions") {
+ sql("create database a")
+ sql("use a")
+ val s2 = newSession()
+ s2.sql("create database b")
+ s2.sql("use b")
+
+ assert(sql("select current_database()").first() === Row("a"))
+ assert(s2.sql("select current_database()").first() === Row("b"))
+
+ try {
+ sql("create table test_a(key INT, value STRING)")
+ s2.sql("create table test_b(key INT, value STRING)")
+
+ sql("select * from test_a")
+ intercept[AnalysisException] {
+ sql("select * from test_b")
+ }
+ sql("select * from b.test_b")
+
+ s2.sql("select * from test_b")
+ intercept[AnalysisException] {
+ s2.sql("select * from test_a")
+ }
+ s2.sql("select * from a.test_a")
+ } finally {
+ sql("DROP TABLE IF EXISTS test_a")
+ s2.sql("DROP TABLE IF EXISTS test_b")
+ }
+
+ }
+
createQueryTest("select from thrift based table",
"SELECT * from src_thrift")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index ec5b83b98e..ccc15eaa63 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -160,10 +160,15 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
}
test("show functions") {
- val allFunctions =
+ val allBuiltinFunctions =
(FunctionRegistry.builtin.listFunction().toSet[String] ++
org.apache.hadoop.hive.ql.exec.FunctionRegistry.getFunctionNames.asScala).toList.sorted
- checkAnswer(sql("SHOW functions"), allFunctions.map(Row(_)))
+ // The TestContext is shared by all the test cases, some functions may be registered before
+ // this, so we check that all the builtin functions are returned.
+ val allFunctions = sql("SHOW functions").collect().map(r => r(0))
+ allBuiltinFunctions.foreach { f =>
+ assert(allFunctions.contains(f))
+ }
checkAnswer(sql("SHOW functions abs"), Row("abs"))
checkAnswer(sql("SHOW functions 'abs'"), Row("abs"))
checkAnswer(sql("SHOW functions abc.abs"), Row("abs"))