aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKunal Khamar <kkhamar@outlook.com>2017-03-08 13:06:22 -0800
committerShixiong Zhu <shixiong@databricks.com>2017-03-08 13:20:45 -0800
commit6570cfd7abe349dc6d2151f2ac9dc662e7465a79 (patch)
tree97b54a89a3d228c737203989d6b68db5ec75d8ef
parent1bf9012380de2aa7bdf39220b55748defde8b700 (diff)
downloadspark-6570cfd7abe349dc6d2151f2ac9dc662e7465a79.tar.gz
spark-6570cfd7abe349dc6d2151f2ac9dc662e7465a79.tar.bz2
spark-6570cfd7abe349dc6d2151f2ac9dc662e7465a79.zip
[SPARK-19540][SQL] Add ability to clone SparkSession wherein cloned session has an identical copy of the SessionState
Forking a newSession() from SparkSession currently makes a new SparkSession that does not retain SessionState (i.e. temporary tables, SQL config, registered functions etc.) This change adds a method cloneSession() which creates a new SparkSession with a copy of the parent's SessionState. Subsequent changes to base session are not propagated to cloned session, clone is independent after creation. If the base is changed after clone has been created, say user registers new UDF, then the new UDF will not be available inside the clone. Same goes for configs and temp tables. Unit tests Author: Kunal Khamar <kkhamar@outlook.com> Author: Shixiong Zhu <shixiong@databricks.com> Closes #16826 from kunalkhamar/fork-sparksession.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala38
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala55
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala59
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala235
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala162
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala21
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala18
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala20
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala5
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala92
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala261
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala2
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala67
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionCatalogSuite.scala112
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala41
20 files changed, 981 insertions, 236 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
index fb99cb27b8..cff0efa979 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
@@ -66,6 +66,8 @@ trait CatalystConf {
/** The maximum number of joined nodes allowed in the dynamic programming algorithm. */
def joinReorderDPThreshold: Int
+
+ override def clone(): CatalystConf = throw new CloneNotSupportedException()
}
@@ -85,4 +87,7 @@ case class SimpleCatalystConf(
joinReorderDPThreshold: Int = 12,
warehousePath: String = "/user/hive/warehouse",
sessionLocalTimeZone: String = TimeZone.getDefault().getID)
- extends CatalystConf
+ extends CatalystConf {
+
+ override def clone(): SimpleCatalystConf = this.copy()
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 556fa99017..0dcb44081f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -64,6 +64,8 @@ trait FunctionRegistry {
/** Clear all registered functions. */
def clear(): Unit
+ /** Create a copy of this registry with identical functions as this registry. */
+ override def clone(): FunctionRegistry = throw new CloneNotSupportedException()
}
class SimpleFunctionRegistry extends FunctionRegistry {
@@ -107,7 +109,7 @@ class SimpleFunctionRegistry extends FunctionRegistry {
functionBuilders.clear()
}
- def copy(): SimpleFunctionRegistry = synchronized {
+ override def clone(): SimpleFunctionRegistry = synchronized {
val registry = new SimpleFunctionRegistry
functionBuilders.iterator.foreach { case (name, (info, builder)) =>
registry.registerFunction(name, info, builder)
@@ -150,6 +152,7 @@ object EmptyFunctionRegistry extends FunctionRegistry {
throw new UnsupportedOperationException
}
+ override def clone(): FunctionRegistry = this
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
index 831e37aac1..6cfc4a4321 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
@@ -50,7 +50,6 @@ object SessionCatalog {
class SessionCatalog(
externalCatalog: ExternalCatalog,
globalTempViewManager: GlobalTempViewManager,
- functionResourceLoader: FunctionResourceLoader,
functionRegistry: FunctionRegistry,
conf: CatalystConf,
hadoopConf: Configuration,
@@ -66,16 +65,19 @@ class SessionCatalog(
this(
externalCatalog,
new GlobalTempViewManager("global_temp"),
- DummyFunctionResourceLoader,
functionRegistry,
conf,
new Configuration(),
CatalystSqlParser)
+ functionResourceLoader = DummyFunctionResourceLoader
}
// For testing only.
def this(externalCatalog: ExternalCatalog) {
- this(externalCatalog, new SimpleFunctionRegistry, new SimpleCatalystConf(true))
+ this(
+ externalCatalog,
+ new SimpleFunctionRegistry,
+ SimpleCatalystConf(caseSensitiveAnalysis = true))
}
/** List of temporary tables, mapping from table name to their logical plan. */
@@ -89,6 +91,8 @@ class SessionCatalog(
@GuardedBy("this")
protected var currentDb = formatDatabaseName(DEFAULT_DATABASE)
+ @volatile var functionResourceLoader: FunctionResourceLoader = _
+
/**
* Checks if the given name conforms the Hive standard ("[a-zA-z_0-9]+"),
* i.e. if this name only contains characters, numbers, and _.
@@ -987,6 +991,9 @@ class SessionCatalog(
* by a tuple (resource type, resource uri).
*/
def loadFunctionResources(resources: Seq[FunctionResource]): Unit = {
+ if (functionResourceLoader == null) {
+ throw new IllegalStateException("functionResourceLoader has not yet been initialized")
+ }
resources.foreach(functionResourceLoader.loadResource)
}
@@ -1182,4 +1189,29 @@ class SessionCatalog(
}
}
+ /**
+ * Create a new [[SessionCatalog]] with the provided parameters. `externalCatalog` and
+ * `globalTempViewManager` are `inherited`, while `currentDb` and `tempTables` are copied.
+ */
+ def newSessionCatalogWith(
+ conf: CatalystConf,
+ hadoopConf: Configuration,
+ functionRegistry: FunctionRegistry,
+ parser: ParserInterface): SessionCatalog = {
+ val catalog = new SessionCatalog(
+ externalCatalog,
+ globalTempViewManager,
+ functionRegistry,
+ conf,
+ hadoopConf,
+ parser)
+
+ synchronized {
+ catalog.currentDb = currentDb
+ // copy over temporary tables
+ tempTables.foreach(kv => catalog.tempTables.put(kv._1, kv._2))
+ }
+
+ catalog
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
index 328a16c4bf..7e74dcdef0 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.catalog
+import org.apache.hadoop.conf.Configuration
+
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{FunctionIdentifier, SimpleCatalystConf, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis._
@@ -1197,6 +1199,59 @@ class SessionCatalogSuite extends PlanTest {
}
}
+ test("clone SessionCatalog - temp views") {
+ val externalCatalog = newEmptyCatalog()
+ val original = new SessionCatalog(externalCatalog)
+ val tempTable1 = Range(1, 10, 1, 10)
+ original.createTempView("copytest1", tempTable1, overrideIfExists = false)
+
+ // check if tables copied over
+ val clone = original.newSessionCatalogWith(
+ SimpleCatalystConf(caseSensitiveAnalysis = true),
+ new Configuration(),
+ new SimpleFunctionRegistry,
+ CatalystSqlParser)
+ assert(original ne clone)
+ assert(clone.getTempView("copytest1") == Some(tempTable1))
+
+ // check if clone and original independent
+ clone.dropTable(TableIdentifier("copytest1"), ignoreIfNotExists = false, purge = false)
+ assert(original.getTempView("copytest1") == Some(tempTable1))
+
+ val tempTable2 = Range(1, 20, 2, 10)
+ original.createTempView("copytest2", tempTable2, overrideIfExists = false)
+ assert(clone.getTempView("copytest2").isEmpty)
+ }
+
+ test("clone SessionCatalog - current db") {
+ val externalCatalog = newEmptyCatalog()
+ val db1 = "db1"
+ val db2 = "db2"
+ val db3 = "db3"
+
+ externalCatalog.createDatabase(newDb(db1), ignoreIfExists = true)
+ externalCatalog.createDatabase(newDb(db2), ignoreIfExists = true)
+ externalCatalog.createDatabase(newDb(db3), ignoreIfExists = true)
+
+ val original = new SessionCatalog(externalCatalog)
+ original.setCurrentDatabase(db1)
+
+ // check if current db copied over
+ val clone = original.newSessionCatalogWith(
+ SimpleCatalystConf(caseSensitiveAnalysis = true),
+ new Configuration(),
+ new SimpleFunctionRegistry,
+ CatalystSqlParser)
+ assert(original ne clone)
+ assert(clone.getCurrentDatabase == db1)
+
+ // check if clone and original independent
+ clone.setCurrentDatabase(db2)
+ assert(original.getCurrentDatabase == db1)
+ original.setCurrentDatabase(db3)
+ assert(clone.getCurrentDatabase == db2)
+ }
+
test("SPARK-19737: detect undefined functions without triggering relation resolution") {
import org.apache.spark.sql.catalyst.dsl.plans._
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala
index 1e8ba51e59..bd8dd6ea3f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala
@@ -46,4 +46,10 @@ class ExperimentalMethods private[sql]() {
@volatile var extraOptimizations: Seq[Rule[LogicalPlan]] = Nil
+ override def clone(): ExperimentalMethods = {
+ val result = new ExperimentalMethods
+ result.extraStrategies = extraStrategies
+ result.extraOptimizations = extraOptimizations
+ result
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index afc1827e7e..49562578b2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -21,7 +21,6 @@ import java.io.Closeable
import java.util.concurrent.atomic.AtomicReference
import scala.collection.JavaConverters._
-import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal
@@ -43,7 +42,7 @@ import org.apache.spark.sql.internal.{CatalogImpl, SessionState, SharedState}
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.streaming._
-import org.apache.spark.sql.types.{DataType, LongType, StructType}
+import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.util.ExecutionListenerManager
import org.apache.spark.util.Utils
@@ -67,15 +66,22 @@ import org.apache.spark.util.Utils
* .config("spark.some.config.option", "some-value")
* .getOrCreate()
* }}}
+ *
+ * @param sparkContext The Spark context associated with this Spark session.
+ * @param existingSharedState If supplied, use the existing shared state
+ * instead of creating a new one.
+ * @param parentSessionState If supplied, inherit all session state (i.e. temporary
+ * views, SQL config, UDFs etc) from parent.
*/
@InterfaceStability.Stable
class SparkSession private(
@transient val sparkContext: SparkContext,
- @transient private val existingSharedState: Option[SharedState])
+ @transient private val existingSharedState: Option[SharedState],
+ @transient private val parentSessionState: Option[SessionState])
extends Serializable with Closeable with Logging { self =>
private[sql] def this(sc: SparkContext) {
- this(sc, None)
+ this(sc, None, None)
}
sparkContext.assertNotStopped()
@@ -108,6 +114,7 @@ class SparkSession private(
/**
* State isolated across sessions, including SQL configurations, temporary tables, registered
* functions, and everything else that accepts a [[org.apache.spark.sql.internal.SQLConf]].
+ * If `parentSessionState` is not null, the `SessionState` will be a copy of the parent.
*
* This is internal to Spark and there is no guarantee on interface stability.
*
@@ -116,9 +123,13 @@ class SparkSession private(
@InterfaceStability.Unstable
@transient
lazy val sessionState: SessionState = {
- SparkSession.reflect[SessionState, SparkSession](
- SparkSession.sessionStateClassName(sparkContext.conf),
- self)
+ parentSessionState
+ .map(_.clone(this))
+ .getOrElse {
+ SparkSession.instantiateSessionState(
+ SparkSession.sessionStateClassName(sparkContext.conf),
+ self)
+ }
}
/**
@@ -208,7 +219,25 @@ class SparkSession private(
* @since 2.0.0
*/
def newSession(): SparkSession = {
- new SparkSession(sparkContext, Some(sharedState))
+ new SparkSession(sparkContext, Some(sharedState), parentSessionState = None)
+ }
+
+ /**
+ * Create an identical copy of this `SparkSession`, sharing the underlying `SparkContext`
+ * and shared state. All the state of this session (i.e. SQL configurations, temporary tables,
+ * registered functions) is copied over, and the cloned session is set up with the same shared
+ * state as this session. The cloned session is independent of this session, that is, any
+ * non-global change in either session is not reflected in the other.
+ *
+ * @note Other than the `SparkContext`, all shared state is initialized lazily.
+ * This method will force the initialization of the shared state to ensure that parent
+ * and child sessions are set up with the same shared state. If the underlying catalog
+ * implementation is Hive, this will initialize the metastore, which may take some time.
+ */
+ private[sql] def cloneSession(): SparkSession = {
+ val result = new SparkSession(sparkContext, Some(sharedState), Some(sessionState))
+ result.sessionState // force copy of SessionState
+ result
}
@@ -971,16 +1000,18 @@ object SparkSession {
}
/**
- * Helper method to create an instance of [[T]] using a single-arg constructor that
- * accepts an [[Arg]].
+ * Helper method to create an instance of `SessionState` based on `className` from conf.
+ * The result is either `SessionState` or `HiveSessionState`.
*/
- private def reflect[T, Arg <: AnyRef](
+ private def instantiateSessionState(
className: String,
- ctorArg: Arg)(implicit ctorArgTag: ClassTag[Arg]): T = {
+ sparkSession: SparkSession): SessionState = {
+
try {
+ // get `SessionState.apply(SparkSession)`
val clazz = Utils.classForName(className)
- val ctor = clazz.getDeclaredConstructor(ctorArgTag.runtimeClass)
- ctor.newInstance(ctorArg).asInstanceOf[T]
+ val method = clazz.getMethod("apply", sparkSession.getClass)
+ method.invoke(null, sparkSession).asInstanceOf[SessionState]
} catch {
case NonFatal(e) =>
throw new IllegalArgumentException(s"Error while instantiating '$className':", e)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
index 4d781b96ab..8b598cc60e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
@@ -66,7 +66,8 @@ class ResolveSQLOnFile(sparkSession: SparkSession) extends Rule[LogicalPlan] {
* Preprocess [[CreateTable]], to do some normalization and checking.
*/
case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[LogicalPlan] {
- private val catalog = sparkSession.sessionState.catalog
+ // catalog is a def and not a val/lazy val as the latter would introduce a circular reference
+ private def catalog = sparkSession.sessionState.catalog
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// When we CREATE TABLE without specifying the table schema, we should fail the query if
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 94e3fa7dd1..1244f690fd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -1019,6 +1019,14 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging {
def clear(): Unit = {
settings.clear()
}
+
+ override def clone(): SQLConf = {
+ val result = new SQLConf
+ getAllConfs.foreach {
+ case(k, v) => if (v ne null) result.setConfString(k, v)
+ }
+ result
+ }
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
index 6908560511..ce80604bd3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
@@ -22,38 +22,49 @@ import java.io.File
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
+import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry}
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.command.AnalyzeTableCommand
import org.apache.spark.sql.execution.datasources._
-import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryManager}
+import org.apache.spark.sql.streaming.StreamingQueryManager
import org.apache.spark.sql.util.ExecutionListenerManager
/**
* A class that holds all session-specific state in a given [[SparkSession]].
+ * @param sparkContext The [[SparkContext]].
+ * @param sharedState The shared state.
+ * @param conf SQL-specific key-value configurations.
+ * @param experimentalMethods The experimental methods.
+ * @param functionRegistry Internal catalog for managing functions registered by the user.
+ * @param catalog Internal catalog for managing table and database states.
+ * @param sqlParser Parser that extracts expressions, plans, table identifiers etc. from SQL texts.
+ * @param analyzer Logical query plan analyzer for resolving unresolved attributes and relations.
+ * @param streamingQueryManager Interface to start and stop
+ * [[org.apache.spark.sql.streaming.StreamingQuery]]s.
+ * @param queryExecutionCreator Lambda to create a [[QueryExecution]] from a [[LogicalPlan]]
*/
-private[sql] class SessionState(sparkSession: SparkSession) {
+private[sql] class SessionState(
+ sparkContext: SparkContext,
+ sharedState: SharedState,
+ val conf: SQLConf,
+ val experimentalMethods: ExperimentalMethods,
+ val functionRegistry: FunctionRegistry,
+ val catalog: SessionCatalog,
+ val sqlParser: ParserInterface,
+ val analyzer: Analyzer,
+ val streamingQueryManager: StreamingQueryManager,
+ val queryExecutionCreator: LogicalPlan => QueryExecution) {
- // Note: These are all lazy vals because they depend on each other (e.g. conf) and we
- // want subclasses to override some of the fields. Otherwise, we would get a lot of NPEs.
-
- /**
- * SQL-specific key-value configurations.
- */
- lazy val conf: SQLConf = new SQLConf
-
- def newHadoopConf(): Configuration = {
- val hadoopConf = new Configuration(sparkSession.sparkContext.hadoopConfiguration)
- conf.getAllConfs.foreach { case (k, v) => if (v ne null) hadoopConf.set(k, v) }
- hadoopConf
- }
+ def newHadoopConf(): Configuration = SessionState.newHadoopConf(
+ sparkContext.hadoopConfiguration,
+ conf)
def newHadoopConfWithOptions(options: Map[String, String]): Configuration = {
val hadoopConf = newHadoopConf()
@@ -65,22 +76,15 @@ private[sql] class SessionState(sparkSession: SparkSession) {
hadoopConf
}
- lazy val experimentalMethods = new ExperimentalMethods
-
- /**
- * Internal catalog for managing functions registered by the user.
- */
- lazy val functionRegistry: FunctionRegistry = FunctionRegistry.builtin.copy()
-
/**
* A class for loading resources specified by a function.
*/
- lazy val functionResourceLoader: FunctionResourceLoader = {
+ val functionResourceLoader: FunctionResourceLoader = {
new FunctionResourceLoader {
override def loadResource(resource: FunctionResource): Unit = {
resource.resourceType match {
case JarResource => addJar(resource.uri)
- case FileResource => sparkSession.sparkContext.addFile(resource.uri)
+ case FileResource => sparkContext.addFile(resource.uri)
case ArchiveResource =>
throw new AnalysisException(
"Archive is not allowed to be loaded. If YARN mode is used, " +
@@ -91,92 +95,77 @@ private[sql] class SessionState(sparkSession: SparkSession) {
}
/**
- * Internal catalog for managing table and database states.
- */
- lazy val catalog = new SessionCatalog(
- sparkSession.sharedState.externalCatalog,
- sparkSession.sharedState.globalTempViewManager,
- functionResourceLoader,
- functionRegistry,
- conf,
- newHadoopConf(),
- sqlParser)
-
- /**
* Interface exposed to the user for registering user-defined functions.
* Note that the user-defined functions must be deterministic.
*/
- lazy val udf: UDFRegistration = new UDFRegistration(functionRegistry)
-
- /**
- * Logical query plan analyzer for resolving unresolved attributes and relations.
- */
- lazy val analyzer: Analyzer = {
- new Analyzer(catalog, conf) {
- override val extendedResolutionRules =
- new FindDataSourceTable(sparkSession) ::
- new ResolveSQLOnFile(sparkSession) :: Nil
-
- override val postHocResolutionRules =
- PreprocessTableCreation(sparkSession) ::
- PreprocessTableInsertion(conf) ::
- DataSourceAnalysis(conf) :: Nil
-
- override val extendedCheckRules = Seq(PreWriteCheck, HiveOnlyCheck)
- }
- }
+ val udf: UDFRegistration = new UDFRegistration(functionRegistry)
/**
* Logical query plan optimizer.
*/
- lazy val optimizer: Optimizer = new SparkOptimizer(catalog, conf, experimentalMethods)
-
- /**
- * Parser that extracts expressions, plans, table identifiers etc. from SQL texts.
- */
- lazy val sqlParser: ParserInterface = new SparkSqlParser(conf)
+ val optimizer: Optimizer = new SparkOptimizer(catalog, conf, experimentalMethods)
/**
* Planner that converts optimized logical plans to physical plans.
*/
def planner: SparkPlanner =
- new SparkPlanner(sparkSession.sparkContext, conf, experimentalMethods.extraStrategies)
+ new SparkPlanner(sparkContext, conf, experimentalMethods.extraStrategies)
/**
* An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s
* that listen for execution metrics.
*/
- lazy val listenerManager: ExecutionListenerManager = new ExecutionListenerManager
+ val listenerManager: ExecutionListenerManager = new ExecutionListenerManager
/**
- * Interface to start and stop [[StreamingQuery]]s.
+ * Get an identical copy of the `SessionState` and associate it with the given `SparkSession`
*/
- lazy val streamingQueryManager: StreamingQueryManager = {
- new StreamingQueryManager(sparkSession)
- }
+ def clone(newSparkSession: SparkSession): SessionState = {
+ val sparkContext = newSparkSession.sparkContext
+ val confCopy = conf.clone()
+ val functionRegistryCopy = functionRegistry.clone()
+ val sqlParser: ParserInterface = new SparkSqlParser(confCopy)
+ val catalogCopy = catalog.newSessionCatalogWith(
+ confCopy,
+ SessionState.newHadoopConf(sparkContext.hadoopConfiguration, confCopy),
+ functionRegistryCopy,
+ sqlParser)
+ val queryExecutionCreator = (plan: LogicalPlan) => new QueryExecution(newSparkSession, plan)
- private val jarClassLoader: NonClosableMutableURLClassLoader =
- sparkSession.sharedState.jarClassLoader
+ SessionState.mergeSparkConf(confCopy, sparkContext.getConf)
- // Automatically extract all entries and put it in our SQLConf
- // We need to call it after all of vals have been initialized.
- sparkSession.sparkContext.getConf.getAll.foreach { case (k, v) =>
- conf.setConfString(k, v)
+ new SessionState(
+ sparkContext,
+ newSparkSession.sharedState,
+ confCopy,
+ experimentalMethods.clone(),
+ functionRegistryCopy,
+ catalogCopy,
+ sqlParser,
+ SessionState.createAnalyzer(newSparkSession, catalogCopy, confCopy),
+ new StreamingQueryManager(newSparkSession),
+ queryExecutionCreator)
}
// ------------------------------------------------------
// Helper methods, partially leftover from pre-2.0 days
// ------------------------------------------------------
- def executePlan(plan: LogicalPlan): QueryExecution = new QueryExecution(sparkSession, plan)
+ def executePlan(plan: LogicalPlan): QueryExecution = queryExecutionCreator(plan)
def refreshTable(tableName: String): Unit = {
catalog.refreshTable(sqlParser.parseTableIdentifier(tableName))
}
+ /**
+ * Add a jar path to [[SparkContext]] and the classloader.
+ *
+ * Note: this method seems not access any session state, but the subclass `HiveSessionState` needs
+ * to add the jar to its hive client for the current session. Hence, it still needs to be in
+ * [[SessionState]].
+ */
def addJar(path: String): Unit = {
- sparkSession.sparkContext.addJar(path)
-
+ sparkContext.addJar(path)
val uri = new Path(path).toUri
val jarURL = if (uri.getScheme == null) {
// `path` is a local file path without a URL scheme
@@ -185,15 +174,93 @@ private[sql] class SessionState(sparkSession: SparkSession) {
// `path` is a URL with a scheme
uri.toURL
}
- jarClassLoader.addURL(jarURL)
- Thread.currentThread().setContextClassLoader(jarClassLoader)
+ sharedState.jarClassLoader.addURL(jarURL)
+ Thread.currentThread().setContextClassLoader(sharedState.jarClassLoader)
+ }
+}
+
+
+private[sql] object SessionState {
+
+ def apply(sparkSession: SparkSession): SessionState = {
+ apply(sparkSession, new SQLConf)
+ }
+
+ def apply(sparkSession: SparkSession, sqlConf: SQLConf): SessionState = {
+ val sparkContext = sparkSession.sparkContext
+
+ // Automatically extract all entries and put them in our SQLConf
+ mergeSparkConf(sqlConf, sparkContext.getConf)
+
+ val functionRegistry = FunctionRegistry.builtin.clone()
+
+ val sqlParser: ParserInterface = new SparkSqlParser(sqlConf)
+
+ val catalog = new SessionCatalog(
+ sparkSession.sharedState.externalCatalog,
+ sparkSession.sharedState.globalTempViewManager,
+ functionRegistry,
+ sqlConf,
+ newHadoopConf(sparkContext.hadoopConfiguration, sqlConf),
+ sqlParser)
+
+ val analyzer: Analyzer = createAnalyzer(sparkSession, catalog, sqlConf)
+
+ val streamingQueryManager: StreamingQueryManager = new StreamingQueryManager(sparkSession)
+
+ val queryExecutionCreator = (plan: LogicalPlan) => new QueryExecution(sparkSession, plan)
+
+ val sessionState = new SessionState(
+ sparkContext,
+ sparkSession.sharedState,
+ sqlConf,
+ new ExperimentalMethods,
+ functionRegistry,
+ catalog,
+ sqlParser,
+ analyzer,
+ streamingQueryManager,
+ queryExecutionCreator)
+ // functionResourceLoader needs to access SessionState.addJar, so it cannot be created before
+ // creating SessionState. Setting `catalog.functionResourceLoader` here is safe since the caller
+ // cannot use SessionCatalog before we return SessionState.
+ catalog.functionResourceLoader = sessionState.functionResourceLoader
+ sessionState
+ }
+
+ def newHadoopConf(hadoopConf: Configuration, sqlConf: SQLConf): Configuration = {
+ val newHadoopConf = new Configuration(hadoopConf)
+ sqlConf.getAllConfs.foreach { case (k, v) => if (v ne null) newHadoopConf.set(k, v) }
+ newHadoopConf
+ }
+
+ /**
+ * Create an logical query plan `Analyzer` with rules specific to a non-Hive `SessionState`.
+ */
+ private def createAnalyzer(
+ sparkSession: SparkSession,
+ catalog: SessionCatalog,
+ sqlConf: SQLConf): Analyzer = {
+ new Analyzer(catalog, sqlConf) {
+ override val extendedResolutionRules: Seq[Rule[LogicalPlan]] =
+ new FindDataSourceTable(sparkSession) ::
+ new ResolveSQLOnFile(sparkSession) :: Nil
+
+ override val postHocResolutionRules: Seq[Rule[LogicalPlan]] =
+ PreprocessTableCreation(sparkSession) ::
+ PreprocessTableInsertion(sqlConf) ::
+ DataSourceAnalysis(sqlConf) :: Nil
+
+ override val extendedCheckRules = Seq(PreWriteCheck, HiveOnlyCheck)
+ }
}
/**
- * Analyzes the given table in the current database to generate statistics, which will be
- * used in query optimizations.
+ * Extract entries from `SparkConf` and put them in the `SQLConf`
*/
- def analyze(tableIdent: TableIdentifier, noscan: Boolean = true): Unit = {
- AnalyzeTableCommand(tableIdent, noscan).run(sparkSession)
+ def mergeSparkConf(sqlConf: SQLConf, sparkConf: SparkConf): Unit = {
+ sparkConf.getAll.foreach { case (k, v) =>
+ sqlConf.setConfString(k, v)
+ }
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala
new file mode 100644
index 0000000000..2d5e37242a
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.BeforeAndAfterEach
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+
+class SessionStateSuite extends SparkFunSuite
+ with BeforeAndAfterEach with BeforeAndAfterAll {
+
+ /**
+ * A shared SparkSession for all tests in this suite. Make sure you reset any changes to this
+ * session as this is a singleton HiveSparkSession in HiveSessionStateSuite and it's shared
+ * with all Hive test suites.
+ */
+ protected var activeSession: SparkSession = _
+
+ override def beforeAll(): Unit = {
+ activeSession = SparkSession.builder().master("local").getOrCreate()
+ }
+
+ override def afterAll(): Unit = {
+ if (activeSession != null) {
+ activeSession.stop()
+ activeSession = null
+ }
+ super.afterAll()
+ }
+
+ test("fork new session and inherit RuntimeConfig options") {
+ val key = "spark-config-clone"
+ try {
+ activeSession.conf.set(key, "active")
+
+ // inheritance
+ val forkedSession = activeSession.cloneSession()
+ assert(forkedSession ne activeSession)
+ assert(forkedSession.conf ne activeSession.conf)
+ assert(forkedSession.conf.get(key) == "active")
+
+ // independence
+ forkedSession.conf.set(key, "forked")
+ assert(activeSession.conf.get(key) == "active")
+ activeSession.conf.set(key, "dontcopyme")
+ assert(forkedSession.conf.get(key) == "forked")
+ } finally {
+ activeSession.conf.unset(key)
+ }
+ }
+
+ test("fork new session and inherit function registry and udf") {
+ val testFuncName1 = "strlenScala"
+ val testFuncName2 = "addone"
+ try {
+ activeSession.udf.register(testFuncName1, (_: String).length + (_: Int))
+ val forkedSession = activeSession.cloneSession()
+
+ // inheritance
+ assert(forkedSession ne activeSession)
+ assert(forkedSession.sessionState.functionRegistry ne
+ activeSession.sessionState.functionRegistry)
+ assert(forkedSession.sessionState.functionRegistry.lookupFunction(testFuncName1).nonEmpty)
+
+ // independence
+ forkedSession.sessionState.functionRegistry.dropFunction(testFuncName1)
+ assert(activeSession.sessionState.functionRegistry.lookupFunction(testFuncName1).nonEmpty)
+ activeSession.udf.register(testFuncName2, (_: Int) + 1)
+ assert(forkedSession.sessionState.functionRegistry.lookupFunction(testFuncName2).isEmpty)
+ } finally {
+ activeSession.sessionState.functionRegistry.dropFunction(testFuncName1)
+ activeSession.sessionState.functionRegistry.dropFunction(testFuncName2)
+ }
+ }
+
+ test("fork new session and inherit experimental methods") {
+ val originalExtraOptimizations = activeSession.experimental.extraOptimizations
+ val originalExtraStrategies = activeSession.experimental.extraStrategies
+ try {
+ object DummyRule1 extends Rule[LogicalPlan] {
+ def apply(p: LogicalPlan): LogicalPlan = p
+ }
+ object DummyRule2 extends Rule[LogicalPlan] {
+ def apply(p: LogicalPlan): LogicalPlan = p
+ }
+ val optimizations = List(DummyRule1, DummyRule2)
+ activeSession.experimental.extraOptimizations = optimizations
+ val forkedSession = activeSession.cloneSession()
+
+ // inheritance
+ assert(forkedSession ne activeSession)
+ assert(forkedSession.experimental ne activeSession.experimental)
+ assert(forkedSession.experimental.extraOptimizations.toSet ==
+ activeSession.experimental.extraOptimizations.toSet)
+
+ // independence
+ forkedSession.experimental.extraOptimizations = List(DummyRule2)
+ assert(activeSession.experimental.extraOptimizations == optimizations)
+ activeSession.experimental.extraOptimizations = List(DummyRule1)
+ assert(forkedSession.experimental.extraOptimizations == List(DummyRule2))
+ } finally {
+ activeSession.experimental.extraOptimizations = originalExtraOptimizations
+ activeSession.experimental.extraStrategies = originalExtraStrategies
+ }
+ }
+
+ test("fork new sessions and run query on inherited table") {
+ def checkTableExists(sparkSession: SparkSession): Unit = {
+ QueryTest.checkAnswer(sparkSession.sql(
+ """
+ |SELECT x.str, COUNT(*)
+ |FROM df x JOIN df y ON x.str = y.str
+ |GROUP BY x.str
+ """.stripMargin),
+ Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil)
+ }
+
+ val spark = activeSession
+ // Cannot use `import activeSession.implicits._` due to the compiler limitation.
+ import spark.implicits._
+
+ try {
+ activeSession
+ .createDataset[(Int, String)](Seq(1, 2, 3).map(i => (i, i.toString)))
+ .toDF("int", "str")
+ .createOrReplaceTempView("df")
+ checkTableExists(activeSession)
+
+ val forkedSession = activeSession.cloneSession()
+ assert(forkedSession ne activeSession)
+ assert(forkedSession.sessionState ne activeSession.sessionState)
+ checkTableExists(forkedSession)
+ checkTableExists(activeSession.cloneSession()) // ability to clone multiple times
+ checkTableExists(forkedSession.cloneSession()) // clone of clone
+ } finally {
+ activeSession.sql("drop table df")
+ }
+ }
+
+ test("fork new session and inherit reference to SharedState") {
+ val forkedSession = activeSession.cloneSession()
+ assert(activeSession.sharedState eq forkedSession.sharedState)
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala
index 989a7f2698..fcb8ffbc6e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala
@@ -493,6 +493,25 @@ class CatalogSuite
}
}
- // TODO: add tests for the rest of them
+ test("clone Catalog") {
+ // need to test tempTables are cloned
+ assert(spark.catalog.listTables().collect().isEmpty)
+ createTempTable("my_temp_table")
+ assert(spark.catalog.listTables().collect().map(_.name).toSet == Set("my_temp_table"))
+
+ // inheritance
+ val forkedSession = spark.cloneSession()
+ assert(spark ne forkedSession)
+ assert(spark.catalog ne forkedSession.catalog)
+ assert(forkedSession.catalog.listTables().collect().map(_.name).toSet == Set("my_temp_table"))
+
+ // independence
+ dropTable("my_temp_table") // drop table in original session
+ assert(spark.catalog.listTables().collect().map(_.name).toSet == Set())
+ assert(forkedSession.catalog.listTables().collect().map(_.name).toSet == Set("my_temp_table"))
+ forkedSession.sessionState.catalog
+ .createTempView("fork_table", Range(1, 2, 3, 4), overrideIfExists = true)
+ assert(spark.catalog.listTables().collect().map(_.name).toSet == Set())
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala
index 0e3a5ca9d7..f2456c7704 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala
@@ -187,4 +187,22 @@ class SQLConfEntrySuite extends SparkFunSuite {
}
assert(e2.getMessage === "The maximum size of the cache must not be negative")
}
+
+ test("clone SQLConf") {
+ val original = new SQLConf
+ val key = "spark.sql.SQLConfEntrySuite.clone"
+ assert(original.getConfString(key, "noentry") === "noentry")
+
+ // inheritance
+ original.setConfString(key, "orig")
+ val clone = original.clone()
+ assert(original ne clone)
+ assert(clone.getConfString(key, "noentry") === "orig")
+
+ // independence
+ clone.setConfString(key, "clone")
+ assert(original.getConfString(key, "noentry") === "orig")
+ original.setConfString(key, "dontcopyme")
+ assert(clone.getConfString(key, "noentry") === "clone")
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
index 8ab6db175d..898a2fb4f3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
@@ -35,18 +35,16 @@ private[sql] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) {
}
@transient
- override lazy val sessionState: SessionState = new SessionState(self) {
- override lazy val conf: SQLConf = {
- new SQLConf {
- clear()
- override def clear(): Unit = {
- super.clear()
- // Make sure we start with the default test configs even after clear
- TestSQLContext.overrideConfs.foreach { case (key, value) => setConfString(key, value) }
- }
+ override lazy val sessionState: SessionState = SessionState(
+ this,
+ new SQLConf {
+ clear()
+ override def clear(): Unit = {
+ super.clear()
+ // Make sure we start with the default test configs even after clear
+ TestSQLContext.overrideConfs.foreach { case (key, value) => setConfString(key, value) }
}
- }
- }
+ })
// Needed for Java tests
def loadTestData(): Unit = {
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index 4d3b6c3cec..d135dfa9f4 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -41,8 +41,9 @@ import org.apache.spark.sql.types._
* cleaned up to integrate more nicely with [[HiveExternalCatalog]].
*/
private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Logging {
- private val sessionState = sparkSession.sessionState.asInstanceOf[HiveSessionState]
- private lazy val tableRelationCache = sparkSession.sessionState.catalog.tableRelationCache
+ // these are def_s and not val/lazy val since the latter would introduce circular references
+ private def sessionState = sparkSession.sessionState.asInstanceOf[HiveSessionState]
+ private def tableRelationCache = sparkSession.sessionState.catalog.tableRelationCache
private def getCurrentDatabase: String = sessionState.catalog.getCurrentDatabase
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
index f1ea86890c..6b7599e3d3 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
@@ -26,7 +26,7 @@ import org.apache.hadoop.hive.ql.exec.{FunctionRegistry => HiveFunctionRegistry}
import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, GenericUDF, GenericUDTF}
import org.apache.spark.sql.{AnalysisException, SparkSession}
-import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
+import org.apache.spark.sql.catalyst.{CatalystConf, FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.catalog.{FunctionResourceLoader, GlobalTempViewManager, SessionCatalog}
@@ -43,31 +43,23 @@ import org.apache.spark.util.Utils
private[sql] class HiveSessionCatalog(
externalCatalog: HiveExternalCatalog,
globalTempViewManager: GlobalTempViewManager,
- sparkSession: SparkSession,
- functionResourceLoader: FunctionResourceLoader,
+ private val metastoreCatalog: HiveMetastoreCatalog,
functionRegistry: FunctionRegistry,
conf: SQLConf,
hadoopConf: Configuration,
parser: ParserInterface)
extends SessionCatalog(
- externalCatalog,
- globalTempViewManager,
- functionResourceLoader,
- functionRegistry,
- conf,
- hadoopConf,
- parser) {
+ externalCatalog,
+ globalTempViewManager,
+ functionRegistry,
+ conf,
+ hadoopConf,
+ parser) {
// ----------------------------------------------------------------
// | Methods and fields for interacting with HiveMetastoreCatalog |
// ----------------------------------------------------------------
- // Catalog for handling data source tables. TODO: This really doesn't belong here since it is
- // essentially a cache for metastore tables. However, it relies on a lot of session-specific
- // things so it would be a lot of work to split its functionality between HiveSessionCatalog
- // and HiveCatalog. We should still do it at some point...
- private val metastoreCatalog = new HiveMetastoreCatalog(sparkSession)
-
// These 2 rules must be run before all other DDL post-hoc resolution rules, i.e.
// `PreprocessTableCreation`, `PreprocessTableInsertion`, `DataSourceAnalysis` and `HiveAnalysis`.
val ParquetConversions: Rule[LogicalPlan] = metastoreCatalog.ParquetConversions
@@ -77,10 +69,51 @@ private[sql] class HiveSessionCatalog(
metastoreCatalog.hiveDefaultTableFilePath(name)
}
+ /**
+ * Create a new [[HiveSessionCatalog]] with the provided parameters. `externalCatalog` and
+ * `globalTempViewManager` are `inherited`, while `currentDb` and `tempTables` are copied.
+ */
+ def newSessionCatalogWith(
+ newSparkSession: SparkSession,
+ conf: SQLConf,
+ hadoopConf: Configuration,
+ functionRegistry: FunctionRegistry,
+ parser: ParserInterface): HiveSessionCatalog = {
+ val catalog = HiveSessionCatalog(
+ newSparkSession,
+ functionRegistry,
+ conf,
+ hadoopConf,
+ parser)
+
+ synchronized {
+ catalog.currentDb = currentDb
+ // copy over temporary tables
+ tempTables.foreach(kv => catalog.tempTables.put(kv._1, kv._2))
+ }
+
+ catalog
+ }
+
+ /**
+ * The parent class [[SessionCatalog]] cannot access the [[SparkSession]] class, so we cannot add
+ * a [[SparkSession]] parameter to [[SessionCatalog.newSessionCatalogWith]]. However,
+ * [[HiveSessionCatalog]] requires a [[SparkSession]] parameter, so we can a new version of
+ * `newSessionCatalogWith` and disable this one.
+ *
+ * TODO Refactor HiveSessionCatalog to not use [[SparkSession]] directly.
+ */
+ override def newSessionCatalogWith(
+ conf: CatalystConf,
+ hadoopConf: Configuration,
+ functionRegistry: FunctionRegistry,
+ parser: ParserInterface): HiveSessionCatalog = throw new UnsupportedOperationException(
+ "to clone HiveSessionCatalog, use the other clone method that also accepts a SparkSession")
+
// For testing only
private[hive] def getCachedDataSourceTable(table: TableIdentifier): LogicalPlan = {
val key = metastoreCatalog.getQualifiedTableName(table)
- sparkSession.sessionState.catalog.tableRelationCache.getIfPresent(key)
+ tableRelationCache.getIfPresent(key)
}
override def makeFunctionBuilder(funcName: String, className: String): FunctionBuilder = {
@@ -217,3 +250,28 @@ private[sql] class HiveSessionCatalog(
"histogram_numeric"
)
}
+
+private[sql] object HiveSessionCatalog {
+
+ def apply(
+ sparkSession: SparkSession,
+ functionRegistry: FunctionRegistry,
+ conf: SQLConf,
+ hadoopConf: Configuration,
+ parser: ParserInterface): HiveSessionCatalog = {
+ // Catalog for handling data source tables. TODO: This really doesn't belong here since it is
+ // essentially a cache for metastore tables. However, it relies on a lot of session-specific
+ // things so it would be a lot of work to split its functionality between HiveSessionCatalog
+ // and HiveCatalog. We should still do it at some point...
+ val metastoreCatalog = new HiveMetastoreCatalog(sparkSession)
+
+ new HiveSessionCatalog(
+ sparkSession.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog],
+ sparkSession.sharedState.globalTempViewManager,
+ metastoreCatalog,
+ functionRegistry,
+ conf,
+ hadoopConf,
+ parser)
+ }
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
index 5a08a6bc66..cb8bcb8591 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
@@ -17,89 +17,65 @@
package org.apache.spark.sql.hive
+import org.apache.spark.SparkContext
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.analysis.Analyzer
-import org.apache.spark.sql.execution.SparkPlanner
+import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry}
+import org.apache.spark.sql.catalyst.parser.ParserInterface
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.{QueryExecution, SparkPlanner, SparkSqlParser}
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.hive.client.HiveClient
-import org.apache.spark.sql.internal.SessionState
+import org.apache.spark.sql.internal.{SessionState, SharedState, SQLConf}
+import org.apache.spark.sql.streaming.StreamingQueryManager
/**
* A class that holds all session-specific state in a given [[SparkSession]] backed by Hive.
+ * @param sparkContext The [[SparkContext]].
+ * @param sharedState The shared state.
+ * @param conf SQL-specific key-value configurations.
+ * @param experimentalMethods The experimental methods.
+ * @param functionRegistry Internal catalog for managing functions registered by the user.
+ * @param catalog Internal catalog for managing table and database states that uses Hive client for
+ * interacting with the metastore.
+ * @param sqlParser Parser that extracts expressions, plans, table identifiers etc. from SQL texts.
+ * @param metadataHive The Hive metadata client.
+ * @param analyzer Logical query plan analyzer for resolving unresolved attributes and relations.
+ * @param streamingQueryManager Interface to start and stop
+ * [[org.apache.spark.sql.streaming.StreamingQuery]]s.
+ * @param queryExecutionCreator Lambda to create a [[QueryExecution]] from a [[LogicalPlan]]
+ * @param plannerCreator Lambda to create a planner that takes into account Hive-specific strategies
*/
-private[hive] class HiveSessionState(sparkSession: SparkSession)
- extends SessionState(sparkSession) {
-
- self =>
-
- /**
- * A Hive client used for interacting with the metastore.
- */
- lazy val metadataHive: HiveClient =
- sparkSession.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client.newSession()
-
- /**
- * Internal catalog for managing table and database states.
- */
- override lazy val catalog = {
- new HiveSessionCatalog(
- sparkSession.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog],
- sparkSession.sharedState.globalTempViewManager,
- sparkSession,
- functionResourceLoader,
- functionRegistry,
+private[hive] class HiveSessionState(
+ sparkContext: SparkContext,
+ sharedState: SharedState,
+ conf: SQLConf,
+ experimentalMethods: ExperimentalMethods,
+ functionRegistry: FunctionRegistry,
+ override val catalog: HiveSessionCatalog,
+ sqlParser: ParserInterface,
+ val metadataHive: HiveClient,
+ analyzer: Analyzer,
+ streamingQueryManager: StreamingQueryManager,
+ queryExecutionCreator: LogicalPlan => QueryExecution,
+ val plannerCreator: () => SparkPlanner)
+ extends SessionState(
+ sparkContext,
+ sharedState,
conf,
- newHadoopConf(),
- sqlParser)
- }
-
- /**
- * An analyzer that uses the Hive metastore.
- */
- override lazy val analyzer: Analyzer = {
- new Analyzer(catalog, conf) {
- override val extendedResolutionRules =
- new ResolveHiveSerdeTable(sparkSession) ::
- new FindDataSourceTable(sparkSession) ::
- new ResolveSQLOnFile(sparkSession) :: Nil
-
- override val postHocResolutionRules =
- new DetermineTableStats(sparkSession) ::
- catalog.ParquetConversions ::
- catalog.OrcConversions ::
- PreprocessTableCreation(sparkSession) ::
- PreprocessTableInsertion(conf) ::
- DataSourceAnalysis(conf) ::
- HiveAnalysis :: Nil
-
- override val extendedCheckRules = Seq(PreWriteCheck)
- }
- }
+ experimentalMethods,
+ functionRegistry,
+ catalog,
+ sqlParser,
+ analyzer,
+ streamingQueryManager,
+ queryExecutionCreator) { self =>
/**
* Planner that takes into account Hive-specific strategies.
*/
- override def planner: SparkPlanner = {
- new SparkPlanner(sparkSession.sparkContext, conf, experimentalMethods.extraStrategies)
- with HiveStrategies {
- override val sparkSession: SparkSession = self.sparkSession
-
- override def strategies: Seq[Strategy] = {
- experimentalMethods.extraStrategies ++ Seq(
- FileSourceStrategy,
- DataSourceStrategy,
- SpecialLimits,
- InMemoryScans,
- HiveTableScans,
- Scripts,
- Aggregation,
- JoinSelection,
- BasicOperators
- )
- }
- }
- }
+ override def planner: SparkPlanner = plannerCreator()
// ------------------------------------------------------
@@ -146,4 +122,149 @@ private[hive] class HiveSessionState(sparkSession: SparkSession)
conf.getConf(HiveUtils.HIVE_THRIFT_SERVER_ASYNC)
}
+ /**
+ * Get an identical copy of the `HiveSessionState`.
+ * This should ideally reuse the `SessionState.clone` but cannot do so.
+ * Doing that will throw an exception when trying to clone the catalog.
+ */
+ override def clone(newSparkSession: SparkSession): HiveSessionState = {
+ val sparkContext = newSparkSession.sparkContext
+ val confCopy = conf.clone()
+ val functionRegistryCopy = functionRegistry.clone()
+ val experimentalMethodsCopy = experimentalMethods.clone()
+ val sqlParser: ParserInterface = new SparkSqlParser(confCopy)
+ val catalogCopy = catalog.newSessionCatalogWith(
+ newSparkSession,
+ confCopy,
+ SessionState.newHadoopConf(sparkContext.hadoopConfiguration, confCopy),
+ functionRegistryCopy,
+ sqlParser)
+ val queryExecutionCreator = (plan: LogicalPlan) => new QueryExecution(newSparkSession, plan)
+
+ val hiveClient =
+ newSparkSession.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client
+ .newSession()
+
+ SessionState.mergeSparkConf(confCopy, sparkContext.getConf)
+
+ new HiveSessionState(
+ sparkContext,
+ newSparkSession.sharedState,
+ confCopy,
+ experimentalMethodsCopy,
+ functionRegistryCopy,
+ catalogCopy,
+ sqlParser,
+ hiveClient,
+ HiveSessionState.createAnalyzer(newSparkSession, catalogCopy, confCopy),
+ new StreamingQueryManager(newSparkSession),
+ queryExecutionCreator,
+ HiveSessionState.createPlannerCreator(
+ newSparkSession,
+ confCopy,
+ experimentalMethodsCopy))
+ }
+
+}
+
+private[hive] object HiveSessionState {
+
+ def apply(sparkSession: SparkSession): HiveSessionState = {
+ apply(sparkSession, new SQLConf)
+ }
+
+ def apply(sparkSession: SparkSession, conf: SQLConf): HiveSessionState = {
+ val initHelper = SessionState(sparkSession, conf)
+
+ val sparkContext = sparkSession.sparkContext
+
+ val catalog = HiveSessionCatalog(
+ sparkSession,
+ initHelper.functionRegistry,
+ initHelper.conf,
+ SessionState.newHadoopConf(sparkContext.hadoopConfiguration, initHelper.conf),
+ initHelper.sqlParser)
+
+ val metadataHive: HiveClient =
+ sparkSession.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client
+ .newSession()
+
+ val analyzer: Analyzer = createAnalyzer(sparkSession, catalog, initHelper.conf)
+
+ val plannerCreator = createPlannerCreator(
+ sparkSession,
+ initHelper.conf,
+ initHelper.experimentalMethods)
+
+ val hiveSessionState = new HiveSessionState(
+ sparkContext,
+ sparkSession.sharedState,
+ initHelper.conf,
+ initHelper.experimentalMethods,
+ initHelper.functionRegistry,
+ catalog,
+ initHelper.sqlParser,
+ metadataHive,
+ analyzer,
+ initHelper.streamingQueryManager,
+ initHelper.queryExecutionCreator,
+ plannerCreator)
+ catalog.functionResourceLoader = hiveSessionState.functionResourceLoader
+ hiveSessionState
+ }
+
+ /**
+ * Create an logical query plan `Analyzer` with rules specific to a `HiveSessionState`.
+ */
+ private def createAnalyzer(
+ sparkSession: SparkSession,
+ catalog: HiveSessionCatalog,
+ sqlConf: SQLConf): Analyzer = {
+ new Analyzer(catalog, sqlConf) {
+ override val extendedResolutionRules: Seq[Rule[LogicalPlan]] =
+ new ResolveHiveSerdeTable(sparkSession) ::
+ new FindDataSourceTable(sparkSession) ::
+ new ResolveSQLOnFile(sparkSession) :: Nil
+
+ override val postHocResolutionRules: Seq[Rule[LogicalPlan]] =
+ new DetermineTableStats(sparkSession) ::
+ catalog.ParquetConversions ::
+ catalog.OrcConversions ::
+ PreprocessTableCreation(sparkSession) ::
+ PreprocessTableInsertion(sqlConf) ::
+ DataSourceAnalysis(sqlConf) ::
+ HiveAnalysis :: Nil
+
+ override val extendedCheckRules = Seq(PreWriteCheck)
+ }
+ }
+
+ private def createPlannerCreator(
+ associatedSparkSession: SparkSession,
+ sqlConf: SQLConf,
+ experimentalMethods: ExperimentalMethods): () => SparkPlanner = {
+ () =>
+ new SparkPlanner(
+ associatedSparkSession.sparkContext,
+ sqlConf,
+ experimentalMethods.extraStrategies)
+ with HiveStrategies {
+
+ override val sparkSession: SparkSession = associatedSparkSession
+
+ override def strategies: Seq[Strategy] = {
+ experimentalMethods.extraStrategies ++ Seq(
+ FileSourceStrategy,
+ DataSourceStrategy,
+ SpecialLimits,
+ InMemoryScans,
+ HiveTableScans,
+ Scripts,
+ Aggregation,
+ JoinSelection,
+ BasicOperators
+ )
+ }
+ }
+ }
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
index 469c9d84de..6e1f429286 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
@@ -278,6 +278,8 @@ private[hive] class HiveClientImpl(
state.getConf.setClassLoader(clientLoader.classLoader)
// Set the thread local metastore client to the client associated with this HiveClientImpl.
Hive.set(client)
+ // Replace conf in the thread local Hive with current conf
+ Hive.get(conf)
// setCurrentSessionState will use the classLoader associated
// with the HiveConf in `state` to override the context class loader of the current
// thread.
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
index efc2f00984..076c40d459 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
@@ -30,16 +30,17 @@ import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{SparkSession, SQLContext}
-import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
-import org.apache.spark.sql.catalyst.expressions.ExpressionInfo
+import org.apache.spark.sql.{ExperimentalMethods, SparkSession, SQLContext}
+import org.apache.spark.sql.catalyst.analysis.{Analyzer, UnresolvedRelation}
+import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.execution.QueryExecution
+import org.apache.spark.sql.execution.{QueryExecution, SparkPlanner}
import org.apache.spark.sql.execution.command.CacheTableCommand
import org.apache.spark.sql.hive._
-import org.apache.spark.sql.internal.{SharedState, SQLConf}
+import org.apache.spark.sql.hive.client.HiveClient
+import org.apache.spark.sql.internal.{SessionState, SharedState, SQLConf}
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
+import org.apache.spark.sql.streaming.StreamingQueryManager
import org.apache.spark.util.{ShutdownHookManager, Utils}
// SPARK-3729: Test key required to check for initialization errors with config.
@@ -84,7 +85,7 @@ class TestHiveContext(
new TestHiveContext(sparkSession.newSession())
}
- override def sessionState: TestHiveSessionState = sparkSession.sessionState
+ override def sessionState: HiveSessionState = sparkSession.sessionState
def setCacheTables(c: Boolean): Unit = {
sparkSession.setCacheTables(c)
@@ -144,11 +145,35 @@ private[hive] class TestHiveSparkSession(
existingSharedState.getOrElse(new SharedState(sc))
}
- // TODO: Let's remove TestHiveSessionState. Otherwise, we are not really testing the reflection
- // logic based on the setting of CATALOG_IMPLEMENTATION.
@transient
- override lazy val sessionState: TestHiveSessionState =
- new TestHiveSessionState(self)
+ override lazy val sessionState: HiveSessionState = {
+ val testConf =
+ new SQLConf {
+ clear()
+ override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false)
+ override def clear(): Unit = {
+ super.clear()
+ TestHiveContext.overrideConfs.foreach { case (k, v) => setConfString(k, v) }
+ }
+ }
+ val queryExecutionCreator = (plan: LogicalPlan) => new TestHiveQueryExecution(this, plan)
+ val initHelper = HiveSessionState(this, testConf)
+ SessionState.mergeSparkConf(testConf, sparkContext.getConf)
+
+ new HiveSessionState(
+ sparkContext,
+ sharedState,
+ testConf,
+ initHelper.experimentalMethods,
+ initHelper.functionRegistry,
+ initHelper.catalog,
+ initHelper.sqlParser,
+ initHelper.metadataHive,
+ initHelper.analyzer,
+ initHelper.streamingQueryManager,
+ queryExecutionCreator,
+ initHelper.plannerCreator)
+ }
override def newSession(): TestHiveSparkSession = {
new TestHiveSparkSession(sc, Some(sharedState), loadTestTables)
@@ -492,26 +517,6 @@ private[hive] class TestHiveQueryExecution(
}
}
-private[hive] class TestHiveSessionState(
- sparkSession: TestHiveSparkSession)
- extends HiveSessionState(sparkSession) { self =>
-
- override lazy val conf: SQLConf = {
- new SQLConf {
- clear()
- override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false)
- override def clear(): Unit = {
- super.clear()
- TestHiveContext.overrideConfs.foreach { case (k, v) => setConfString(k, v) }
- }
- }
- }
-
- override def executePlan(plan: LogicalPlan): TestHiveQueryExecution = {
- new TestHiveQueryExecution(sparkSession, plan)
- }
-}
-
private[hive] object TestHiveContext {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionCatalogSuite.scala
new file mode 100644
index 0000000000..3b0f59b159
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionCatalogSuite.scala
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+import java.net.URI
+
+import org.apache.hadoop.conf.Configuration
+
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.analysis.SimpleFunctionRegistry
+import org.apache.spark.sql.catalyst.catalog.CatalogDatabase
+import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
+import org.apache.spark.sql.catalyst.plans.logical.Range
+import org.apache.spark.sql.hive.test.TestHiveSingleton
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.util.Utils
+
+class HiveSessionCatalogSuite extends TestHiveSingleton {
+
+ test("clone HiveSessionCatalog") {
+ val original = spark.sessionState.catalog.asInstanceOf[HiveSessionCatalog]
+
+ val tempTableName1 = "copytest1"
+ val tempTableName2 = "copytest2"
+ try {
+ val tempTable1 = Range(1, 10, 1, 10)
+ original.createTempView(tempTableName1, tempTable1, overrideIfExists = false)
+
+ // check if tables copied over
+ val clone = original.newSessionCatalogWith(
+ spark,
+ new SQLConf,
+ new Configuration(),
+ new SimpleFunctionRegistry,
+ CatalystSqlParser)
+ assert(original ne clone)
+ assert(clone.getTempView(tempTableName1) == Some(tempTable1))
+
+ // check if clone and original independent
+ clone.dropTable(TableIdentifier(tempTableName1), ignoreIfNotExists = false, purge = false)
+ assert(original.getTempView(tempTableName1) == Some(tempTable1))
+
+ val tempTable2 = Range(1, 20, 2, 10)
+ original.createTempView(tempTableName2, tempTable2, overrideIfExists = false)
+ assert(clone.getTempView(tempTableName2).isEmpty)
+ } finally {
+ // Drop the created temp views from the global singleton HiveSession.
+ original.dropTable(TableIdentifier(tempTableName1), ignoreIfNotExists = true, purge = true)
+ original.dropTable(TableIdentifier(tempTableName2), ignoreIfNotExists = true, purge = true)
+ }
+ }
+
+ test("clone SessionCatalog - current db") {
+ val original = spark.sessionState.catalog.asInstanceOf[HiveSessionCatalog]
+ val originalCurrentDatabase = original.getCurrentDatabase
+ val db1 = "db1"
+ val db2 = "db2"
+ val db3 = "db3"
+ try {
+ original.createDatabase(newDb(db1), ignoreIfExists = true)
+ original.createDatabase(newDb(db2), ignoreIfExists = true)
+ original.createDatabase(newDb(db3), ignoreIfExists = true)
+
+ original.setCurrentDatabase(db1)
+
+ // check if tables copied over
+ val clone = original.newSessionCatalogWith(
+ spark,
+ new SQLConf,
+ new Configuration(),
+ new SimpleFunctionRegistry,
+ CatalystSqlParser)
+
+ // check if current db copied over
+ assert(original ne clone)
+ assert(clone.getCurrentDatabase == db1)
+
+ // check if clone and original independent
+ clone.setCurrentDatabase(db2)
+ assert(original.getCurrentDatabase == db1)
+ original.setCurrentDatabase(db3)
+ assert(clone.getCurrentDatabase == db2)
+ } finally {
+ // Drop the created databases from the global singleton HiveSession.
+ original.dropDatabase(db1, ignoreIfNotExists = true, cascade = true)
+ original.dropDatabase(db2, ignoreIfNotExists = true, cascade = true)
+ original.dropDatabase(db3, ignoreIfNotExists = true, cascade = true)
+ original.setCurrentDatabase(originalCurrentDatabase)
+ }
+ }
+
+ def newUriForDatabase(): URI = new URI(Utils.createTempDir().toURI.toString.stripSuffix("/"))
+
+ def newDb(name: String): CatalogDatabase = {
+ CatalogDatabase(name, name + " description", newUriForDatabase(), Map.empty)
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala
new file mode 100644
index 0000000000..67c77fb62f
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+import org.scalatest.BeforeAndAfterEach
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.hive.test.TestHiveSingleton
+
+/**
+ * Run all tests from `SessionStateSuite` with a `HiveSessionState`.
+ */
+class HiveSessionStateSuite extends SessionStateSuite
+ with TestHiveSingleton with BeforeAndAfterEach {
+
+ override def beforeAll(): Unit = {
+ // Reuse the singleton session
+ activeSession = spark
+ }
+
+ override def afterAll(): Unit = {
+ // Set activeSession to null to avoid stopping the singleton session
+ activeSession = null
+ super.afterAll()
+ }
+}