aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-10-08 17:34:24 -0700
committerDavies Liu <davies.liu@gmail.com>2015-10-08 17:34:24 -0700
commit3390b400d04e40f767d8a51f1078fcccb4e64abd (patch)
treed48ed36a14abf0b15467c9ae9c7c04933fdd3a19 /sql
parent84ea287178247c163226e835490c9c70b17d8d3b (diff)
downloadspark-3390b400d04e40f767d8a51f1078fcccb4e64abd.tar.gz
spark-3390b400d04e40f767d8a51f1078fcccb4e64abd.tar.bz2
spark-3390b400d04e40f767d8a51f1078fcccb4e64abd.zip
[SPARK-10810] [SPARK-10902] [SQL] Improve session management in SQL
This PR improve the sessions management by replacing the thread-local based to one SQLContext per session approach, introduce separated temporary tables and UDFs/UDAFs for each session. A new session of SQLContext could be created by: 1) create an new SQLContext 2) call newSession() on existing SQLContext For HiveContext, in order to reduce the cost for each session, the classloader and Hive client are shared across multiple sessions (created by newSession). CacheManager is also shared by multiple sessions, so cache a table multiple times in different sessions will not cause multiple copies of in-memory cache. Added jars are still shared by all the sessions, because SparkContext does not support sessions. cc marmbrus yhuai rxin Author: Davies Liu <davies@databricks.com> Closes #8909 from davies/sessions.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala28
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala164
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala14
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala59
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala21
-rw-r--r--sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala76
-rw-r--r--sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala9
-rw-r--r--sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala5
-rw-r--r--sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala8
-rw-r--r--sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala76
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala155
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala28
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala9
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala85
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala107
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala27
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala27
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala13
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala6
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala32
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala9
21 files changed, 519 insertions, 439 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index e6122d92b7..ba77b70a37 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -51,23 +51,37 @@ class SimpleFunctionRegistry extends FunctionRegistry {
private val functionBuilders =
StringKeyHashMap[(ExpressionInfo, FunctionBuilder)](caseSensitive = false)
- override def registerFunction(name: String, info: ExpressionInfo, builder: FunctionBuilder)
- : Unit = {
+ override def registerFunction(
+ name: String,
+ info: ExpressionInfo,
+ builder: FunctionBuilder): Unit = synchronized {
functionBuilders.put(name, (info, builder))
}
override def lookupFunction(name: String, children: Seq[Expression]): Expression = {
- val func = functionBuilders.get(name).map(_._2).getOrElse {
- throw new AnalysisException(s"undefined function $name")
+ val func = synchronized {
+ functionBuilders.get(name).map(_._2).getOrElse {
+ throw new AnalysisException(s"undefined function $name")
+ }
}
func(children)
}
- override def listFunction(): Seq[String] = functionBuilders.iterator.map(_._1).toList.sorted
+ override def listFunction(): Seq[String] = synchronized {
+ functionBuilders.iterator.map(_._1).toList.sorted
+ }
- override def lookupFunction(name: String): Option[ExpressionInfo] = {
+ override def lookupFunction(name: String): Option[ExpressionInfo] = synchronized {
functionBuilders.get(name).map(_._1)
}
+
+ def copy(): SimpleFunctionRegistry = synchronized {
+ val registry = new SimpleFunctionRegistry
+ functionBuilders.iterator.foreach { case (name, (info, builder)) =>
+ registry.registerFunction(name, info, builder)
+ }
+ registry
+ }
}
/**
@@ -257,7 +271,7 @@ object FunctionRegistry {
expression[InputFileName]("input_file_name")
)
- val builtin: FunctionRegistry = {
+ val builtin: SimpleFunctionRegistry = {
val fr = new SimpleFunctionRegistry
expressions.foreach { case (name, (info, builder)) => fr.registerFunction(name, info, builder) }
fr
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index cb0a3e361c..2bdfd82af0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -30,6 +30,7 @@ import org.apache.spark.SparkContext
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.rdd.RDD
+import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
import org.apache.spark.sql.SQLConf.SQLConfEntry
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.errors.DialectException
@@ -38,15 +39,12 @@ import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, _}
-import org.apache.spark.sql.execution.{Filter, _}
-import org.apache.spark.sql.{execution => sparkexecution}
-import org.apache.spark.sql.execution._
-import org.apache.spark.sql.sources._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab}
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types._
+import org.apache.spark.sql.{execution => sparkexecution}
import org.apache.spark.util.Utils
/**
@@ -64,18 +62,30 @@ import org.apache.spark.util.Utils
*
* @since 1.0.0
*/
-class SQLContext(@transient val sparkContext: SparkContext)
- extends org.apache.spark.Logging
- with Serializable {
+class SQLContext private[sql](
+ @transient val sparkContext: SparkContext,
+ @transient protected[sql] val cacheManager: CacheManager)
+ extends org.apache.spark.Logging with Serializable {
self =>
+ def this(sparkContext: SparkContext) = this(sparkContext, new CacheManager)
def this(sparkContext: JavaSparkContext) = this(sparkContext.sc)
/**
+ * Returns a SQLContext as new session, with separated SQL configurations, temporary tables,
+ * registered functions, but sharing the same SparkContext and CacheManager.
+ *
+ * @since 1.6.0
+ */
+ def newSession(): SQLContext = {
+ new SQLContext(sparkContext, cacheManager)
+ }
+
+ /**
* @return Spark SQL configuration
*/
- protected[sql] def conf = currentSession().conf
+ protected[sql] lazy val conf = new SQLConf
// `listener` should be only used in the driver
@transient private[sql] val listener = new SQLListener(this)
@@ -142,13 +152,11 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
def getAllConfs: immutable.Map[String, String] = conf.getAllConfs
- // TODO how to handle the temp table per user session?
@transient
protected[sql] lazy val catalog: Catalog = new SimpleCatalog(conf)
- // TODO how to handle the temp function per user session?
@transient
- protected[sql] lazy val functionRegistry: FunctionRegistry = FunctionRegistry.builtin
+ protected[sql] lazy val functionRegistry: FunctionRegistry = FunctionRegistry.builtin.copy()
@transient
protected[sql] lazy val analyzer: Analyzer =
@@ -198,20 +206,19 @@ class SQLContext(@transient val sparkContext: SparkContext)
protected[sql] def executePlan(plan: LogicalPlan) =
new sparkexecution.QueryExecution(this, plan)
- @transient
- protected[sql] val tlSession = new ThreadLocal[SQLSession]() {
- override def initialValue: SQLSession = defaultSession
- }
-
- @transient
- protected[sql] val defaultSession = createSession()
-
protected[sql] def dialectClassName = if (conf.dialect == "sql") {
classOf[DefaultParserDialect].getCanonicalName
} else {
conf.dialect
}
+ /**
+ * Add a jar to SQLContext
+ */
+ protected[sql] def addJar(path: String): Unit = {
+ sparkContext.addJar(path)
+ }
+
{
// We extract spark sql settings from SparkContext's conf and put them to
// Spark SQL's conf.
@@ -236,9 +243,6 @@ class SQLContext(@transient val sparkContext: SparkContext)
}
}
- @transient
- protected[sql] val cacheManager = new CacheManager(this)
-
/**
* :: Experimental ::
* A collection of methods that are considered experimental, but can be used to hook into
@@ -300,21 +304,25 @@ class SQLContext(@transient val sparkContext: SparkContext)
* @group cachemgmt
* @since 1.3.0
*/
- def isCached(tableName: String): Boolean = cacheManager.isCached(tableName)
+ def isCached(tableName: String): Boolean = {
+ cacheManager.lookupCachedData(table(tableName)).nonEmpty
+ }
/**
* Caches the specified table in-memory.
* @group cachemgmt
* @since 1.3.0
*/
- def cacheTable(tableName: String): Unit = cacheManager.cacheTable(tableName)
+ def cacheTable(tableName: String): Unit = {
+ cacheManager.cacheQuery(table(tableName), Some(tableName))
+ }
/**
* Removes the specified table from the in-memory cache.
* @group cachemgmt
* @since 1.3.0
*/
- def uncacheTable(tableName: String): Unit = cacheManager.uncacheTable(tableName)
+ def uncacheTable(tableName: String): Unit = cacheManager.uncacheQuery(table(tableName))
/**
* Removes all cached tables from the in-memory cache.
@@ -830,36 +838,6 @@ class SQLContext(@transient val sparkContext: SparkContext)
)
}
- protected[sql] def openSession(): SQLSession = {
- detachSession()
- val session = createSession()
- tlSession.set(session)
-
- session
- }
-
- protected[sql] def currentSession(): SQLSession = {
- tlSession.get()
- }
-
- protected[sql] def createSession(): SQLSession = {
- new this.SQLSession()
- }
-
- protected[sql] def detachSession(): Unit = {
- tlSession.remove()
- }
-
- protected[sql] def setSession(session: SQLSession): Unit = {
- detachSession()
- tlSession.set(session)
- }
-
- protected[sql] class SQLSession {
- // Note that this is a lazy val so we can override the default value in subclasses.
- protected[sql] lazy val conf: SQLConf = new SQLConf
- }
-
@deprecated("use org.apache.spark.sql.QueryExecution", "1.6.0")
protected[sql] class QueryExecution(logical: LogicalPlan)
extends sparkexecution.QueryExecution(this, logical)
@@ -1196,46 +1174,90 @@ class SQLContext(@transient val sparkContext: SparkContext)
// Register a succesfully instantiatd context to the singleton. This should be at the end of
// the class definition so that the singleton is updated only if there is no exception in the
// construction of the instance.
- SQLContext.setLastInstantiatedContext(self)
+ sparkContext.addSparkListener(new SparkListener {
+ override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = {
+ SQLContext.clearInstantiatedContext(self)
+ }
+ })
+
+ SQLContext.setInstantiatedContext(self)
}
/**
* This SQLContext object contains utility functions to create a singleton SQLContext instance,
- * or to get the last created SQLContext instance.
+ * or to get the created SQLContext instance.
+ *
+ * It also provides utility functions to support preference for threads in multiple sessions
+ * scenario, setActive could set a SQLContext for current thread, which will be returned by
+ * getOrCreate instead of the global one.
*/
object SQLContext {
- private val INSTANTIATION_LOCK = new Object()
+ /**
+ * The active SQLContext for the current thread.
+ */
+ private val activeContext: InheritableThreadLocal[SQLContext] =
+ new InheritableThreadLocal[SQLContext]
/**
- * Reference to the last created SQLContext.
+ * Reference to the created SQLContext.
*/
- @transient private val lastInstantiatedContext = new AtomicReference[SQLContext]()
+ @transient private val instantiatedContext = new AtomicReference[SQLContext]()
/**
* Get the singleton SQLContext if it exists or create a new one using the given SparkContext.
+ *
* This function can be used to create a singleton SQLContext object that can be shared across
* the JVM.
+ *
+ * If there is an active SQLContext for current thread, it will be returned instead of the global
+ * one.
+ *
+ * @since 1.5.0
*/
def getOrCreate(sparkContext: SparkContext): SQLContext = {
- INSTANTIATION_LOCK.synchronized {
- if (lastInstantiatedContext.get() == null) {
+ val ctx = activeContext.get()
+ if (ctx != null) {
+ return ctx
+ }
+
+ synchronized {
+ val ctx = instantiatedContext.get()
+ if (ctx == null) {
new SQLContext(sparkContext)
+ } else {
+ ctx
}
}
- lastInstantiatedContext.get()
}
- private[sql] def clearLastInstantiatedContext(): Unit = {
- INSTANTIATION_LOCK.synchronized {
- lastInstantiatedContext.set(null)
- }
+ private[sql] def clearInstantiatedContext(sqlContext: SQLContext): Unit = {
+ instantiatedContext.compareAndSet(sqlContext, null)
}
- private[sql] def setLastInstantiatedContext(sqlContext: SQLContext): Unit = {
- INSTANTIATION_LOCK.synchronized {
- lastInstantiatedContext.set(sqlContext)
- }
+ private[sql] def setInstantiatedContext(sqlContext: SQLContext): Unit = {
+ instantiatedContext.compareAndSet(null, sqlContext)
+ }
+
+ /**
+ * Changes the SQLContext that will be returned in this thread and its children when
+ * SQLContext.getOrCreate() is called. This can be used to ensure that a given thread receives
+ * a SQLContext with an isolated session, instead of the global (first created) context.
+ *
+ * @since 1.6.0
+ */
+ def setActive(sqlContext: SQLContext): Unit = {
+ activeContext.set(sqlContext)
+ }
+
+ /**
+ * Clears the active SQLContext for current thread. Subsequent calls to getOrCreate will
+ * return the first created context instead of a thread-local override.
+ *
+ * @since 1.6.0
+ */
+ def clearActive(): Unit = {
+ activeContext.remove()
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
index d3e5c378d0..f85aeb1b02 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
@@ -20,9 +20,9 @@ package org.apache.spark.sql.execution
import java.util.concurrent.locks.ReentrantReadWriteLock
import org.apache.spark.Logging
+import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.columnar.InMemoryRelation
-import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK
@@ -37,7 +37,7 @@ private[sql] case class CachedData(plan: LogicalPlan, cachedRepresentation: InMe
*
* Internal to Spark SQL.
*/
-private[sql] class CacheManager(sqlContext: SQLContext) extends Logging {
+private[sql] class CacheManager extends Logging {
@transient
private val cachedData = new scala.collection.mutable.ArrayBuffer[CachedData]
@@ -45,15 +45,6 @@ private[sql] class CacheManager(sqlContext: SQLContext) extends Logging {
@transient
private val cacheLock = new ReentrantReadWriteLock
- /** Returns true if the table is currently cached in-memory. */
- def isCached(tableName: String): Boolean = lookupCachedData(sqlContext.table(tableName)).nonEmpty
-
- /** Caches the specified table in-memory. */
- def cacheTable(tableName: String): Unit = cacheQuery(sqlContext.table(tableName), Some(tableName))
-
- /** Removes the specified table from the in-memory cache. */
- def uncacheTable(tableName: String): Unit = uncacheQuery(sqlContext.table(tableName))
-
/** Acquires a read lock on the cache for the duration of `f`. */
private def readLock[A](f: => A): A = {
val lock = cacheLock.readLock()
@@ -96,6 +87,7 @@ private[sql] class CacheManager(sqlContext: SQLContext) extends Logging {
if (lookupCachedData(planToCache).nonEmpty) {
logWarning("Asked to cache already cached data.")
} else {
+ val sqlContext = query.sqlContext
cachedData +=
CachedData(
planToCache,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala
index dd88ae3700..1994dacfc4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala
@@ -17,33 +17,52 @@
package org.apache.spark.sql
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.{SharedSparkContext, SparkFunSuite}
-class SQLContextSuite extends SparkFunSuite with SharedSQLContext {
-
- override def afterAll(): Unit = {
- try {
- SQLContext.setLastInstantiatedContext(sqlContext)
- } finally {
- super.afterAll()
- }
- }
+class SQLContextSuite extends SparkFunSuite with SharedSparkContext{
test("getOrCreate instantiates SQLContext") {
- SQLContext.clearLastInstantiatedContext()
- val sqlContext = SQLContext.getOrCreate(sparkContext)
+ val sqlContext = SQLContext.getOrCreate(sc)
assert(sqlContext != null, "SQLContext.getOrCreate returned null")
- assert(SQLContext.getOrCreate(sparkContext).eq(sqlContext),
+ assert(SQLContext.getOrCreate(sc).eq(sqlContext),
"SQLContext created by SQLContext.getOrCreate not returned by SQLContext.getOrCreate")
}
- test("getOrCreate gets last explicitly instantiated SQLContext") {
- SQLContext.clearLastInstantiatedContext()
- val sqlContext = new SQLContext(sparkContext)
- assert(SQLContext.getOrCreate(sparkContext) != null,
- "SQLContext.getOrCreate after explicitly created SQLContext returned null")
- assert(SQLContext.getOrCreate(sparkContext).eq(sqlContext),
+ test("getOrCreate return the original SQLContext") {
+ val sqlContext = SQLContext.getOrCreate(sc)
+ val newSession = sqlContext.newSession()
+ assert(SQLContext.getOrCreate(sc).eq(sqlContext),
"SQLContext.getOrCreate after explicitly created SQLContext did not return the context")
+ SQLContext.setActive(newSession)
+ assert(SQLContext.getOrCreate(sc).eq(newSession),
+ "SQLContext.getOrCreate after explicitly setActive() did not return the active context")
+ }
+
+ test("Sessions of SQLContext") {
+ val sqlContext = SQLContext.getOrCreate(sc)
+ val session1 = sqlContext.newSession()
+ val session2 = sqlContext.newSession()
+
+ // all have the default configurations
+ val key = SQLConf.SHUFFLE_PARTITIONS.key
+ assert(session1.getConf(key) === session2.getConf(key))
+ session1.setConf(key, "1")
+ session2.setConf(key, "2")
+ assert(session1.getConf(key) === "1")
+ assert(session2.getConf(key) === "2")
+
+ // temporary table should not be shared
+ val df = session1.range(10)
+ df.registerTempTable("test1")
+ assert(session1.tableNames().contains("test1"))
+ assert(!session2.tableNames().contains("test1"))
+
+ // UDF should not be shared
+ def myadd(a: Int, b: Int): Int = a + b
+ session1.udf.register[Int, Int, Int]("myadd", myadd)
+ session1.sql("select myadd(1, 2)").explain()
+ intercept[AnalysisException] {
+ session2.sql("select myadd(1, 2)").explain()
+ }
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
index 10e633f3cd..c89a151650 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
@@ -31,23 +31,16 @@ private[sql] class TestSQLContext(sc: SparkContext) extends SQLContext(sc) { sel
new SparkConf().set("spark.sql.testkey", "true")))
}
- // Make sure we set those test specific confs correctly when we create
- // the SQLConf as well as when we call clear.
- protected[sql] override def createSession(): SQLSession = new this.SQLSession()
+ protected[sql] override lazy val conf: SQLConf = new SQLConf {
- /** A special [[SQLSession]] that uses fewer shuffle partitions than normal. */
- protected[sql] class SQLSession extends super.SQLSession {
- protected[sql] override lazy val conf: SQLConf = new SQLConf {
+ clear()
- clear()
+ override def clear(): Unit = {
+ super.clear()
- override def clear(): Unit = {
- super.clear()
-
- // Make sure we start with the default test configs even after clear
- TestSQLContext.overrideConfs.map {
- case (key, value) => setConfString(key, value)
- }
+ // Make sure we start with the default test configs even after clear
+ TestSQLContext.overrideConfs.map {
+ case (key, value) => setConfString(key, value)
}
}
}
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala
index 306f98bcb5..719b03e1c7 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala
@@ -20,19 +20,15 @@ package org.apache.spark.sql.hive.thriftserver
import java.security.PrivilegedExceptionAction
import java.sql.{Date, Timestamp}
import java.util.concurrent.RejectedExecutionException
-import java.util.{Arrays, Map => JMap, UUID}
+import java.util.{Arrays, UUID, Map => JMap}
import scala.collection.JavaConverters._
import scala.collection.mutable.{ArrayBuffer, Map => SMap}
import scala.util.control.NonFatal
-import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.metastore.api.FieldSchema
-import org.apache.hive.service.cli._
-import org.apache.hadoop.hive.ql.metadata.Hive
-import org.apache.hadoop.hive.ql.metadata.HiveException
-import org.apache.hadoop.hive.ql.session.SessionState
import org.apache.hadoop.hive.shims.Utils
+import org.apache.hive.service.cli._
import org.apache.hive.service.cli.operation.ExecuteStatementOperation
import org.apache.hive.service.cli.session.HiveSession
@@ -40,7 +36,7 @@ import org.apache.spark.Logging
import org.apache.spark.sql.execution.SetCommand
import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes}
import org.apache.spark.sql.types._
-import org.apache.spark.sql.{DataFrame, Row => SparkRow, SQLConf}
+import org.apache.spark.sql.{DataFrame, SQLConf, Row => SparkRow}
private[hive] class SparkExecuteStatementOperation(
@@ -143,30 +139,15 @@ private[hive] class SparkExecuteStatementOperation(
if (!runInBackground) {
runInternal()
} else {
- val parentSessionState = SessionState.get()
- val hiveConf = getConfigForOperation()
val sparkServiceUGI = Utils.getUGI()
- val sessionHive = getCurrentHive()
- val currentSqlSession = hiveContext.currentSession
// Runnable impl to call runInternal asynchronously,
// from a different thread
val backgroundOperation = new Runnable() {
override def run(): Unit = {
- val doAsAction = new PrivilegedExceptionAction[Object]() {
- override def run(): Object = {
-
- // User information is part of the metastore client member in Hive
- hiveContext.setSession(currentSqlSession)
- // Always use the latest class loader provided by executionHive's state.
- val executionHiveClassLoader =
- hiveContext.executionHive.state.getConf.getClassLoader
- sessionHive.getConf.setClassLoader(executionHiveClassLoader)
- parentSessionState.getConf.setClassLoader(executionHiveClassLoader)
-
- Hive.set(sessionHive)
- SessionState.setCurrentSessionState(parentSessionState)
+ val doAsAction = new PrivilegedExceptionAction[Unit]() {
+ override def run(): Unit = {
try {
runInternal()
} catch {
@@ -174,7 +155,6 @@ private[hive] class SparkExecuteStatementOperation(
setOperationException(e)
log.error("Error running hive query: ", e)
}
- return null
}
}
@@ -191,7 +171,7 @@ private[hive] class SparkExecuteStatementOperation(
try {
// This submit blocks if no background threads are available to run this operation
val backgroundHandle =
- getParentSession().getSessionManager().submitBackgroundOperation(backgroundOperation)
+ parentSession.getSessionManager().submitBackgroundOperation(backgroundOperation)
setBackgroundHandle(backgroundHandle)
} catch {
case rejected: RejectedExecutionException =>
@@ -210,6 +190,11 @@ private[hive] class SparkExecuteStatementOperation(
statementId = UUID.randomUUID().toString
logInfo(s"Running query '$statement' with $statementId")
setState(OperationState.RUNNING)
+ // Always use the latest class loader provided by executionHive's state.
+ val executionHiveClassLoader =
+ hiveContext.executionHive.state.getConf.getClassLoader
+ Thread.currentThread().setContextClassLoader(executionHiveClassLoader)
+
HiveThriftServer2.listener.onStatementStart(
statementId,
parentSession.getSessionHandle.getSessionId.toString,
@@ -279,43 +264,4 @@ private[hive] class SparkExecuteStatementOperation(
}
}
}
-
- /**
- * If there are query specific settings to overlay, then create a copy of config
- * There are two cases we need to clone the session config that's being passed to hive driver
- * 1. Async query -
- * If the client changes a config setting, that shouldn't reflect in the execution
- * already underway
- * 2. confOverlay -
- * The query specific settings should only be applied to the query config and not session
- * @return new configuration
- * @throws HiveSQLException
- */
- private def getConfigForOperation(): HiveConf = {
- var sqlOperationConf = getParentSession().getHiveConf()
- if (!getConfOverlay().isEmpty() || runInBackground) {
- // clone the partent session config for this query
- sqlOperationConf = new HiveConf(sqlOperationConf)
-
- // apply overlay query specific settings, if any
- getConfOverlay().asScala.foreach { case (k, v) =>
- try {
- sqlOperationConf.verifyAndSet(k, v)
- } catch {
- case e: IllegalArgumentException =>
- throw new HiveSQLException("Error applying statement specific settings", e)
- }
- }
- }
- return sqlOperationConf
- }
-
- private def getCurrentHive(): Hive = {
- try {
- return Hive.get()
- } catch {
- case e: HiveException =>
- throw new HiveSQLException("Failed to get current Hive object", e);
- }
- }
}
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala
index 92ac0ec3fc..33aaead3fb 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala
@@ -36,7 +36,7 @@ private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, hiveContext:
extends SessionManager(hiveServer)
with ReflectedCompositeService {
- private lazy val sparkSqlOperationManager = new SparkSQLOperationManager(hiveContext)
+ private lazy val sparkSqlOperationManager = new SparkSQLOperationManager()
override def init(hiveConf: HiveConf) {
setSuperField(this, "hiveConf", hiveConf)
@@ -60,13 +60,15 @@ private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, hiveContext:
sessionConf: java.util.Map[String, String],
withImpersonation: Boolean,
delegationToken: String): SessionHandle = {
- hiveContext.openSession()
val sessionHandle =
super.openSession(protocol, username, passwd, ipAddress, sessionConf, withImpersonation,
delegationToken)
val session = super.getSession(sessionHandle)
HiveThriftServer2.listener.onSessionCreated(
session.getIpAddress, sessionHandle.getSessionId.toString, session.getUsername)
+ val ctx = hiveContext.newSession()
+ ctx.setConf("spark.sql.hive.version", HiveContext.hiveExecutionVersion)
+ sparkSqlOperationManager.sessionToContexts += sessionHandle -> ctx
sessionHandle
}
@@ -74,7 +76,6 @@ private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, hiveContext:
HiveThriftServer2.listener.onSessionClosed(sessionHandle.getSessionId.toString)
super.closeSession(sessionHandle)
sparkSqlOperationManager.sessionToActivePool -= sessionHandle
-
- hiveContext.detachSession()
+ sparkSqlOperationManager.sessionToContexts.remove(sessionHandle)
}
}
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
index c8031ed0f3..476651a559 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
@@ -30,20 +30,21 @@ import org.apache.spark.sql.hive.thriftserver.{SparkExecuteStatementOperation, R
/**
* Executes queries using Spark SQL, and maintains a list of handles to active queries.
*/
-private[thriftserver] class SparkSQLOperationManager(hiveContext: HiveContext)
+private[thriftserver] class SparkSQLOperationManager()
extends OperationManager with Logging {
val handleToOperation = ReflectionUtils
.getSuperField[JMap[OperationHandle, Operation]](this, "handleToOperation")
val sessionToActivePool = Map[SessionHandle, String]()
+ val sessionToContexts = Map[SessionHandle, HiveContext]()
override def newExecuteStatementOperation(
parentSession: HiveSession,
statement: String,
confOverlay: JMap[String, String],
async: Boolean): ExecuteStatementOperation = synchronized {
-
+ val hiveContext = sessionToContexts(parentSession.getSessionHandle)
val runInBackground = async && hiveContext.hiveThriftServerAsync
val operation = new SparkExecuteStatementOperation(parentSession, statement, confOverlay,
runInBackground)(hiveContext, sessionToActivePool)
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
index e59a14ec00..76d1591a23 100644
--- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
@@ -96,7 +96,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging {
buffer += s"${new Timestamp(new Date().getTime)} - $source> $line"
// If we haven't found all expected answers and another expected answer comes up...
- if (next < expectedAnswers.size && line.startsWith(expectedAnswers(next))) {
+ if (next < expectedAnswers.size && line.contains(expectedAnswers(next))) {
next += 1
// If all expected answers have been found...
if (next == expectedAnswers.size) {
@@ -159,7 +159,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging {
s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE hive_test;"
-> "OK",
"CACHE TABLE hive_test;"
- -> "Time taken: ",
+ -> "",
"SELECT COUNT(*) FROM hive_test;"
-> "5",
"DROP TABLE hive_test;"
@@ -180,7 +180,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging {
"CREATE TABLE hive_test(key INT, val STRING);"
-> "OK",
"SHOW TABLES;"
- -> "Time taken: "
+ -> "hive_test"
)
runCliWithin(2.minute, Seq("--database", "hive_test_db", "-e", "SHOW TABLES;"))(
@@ -210,7 +210,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging {
s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE sourceTable;"
-> "OK",
"INSERT INTO TABLE t1 SELECT key, val FROM sourceTable;"
- -> "Time taken:",
+ -> "",
"SELECT count(key) FROM t1;"
-> "5",
"DROP TABLE t1;"
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
index 19b2f24456..ff8ca01506 100644
--- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
@@ -205,6 +205,7 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest {
import org.apache.spark.sql.SQLConf
var defaultV1: String = null
var defaultV2: String = null
+ var data: ArrayBuffer[Int] = null
withMultipleConnectionJdbcStatement(
// create table
@@ -214,10 +215,16 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest {
"DROP TABLE IF EXISTS test_map",
"CREATE TABLE test_map(key INT, value STRING)",
s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_map",
- "CACHE TABLE test_table AS SELECT key FROM test_map ORDER BY key DESC")
+ "CACHE TABLE test_table AS SELECT key FROM test_map ORDER BY key DESC",
+ "CREATE DATABASE db1")
queries.foreach(statement.execute)
+ val plan = statement.executeQuery("explain select * from test_table")
+ plan.next()
+ plan.next()
+ assert(plan.getString(1).contains("InMemoryColumnarTableScan"))
+
val rs1 = statement.executeQuery("SELECT key FROM test_table ORDER BY KEY DESC")
val buf1 = new collection.mutable.ArrayBuffer[Int]()
while (rs1.next()) {
@@ -233,6 +240,8 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest {
rs2.close()
assert(buf1 === buf2)
+
+ data = buf1
},
// first session, we get the default value of the session status
@@ -289,56 +298,51 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest {
rs2.close()
},
- // accessing the cached data in another session
+ // try to access the cached data in another session
{ statement =>
- val rs1 = statement.executeQuery("SELECT key FROM test_table ORDER BY KEY DESC")
- val buf1 = new collection.mutable.ArrayBuffer[Int]()
- while (rs1.next()) {
- buf1 += rs1.getInt(1)
+ // Cached temporary table can't be accessed by other sessions
+ intercept[SQLException] {
+ statement.executeQuery("SELECT key FROM test_table ORDER BY KEY DESC")
}
- rs1.close()
- val rs2 = statement.executeQuery("SELECT key FROM test_map ORDER BY KEY DESC")
- val buf2 = new collection.mutable.ArrayBuffer[Int]()
- while (rs2.next()) {
- buf2 += rs2.getInt(1)
+ val plan = statement.executeQuery("explain select key from test_map ORDER BY key DESC")
+ plan.next()
+ plan.next()
+ assert(plan.getString(1).contains("InMemoryColumnarTableScan"))
+
+ val rs = statement.executeQuery("SELECT key FROM test_map ORDER BY KEY DESC")
+ val buf = new collection.mutable.ArrayBuffer[Int]()
+ while (rs.next()) {
+ buf += rs.getInt(1)
}
- rs2.close()
+ rs.close()
+ assert(buf === data)
+ },
- assert(buf1 === buf2)
- statement.executeQuery("UNCACHE TABLE test_table")
+ // switch another database
+ { statement =>
+ statement.execute("USE db1")
- // TODO need to figure out how to determine if the data loaded from cache
- val rs3 = statement.executeQuery("SELECT key FROM test_map ORDER BY KEY DESC")
- val buf3 = new collection.mutable.ArrayBuffer[Int]()
- while (rs3.next()) {
- buf3 += rs3.getInt(1)
+ // there is no test_map table in db1
+ intercept[SQLException] {
+ statement.executeQuery("SELECT key FROM test_map ORDER BY KEY DESC")
}
- rs3.close()
- assert(buf1 === buf3)
+ statement.execute("CREATE TABLE test_map2(key INT, value STRING)")
},
- // accessing the uncached table
+ // access default database
{ statement =>
- // TODO need to figure out how to determine if the data loaded from cache
- val rs1 = statement.executeQuery("SELECT key FROM test_table ORDER BY KEY DESC")
- val buf1 = new collection.mutable.ArrayBuffer[Int]()
- while (rs1.next()) {
- buf1 += rs1.getInt(1)
- }
- rs1.close()
-
- val rs2 = statement.executeQuery("SELECT key FROM test_map ORDER BY KEY DESC")
- val buf2 = new collection.mutable.ArrayBuffer[Int]()
- while (rs2.next()) {
- buf2 += rs2.getInt(1)
+ // current database should still be `default`
+ intercept[SQLException] {
+ statement.executeQuery("SELECT key FROM test_map2")
}
- rs2.close()
- assert(buf1 === buf2)
+ statement.execute("USE db1")
+ // access test_map2
+ statement.executeQuery("SELECT key from test_map2")
}
)
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index 17de8ef56f..dad1e2347c 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -25,7 +25,6 @@ import java.util.concurrent.TimeUnit
import scala.collection.JavaConverters._
import scala.collection.mutable.HashMap
import scala.language.implicitConversions
-import scala.concurrent.duration._
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.hive.common.StatsSetupConst
@@ -34,32 +33,49 @@ import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.conf.HiveConf.ConfVars
import org.apache.hadoop.hive.ql.metadata.Table
import org.apache.hadoop.hive.ql.parse.VariableSubstitution
-import org.apache.hadoop.hive.ql.session.SessionState
import org.apache.hadoop.hive.serde2.io.{DateWritable, TimestampWritable}
-import org.apache.spark.Logging
-import org.apache.spark.SparkContext
import org.apache.spark.annotation.Experimental
-import org.apache.spark.sql._
+import org.apache.spark.api.java.JavaSparkContext
import org.apache.spark.sql.SQLConf.SQLConfEntry
import org.apache.spark.sql.SQLConf.SQLConfEntry._
-import org.apache.spark.sql.catalyst.{SqlParser, TableIdentifier, ParserDialect}
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUDFs, SetCommand}
-import org.apache.spark.sql.execution.datasources.{PreWriteCheck, PreInsertCastAndRename, DataSourceStrategy}
+import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, SqlParser}
+import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, PreInsertCastAndRename, PreWriteCheck}
+import org.apache.spark.sql.execution.{CacheManager, ExecutedCommand, ExtractPythonUDFs, SetCommand}
import org.apache.spark.sql.hive.client._
import org.apache.spark.sql.hive.execution.{DescribeHiveTableCommand, HiveNativeCommand}
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
+import org.apache.spark.{Logging, SparkContext}
/**
* This is the HiveQL Dialect, this dialect is strongly bind with HiveContext
*/
-private[hive] class HiveQLDialect extends ParserDialect {
+private[hive] class HiveQLDialect(sqlContext: HiveContext) extends ParserDialect {
override def parse(sqlText: String): LogicalPlan = {
- HiveQl.parseSql(sqlText)
+ sqlContext.executionHive.withHiveState {
+ HiveQl.parseSql(sqlText)
+ }
+ }
+}
+
+/**
+ * Returns the current database of metadataHive.
+ */
+private[hive] case class CurrentDatabase(ctx: HiveContext)
+ extends LeafExpression with CodegenFallback {
+ override def dataType: DataType = StringType
+ override def foldable: Boolean = true
+ override def nullable: Boolean = false
+ override def eval(input: InternalRow): Any = {
+ UTF8String.fromString(ctx.metadataHive.currentDatabase)
}
}
@@ -69,14 +85,30 @@ private[hive] class HiveQLDialect extends ParserDialect {
*
* @since 1.0.0
*/
-class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging {
+class HiveContext private[hive](
+ sc: SparkContext,
+ cacheManager: CacheManager,
+ @transient execHive: ClientWrapper,
+ @transient metaHive: ClientInterface) extends SQLContext(sc, cacheManager) with Logging {
self =>
- import HiveContext._
+ def this(sc: SparkContext) = this(sc, new CacheManager, null, null)
+ def this(sc: JavaSparkContext) = this(sc.sc)
+
+ import org.apache.spark.sql.hive.HiveContext._
logDebug("create HiveContext")
/**
+ * Returns a new HiveContext as new session, which will have separated SQLConf, UDF/UDAF,
+ * temporary tables and SessionState, but sharing the same CacheManager, IsolatedClientLoader
+ * and Hive client (both of execution and metadata) with existing HiveContext.
+ */
+ override def newSession(): HiveContext = {
+ new HiveContext(sc, cacheManager, executionHive.newSession(), metadataHive.newSession())
+ }
+
+ /**
* When true, enables an experimental feature where metastore tables that use the parquet SerDe
* are automatically converted to use the Spark SQL parquet table scan, instead of the Hive
* SerDe.
@@ -157,14 +189,18 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging {
* for storing persistent metadata, and only point to a dummy metastore in a temporary directory.
*/
@transient
- protected[hive] lazy val executionHive: ClientWrapper = {
+ protected[hive] lazy val executionHive: ClientWrapper = if (execHive != null) {
+ execHive
+ } else {
logInfo(s"Initializing execution hive, version $hiveExecutionVersion")
- new ClientWrapper(
+ val loader = new IsolatedClientLoader(
version = IsolatedClientLoader.hiveVersion(hiveExecutionVersion),
+ execJars = Seq(),
config = newTemporaryConfiguration(),
- initClassLoader = Utils.getContextOrSparkClassLoader)
+ isolationOn = false,
+ baseClassLoader = Utils.getContextOrSparkClassLoader)
+ loader.createClient().asInstanceOf[ClientWrapper]
}
- SessionState.setCurrentSessionState(executionHive.state)
/**
* Overrides default Hive configurations to avoid breaking changes to Spark SQL users.
@@ -182,7 +218,9 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging {
* in the hive-site.xml file.
*/
@transient
- protected[hive] lazy val metadataHive: ClientInterface = {
+ protected[hive] lazy val metadataHive: ClientInterface = if (metaHive != null) {
+ metaHive
+ } else {
val metaVersion = IsolatedClientLoader.hiveVersion(hiveMetastoreVersion)
// We instantiate a HiveConf here to read in the hive-site.xml file and then pass the options
@@ -268,14 +306,10 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging {
barrierPrefixes = hiveMetastoreBarrierPrefixes,
sharedPrefixes = hiveMetastoreSharedPrefixes)
}
- isolatedLoader.client
+ isolatedLoader.createClient()
}
protected[sql] override def parseSql(sql: String): LogicalPlan = {
- var state = SessionState.get()
- if (state == null) {
- SessionState.setCurrentSessionState(tlSession.get().asInstanceOf[SQLSession].sessionState)
- }
super.parseSql(substitutor.substitute(hiveconf, sql))
}
@@ -384,8 +418,6 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging {
}
}
- protected[hive] def hiveconf = tlSession.get().asInstanceOf[this.SQLSession].hiveconf
-
override def setConf(key: String, value: String): Unit = {
super.setConf(key, value)
executionHive.runSqlHive(s"SET $key=$value")
@@ -402,7 +434,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging {
setConf(entry.key, entry.stringConverter(value))
}
- /* A catalyst metadata catalog that points to the Hive Metastore. */
+ /* A catalyst metadata catalog that points to the Hive Metastore. */
@transient
override protected[sql] lazy val catalog =
new HiveMetastoreCatalog(metadataHive, this) with OverrideCatalog
@@ -410,7 +442,13 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging {
// Note that HiveUDFs will be overridden by functions registered in this context.
@transient
override protected[sql] lazy val functionRegistry: FunctionRegistry =
- new HiveFunctionRegistry(FunctionRegistry.builtin)
+ new HiveFunctionRegistry(FunctionRegistry.builtin.copy())
+
+ // The Hive UDF current_database() is foldable, will be evaluated by optimizer, but the optimizer
+ // can't access the SessionState of metadataHive.
+ functionRegistry.registerFunction(
+ "current_database",
+ (expressions: Seq[Expression]) => new CurrentDatabase(this))
/* An analyzer that uses the Hive metastore. */
@transient
@@ -430,10 +468,6 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging {
)
}
- override protected[sql] def createSession(): SQLSession = {
- new this.SQLSession()
- }
-
/** Overridden by child classes that need to set configuration before the client init. */
protected def configure(): Map[String, String] = {
// Hive 0.14.0 introduces timeout operations in HiveConf, and changes default values of a bunch
@@ -488,41 +522,40 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging {
}.toMap
}
- protected[hive] class SQLSession extends super.SQLSession {
- protected[sql] override lazy val conf: SQLConf = new SQLConf {
- override def dialect: String = getConf(SQLConf.DIALECT, "hiveql")
- override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false)
- }
-
- /**
- * SQLConf and HiveConf contracts:
- *
- * 1. reuse existing started SessionState if any
- * 2. when the Hive session is first initialized, params in HiveConf will get picked up by the
- * SQLConf. Additionally, any properties set by set() or a SET command inside sql() will be
- * set in the SQLConf *as well as* in the HiveConf.
- */
- protected[hive] lazy val sessionState: SessionState = {
- var state = SessionState.get()
- if (state == null) {
- state = new SessionState(new HiveConf(classOf[SessionState]))
- SessionState.start(state)
- }
- state
- }
+ /**
+ * SQLConf and HiveConf contracts:
+ *
+ * 1. create a new SessionState for each HiveContext
+ * 2. when the Hive session is first initialized, params in HiveConf will get picked up by the
+ * SQLConf. Additionally, any properties set by set() or a SET command inside sql() will be
+ * set in the SQLConf *as well as* in the HiveConf.
+ */
+ @transient
+ protected[hive] lazy val hiveconf: HiveConf = {
+ val c = executionHive.conf
+ setConf(c.getAllProperties)
+ c
+ }
- protected[hive] lazy val hiveconf: HiveConf = {
- setConf(sessionState.getConf.getAllProperties)
- sessionState.getConf
- }
+ protected[sql] override lazy val conf: SQLConf = new SQLConf {
+ override def dialect: String = getConf(SQLConf.DIALECT, "hiveql")
+ override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false)
}
- override protected[sql] def dialectClassName = if (conf.dialect == "hiveql") {
+ protected[sql] override def dialectClassName = if (conf.dialect == "hiveql") {
classOf[HiveQLDialect].getCanonicalName
} else {
super.dialectClassName
}
+ protected[sql] override def getSQLDialect(): ParserDialect = {
+ if (conf.dialect == "hiveql") {
+ new HiveQLDialect(this)
+ } else {
+ super.getSQLDialect()
+ }
+ }
+
@transient
private val hivePlanner = new SparkPlanner with HiveStrategies {
val hiveContext = self
@@ -598,6 +631,14 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging {
case _ => super.simpleString
}
}
+
+ protected[sql] override def addJar(path: String): Unit = {
+ // Add jar to Hive and classloader
+ executionHive.addJar(path)
+ metadataHive.addJar(path)
+ Thread.currentThread().setContextClassLoader(executionHive.clientLoader.classLoader)
+ super.addJar(path)
+ }
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index 2bf22f5449..250c232856 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -25,29 +25,27 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.conf.HiveConf.ConfVars
-import org.apache.hadoop.hive.serde.serdeConstants
-import org.apache.hadoop.hive.ql.{ErrorMsg, Context}
-import org.apache.hadoop.hive.ql.exec.{FunctionRegistry, FunctionInfo}
+import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry}
import org.apache.hadoop.hive.ql.lib.Node
import org.apache.hadoop.hive.ql.parse._
import org.apache.hadoop.hive.ql.plan.PlanUtils
import org.apache.hadoop.hive.ql.session.SessionState
+import org.apache.hadoop.hive.ql.{Context, ErrorMsg}
+import org.apache.hadoop.hive.serde.serdeConstants
import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe
import org.apache.spark.Logging
-import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst
+import org.apache.spark.sql.{AnalysisException, catalyst}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical
+import org.apache.spark.sql.catalyst.plans.{logical, _}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
import org.apache.spark.sql.execution.ExplainCommand
import org.apache.spark.sql.execution.datasources.DescribeCommand
import org.apache.spark.sql.hive.HiveShim._
import org.apache.spark.sql.hive.client._
-import org.apache.spark.sql.hive.execution.{HiveNativeCommand, DropTable, AnalyzeTable, HiveScriptIOSchema}
+import org.apache.spark.sql.hive.execution.{AnalyzeTable, DropTable, HiveNativeCommand, HiveScriptIOSchema}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
import org.apache.spark.util.random.RandomSampler
@@ -268,7 +266,7 @@ private[hive] object HiveQl extends Logging {
node
}
- private def createContext(): Context = new Context(SessionState.get().getConf())
+ private def createContext(): Context = new Context(hiveConf)
private def getAst(sql: String, context: Context) =
ParseUtils.findRootNonNullToken((new ParseDriver).parse(sql, context))
@@ -277,12 +275,16 @@ private[hive] object HiveQl extends Logging {
* Returns the HiveConf
*/
private[this] def hiveConf: HiveConf = {
- val ss = SessionState.get() // SessionState is lazy initialization, it can be null here
+ var ss = SessionState.get()
+ // SessionState is lazy initialization, it can be null here
if (ss == null) {
- new HiveConf()
- } else {
- ss.getConf
+ val original = Thread.currentThread().getContextClassLoader
+ val conf = new HiveConf(classOf[SessionState])
+ conf.setClassLoader(original)
+ ss = new SessionState(conf)
+ SessionState.start(ss)
}
+ ss.getConf
}
/** Returns a LogicalPlan for a given HiveQL string. */
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala
index 915eae9d21..9d9a55edd7 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala
@@ -178,6 +178,15 @@ private[hive] trait ClientInterface {
holdDDLTime: Boolean,
listBucketingEnabled: Boolean): Unit
+ /** Add a jar into class loader */
+ def addJar(path: String): Unit
+
+ /** Return a ClientInterface as new session, that will share the class loader and Hive client */
+ def newSession(): ClientInterface
+
+ /** Run a function within Hive state (SessionState, HiveConf, Hive client and class loader) */
+ def withHiveState[A](f: => A): A
+
/** Used for testing only. Removes all metadata from this instance of Hive. */
def reset(): Unit
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala
index 8f6d448b2a..3dce86c480 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala
@@ -60,7 +60,8 @@ import org.apache.spark.util.{CircularBuffer, Utils}
private[hive] class ClientWrapper(
override val version: HiveVersion,
config: Map[String, String],
- initClassLoader: ClassLoader)
+ initClassLoader: ClassLoader,
+ val clientLoader: IsolatedClientLoader)
extends ClientInterface
with Logging {
@@ -150,31 +151,29 @@ private[hive] class ClientWrapper(
// Switch to the initClassLoader.
Thread.currentThread().setContextClassLoader(initClassLoader)
val ret = try {
- val oldState = SessionState.get()
- if (oldState == null) {
- val initialConf = new HiveConf(classOf[SessionState])
- // HiveConf is a Hadoop Configuration, which has a field of classLoader and
- // the initial value will be the current thread's context class loader
- // (i.e. initClassLoader at here).
- // We call initialConf.setClassLoader(initClassLoader) at here to make
- // this action explicit.
- initialConf.setClassLoader(initClassLoader)
- config.foreach { case (k, v) =>
- if (k.toLowerCase.contains("password")) {
- logDebug(s"Hive Config: $k=xxx")
- } else {
- logDebug(s"Hive Config: $k=$v")
- }
- initialConf.set(k, v)
+ val initialConf = new HiveConf(classOf[SessionState])
+ // HiveConf is a Hadoop Configuration, which has a field of classLoader and
+ // the initial value will be the current thread's context class loader
+ // (i.e. initClassLoader at here).
+ // We call initialConf.setClassLoader(initClassLoader) at here to make
+ // this action explicit.
+ initialConf.setClassLoader(initClassLoader)
+ config.foreach { case (k, v) =>
+ if (k.toLowerCase.contains("password")) {
+ logDebug(s"Hive Config: $k=xxx")
+ } else {
+ logDebug(s"Hive Config: $k=$v")
}
- val newState = new SessionState(initialConf)
- SessionState.start(newState)
- newState.out = new PrintStream(outputBuffer, true, "UTF-8")
- newState.err = new PrintStream(outputBuffer, true, "UTF-8")
- newState
- } else {
- oldState
+ initialConf.set(k, v)
+ }
+ val state = new SessionState(initialConf)
+ if (clientLoader.cachedHive != null) {
+ Hive.set(clientLoader.cachedHive.asInstanceOf[Hive])
}
+ SessionState.start(state)
+ state.out = new PrintStream(outputBuffer, true, "UTF-8")
+ state.err = new PrintStream(outputBuffer, true, "UTF-8")
+ state
} finally {
Thread.currentThread().setContextClassLoader(original)
}
@@ -188,11 +187,6 @@ private[hive] class ClientWrapper(
conf.get(key, defaultValue)
}
- // TODO: should be a def?s
- // When we create this val client, the HiveConf of it (conf) is the one associated with state.
- @GuardedBy("this")
- private var client = Hive.get(conf)
-
// We use hive's conf for compatibility.
private val retryLimit = conf.getIntVar(HiveConf.ConfVars.METASTORETHRIFTFAILURERETRIES)
private val retryDelayMillis = shim.getMetastoreClientConnectRetryDelayMillis(conf)
@@ -200,7 +194,7 @@ private[hive] class ClientWrapper(
/**
* Runs `f` with multiple retries in case the hive metastore is temporarily unreachable.
*/
- private def retryLocked[A](f: => A): A = synchronized {
+ private def retryLocked[A](f: => A): A = clientLoader.synchronized {
// Hive sometimes retries internally, so set a deadline to avoid compounding delays.
val deadline = System.nanoTime + (retryLimit * retryDelayMillis * 1e6).toLong
var numTries = 0
@@ -215,13 +209,8 @@ private[hive] class ClientWrapper(
logWarning(
"HiveClientWrapper got thrift exception, destroying client and retrying " +
s"(${retryLimit - numTries} tries remaining)", e)
+ clientLoader.cachedHive = null
Thread.sleep(retryDelayMillis)
- try {
- client = Hive.get(state.getConf, true)
- } catch {
- case e: Exception if causedByThrift(e) =>
- logWarning("Failed to refresh hive client, will retry.", e)
- }
}
} while (numTries <= retryLimit && System.nanoTime < deadline)
if (System.nanoTime > deadline) {
@@ -242,13 +231,26 @@ private[hive] class ClientWrapper(
false
}
+ def client: Hive = {
+ if (clientLoader.cachedHive != null) {
+ clientLoader.cachedHive.asInstanceOf[Hive]
+ } else {
+ val c = Hive.get(conf)
+ clientLoader.cachedHive = c
+ c
+ }
+ }
+
/**
* Runs `f` with ThreadLocal session state and classloaders configured for this version of hive.
*/
- private def withHiveState[A](f: => A): A = retryLocked {
+ def withHiveState[A](f: => A): A = retryLocked {
val original = Thread.currentThread().getContextClassLoader
// Set the thread local metastore client to the client associated with this ClientWrapper.
Hive.set(client)
+ // The classloader in clientLoader could be changed after addJar, always use the latest
+ // classloader
+ state.getConf.setClassLoader(clientLoader.classLoader)
// setCurrentSessionState will use the classLoader associated
// with the HiveConf in `state` to override the context class loader of the current
// thread.
@@ -545,6 +547,15 @@ private[hive] class ClientWrapper(
listBucketingEnabled)
}
+ def addJar(path: String): Unit = {
+ clientLoader.addJar(path)
+ runSqlHive(s"ADD JAR $path")
+ }
+
+ def newSession(): ClientWrapper = {
+ clientLoader.createClient().asInstanceOf[ClientWrapper]
+ }
+
def reset(): Unit = withHiveState {
client.getAllTables("default").asScala.foreach { t =>
logDebug(s"Deleting table $t")
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala
index 1fe4cba957..567e4d7b41 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala
@@ -22,6 +22,7 @@ import java.lang.reflect.InvocationTargetException
import java.net.{URL, URLClassLoader}
import java.util
+import scala.collection.mutable
import scala.language.reflectiveCalls
import scala.util.Try
@@ -148,53 +149,75 @@ private[hive] class IsolatedClientLoader(
name.replaceAll("\\.", "/") + ".class"
/** The classloader that is used to load an isolated version of Hive. */
- protected val classLoader: ClassLoader = new URLClassLoader(allJars, rootClassLoader) {
- override def loadClass(name: String, resolve: Boolean): Class[_] = {
- val loaded = findLoadedClass(name)
- if (loaded == null) doLoadClass(name, resolve) else loaded
- }
-
- def doLoadClass(name: String, resolve: Boolean): Class[_] = {
- val classFileName = name.replaceAll("\\.", "/") + ".class"
- if (isBarrierClass(name) && isolationOn) {
- // For barrier classes, we construct a new copy of the class.
- val bytes = IOUtils.toByteArray(baseClassLoader.getResourceAsStream(classFileName))
- logDebug(s"custom defining: $name - ${util.Arrays.hashCode(bytes)}")
- defineClass(name, bytes, 0, bytes.length)
- } else if (!isSharedClass(name)) {
- logDebug(s"hive class: $name - ${getResource(classToPath(name))}")
- super.loadClass(name, resolve)
- } else {
- // For shared classes, we delegate to baseClassLoader.
- logDebug(s"shared class: $name")
- baseClassLoader.loadClass(name)
+ private[hive] var classLoader: ClassLoader = if (isolationOn) {
+ new URLClassLoader(allJars, rootClassLoader) {
+ override def loadClass(name: String, resolve: Boolean): Class[_] = {
+ val loaded = findLoadedClass(name)
+ if (loaded == null) doLoadClass(name, resolve) else loaded
+ }
+ def doLoadClass(name: String, resolve: Boolean): Class[_] = {
+ val classFileName = name.replaceAll("\\.", "/") + ".class"
+ if (isBarrierClass(name)) {
+ // For barrier classes, we construct a new copy of the class.
+ val bytes = IOUtils.toByteArray(baseClassLoader.getResourceAsStream(classFileName))
+ logDebug(s"custom defining: $name - ${util.Arrays.hashCode(bytes)}")
+ defineClass(name, bytes, 0, bytes.length)
+ } else if (!isSharedClass(name)) {
+ logDebug(s"hive class: $name - ${getResource(classToPath(name))}")
+ super.loadClass(name, resolve)
+ } else {
+ // For shared classes, we delegate to baseClassLoader.
+ logDebug(s"shared class: $name")
+ baseClassLoader.loadClass(name)
+ }
}
}
+ } else {
+ baseClassLoader
}
- // Pre-reflective instantiation setup.
- logDebug("Initializing the logger to avoid disaster...")
- Thread.currentThread.setContextClassLoader(classLoader)
+ private[hive] def addJar(path: String): Unit = synchronized {
+ val jarURL = new java.io.File(path).toURI.toURL
+ // TODO: we should avoid of stacking classloaders (use a single URLClassLoader and add jars
+ // to that)
+ classLoader = new java.net.URLClassLoader(Array(jarURL), classLoader)
+ }
/** The isolated client interface to Hive. */
- val client: ClientInterface = try {
- classLoader
- .loadClass(classOf[ClientWrapper].getName)
- .getConstructors.head
- .newInstance(version, config, classLoader)
- .asInstanceOf[ClientInterface]
- } catch {
- case e: InvocationTargetException =>
- if (e.getCause().isInstanceOf[NoClassDefFoundError]) {
- val cnf = e.getCause().asInstanceOf[NoClassDefFoundError]
- throw new ClassNotFoundException(
- s"$cnf when creating Hive client using classpath: ${execJars.mkString(", ")}\n" +
- "Please make sure that jars for your version of hive and hadoop are included in the " +
- s"paths passed to ${HiveContext.HIVE_METASTORE_JARS}.")
- } else {
- throw e
- }
- } finally {
- Thread.currentThread.setContextClassLoader(baseClassLoader)
+ private[hive] def createClient(): ClientInterface = {
+ if (!isolationOn) {
+ return new ClientWrapper(version, config, baseClassLoader, this)
+ }
+ // Pre-reflective instantiation setup.
+ logDebug("Initializing the logger to avoid disaster...")
+ val origLoader = Thread.currentThread().getContextClassLoader
+ Thread.currentThread.setContextClassLoader(classLoader)
+
+ try {
+ classLoader
+ .loadClass(classOf[ClientWrapper].getName)
+ .getConstructors.head
+ .newInstance(version, config, classLoader, this)
+ .asInstanceOf[ClientInterface]
+ } catch {
+ case e: InvocationTargetException =>
+ if (e.getCause().isInstanceOf[NoClassDefFoundError]) {
+ val cnf = e.getCause().asInstanceOf[NoClassDefFoundError]
+ throw new ClassNotFoundException(
+ s"$cnf when creating Hive client using classpath: ${execJars.mkString(", ")}\n" +
+ "Please make sure that jars for your version of hive and hadoop are included in the " +
+ s"paths passed to ${HiveContext.HIVE_METASTORE_JARS}.")
+ } else {
+ throw e
+ }
+ } finally {
+ Thread.currentThread.setContextClassLoader(origLoader)
+ }
}
+
+ /**
+ * The place holder for shared Hive client for all the HiveContext sessions (they share an
+ * IsolatedClientLoader).
+ */
+ private[hive] var cachedHive: Any = null
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
index 9f654eed57..51ec92afd0 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
@@ -18,18 +18,18 @@
package org.apache.spark.sql.hive.execution
import org.apache.hadoop.hive.metastore.MetaStoreUtils
+
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.{TableIdentifier, SqlParser}
+import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.execution.RunnableCommand
-import org.apache.spark.sql.execution.datasources.{ResolvedDataSource, LogicalRelation}
+import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource}
import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
-import org.apache.spark.util.Utils
/**
* Analyzes the given table in the current database to generate statistics, which will be
@@ -86,26 +86,7 @@ case class AddJar(path: String) extends RunnableCommand {
}
override def run(sqlContext: SQLContext): Seq[Row] = {
- val hiveContext = sqlContext.asInstanceOf[HiveContext]
- val currentClassLoader = Utils.getContextOrSparkClassLoader
-
- // Add jar to current context
- val jarURL = new java.io.File(path).toURI.toURL
- val newClassLoader = new java.net.URLClassLoader(Array(jarURL), currentClassLoader)
- Thread.currentThread.setContextClassLoader(newClassLoader)
- // We need to explicitly set the class loader associated with the conf in executionHive's
- // state because this class loader will be used as the context class loader of the current
- // thread to execute any Hive command.
- // We cannot use `org.apache.hadoop.hive.ql.metadata.Hive.get().getConf()` because Hive.get()
- // returns the value of a thread local variable and its HiveConf may not be the HiveConf
- // associated with `executionHive.state` (for example, HiveContext is created in one thread
- // and then add jar is called from another thread).
- hiveContext.executionHive.state.getConf.setClassLoader(newClassLoader)
- // Add jar to isolated hive (metadataHive) class loader.
- hiveContext.runSqlHive(s"ADD JAR $path")
-
- // Add jar to executors
- hiveContext.sparkContext.addJar(path)
+ sqlContext.addJar(path)
Seq(Row(0))
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
index be335a47dc..ff39ccb7c1 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
@@ -116,27 +116,18 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) {
override def executePlan(plan: LogicalPlan): this.QueryExecution =
new this.QueryExecution(plan)
- // Make sure we set those test specific confs correctly when we create
- // the SQLConf as well as when we call clear.
- override protected[sql] def createSession(): SQLSession = {
- new this.SQLSession()
- }
-
- protected[hive] class SQLSession extends super.SQLSession {
- protected[sql] override lazy val conf: SQLConf = new SQLConf {
- // TODO as in unit test, conf.clear() probably be called, all of the value will be cleared.
- // The super.getConf(SQLConf.DIALECT) is "sql" by default, we need to set it as "hiveql"
- override def dialect: String = super.getConf(SQLConf.DIALECT, "hiveql")
- override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false)
+ protected[sql] override lazy val conf: SQLConf = new SQLConf {
+ // The super.getConf(SQLConf.DIALECT) is "sql" by default, we need to set it as "hiveql"
+ override def dialect: String = super.getConf(SQLConf.DIALECT, "hiveql")
+ override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false)
- clear()
+ clear()
- override def clear(): Unit = {
- super.clear()
+ override def clear(): Unit = {
+ super.clear()
- TestHiveContext.overrideConfs.map {
- case (key, value) => setConfString(key, value)
- }
+ TestHiveContext.overrideConfs.map {
+ case (key, value) => setConfString(key, value)
}
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala
index 79cf40aba4..528a7398b1 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala
@@ -17,22 +17,15 @@
package org.apache.spark.sql.hive
-import org.apache.hadoop.hive.conf.HiveConf
-import org.apache.hadoop.hive.ql.session.SessionState
import org.apache.hadoop.hive.serde.serdeConstants
+import org.scalatest.BeforeAndAfterAll
+
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.hive.client.{ManagedTable, HiveColumn, ExternalTable, HiveTable}
-import org.scalatest.BeforeAndAfterAll
+import org.apache.spark.sql.hive.client.{ExternalTable, HiveColumn, HiveTable, ManagedTable}
class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll {
- override def beforeAll() {
- if (SessionState.get() == null) {
- SessionState.start(new HiveConf())
- }
- }
-
private def extractTableDesc(sql: String): (HiveTable, Boolean) = {
HiveQl.createPlan(sql).collect {
case CreateTableAsSelect(desc, child, allowExisting) => (desc, allowExisting)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala
index 2da22ec237..c6d034a23a 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala
@@ -53,7 +53,7 @@ class VersionsSuite extends SparkFunSuite with Logging {
test("success sanity check") {
val badClient = IsolatedClientLoader.forVersion(HiveContext.hiveExecutionVersion,
buildConf(),
- ivyPath).client
+ ivyPath).createClient()
val db = new HiveDatabase("default", "")
badClient.createDatabase(db)
}
@@ -83,7 +83,7 @@ class VersionsSuite extends SparkFunSuite with Logging {
ignore("failure sanity check") {
val e = intercept[Throwable] {
val badClient = quietly {
- IsolatedClientLoader.forVersion("13", buildConf(), ivyPath).client
+ IsolatedClientLoader.forVersion("13", buildConf(), ivyPath).createClient()
}
}
assert(getNestedMessages(e) contains "Unknown column 'A0.OWNER_NAME' in 'field list'")
@@ -97,7 +97,7 @@ class VersionsSuite extends SparkFunSuite with Logging {
test(s"$version: create client") {
client = null
System.gc() // Hack to avoid SEGV on some JVM versions.
- client = IsolatedClientLoader.forVersion(version, buildConf(), ivyPath).client
+ client = IsolatedClientLoader.forVersion(version, buildConf(), ivyPath).createClient()
}
test(s"$version: createDatabase") {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index fe63ad5683..2878500453 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -1133,6 +1133,38 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
conf.clear()
}
+ test("current_database with multiple sessions") {
+ sql("create database a")
+ sql("use a")
+ val s2 = newSession()
+ s2.sql("create database b")
+ s2.sql("use b")
+
+ assert(sql("select current_database()").first() === Row("a"))
+ assert(s2.sql("select current_database()").first() === Row("b"))
+
+ try {
+ sql("create table test_a(key INT, value STRING)")
+ s2.sql("create table test_b(key INT, value STRING)")
+
+ sql("select * from test_a")
+ intercept[AnalysisException] {
+ sql("select * from test_b")
+ }
+ sql("select * from b.test_b")
+
+ s2.sql("select * from test_b")
+ intercept[AnalysisException] {
+ s2.sql("select * from test_a")
+ }
+ s2.sql("select * from a.test_a")
+ } finally {
+ sql("DROP TABLE IF EXISTS test_a")
+ s2.sql("DROP TABLE IF EXISTS test_b")
+ }
+
+ }
+
createQueryTest("select from thrift based table",
"SELECT * from src_thrift")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index ec5b83b98e..ccc15eaa63 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -160,10 +160,15 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
}
test("show functions") {
- val allFunctions =
+ val allBuiltinFunctions =
(FunctionRegistry.builtin.listFunction().toSet[String] ++
org.apache.hadoop.hive.ql.exec.FunctionRegistry.getFunctionNames.asScala).toList.sorted
- checkAnswer(sql("SHOW functions"), allFunctions.map(Row(_)))
+ // The TestContext is shared by all the test cases, some functions may be registered before
+ // this, so we check that all the builtin functions are returned.
+ val allFunctions = sql("SHOW functions").collect().map(r => r(0))
+ allBuiltinFunctions.foreach { f =>
+ assert(allFunctions.contains(f))
+ }
checkAnswer(sql("SHOW functions abs"), Row("abs"))
checkAnswer(sql("SHOW functions 'abs'"), Row("abs"))
checkAnswer(sql("SHOW functions abc.abs"), Row("abs"))