aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala38
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala55
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala59
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala235
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala162
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala21
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala18
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala20
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala5
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala92
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala261
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala2
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala67
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionCatalogSuite.scala112
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala41
20 files changed, 981 insertions, 236 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
index fb99cb27b8..cff0efa979 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
@@ -66,6 +66,8 @@ trait CatalystConf {
/** The maximum number of joined nodes allowed in the dynamic programming algorithm. */
def joinReorderDPThreshold: Int
+
+ override def clone(): CatalystConf = throw new CloneNotSupportedException()
}
@@ -85,4 +87,7 @@ case class SimpleCatalystConf(
joinReorderDPThreshold: Int = 12,
warehousePath: String = "/user/hive/warehouse",
sessionLocalTimeZone: String = TimeZone.getDefault().getID)
- extends CatalystConf
+ extends CatalystConf {
+
+ override def clone(): SimpleCatalystConf = this.copy()
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 556fa99017..0dcb44081f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -64,6 +64,8 @@ trait FunctionRegistry {
/** Clear all registered functions. */
def clear(): Unit
+ /** Create a copy of this registry with identical functions as this registry. */
+ override def clone(): FunctionRegistry = throw new CloneNotSupportedException()
}
class SimpleFunctionRegistry extends FunctionRegistry {
@@ -107,7 +109,7 @@ class SimpleFunctionRegistry extends FunctionRegistry {
functionBuilders.clear()
}
- def copy(): SimpleFunctionRegistry = synchronized {
+ override def clone(): SimpleFunctionRegistry = synchronized {
val registry = new SimpleFunctionRegistry
functionBuilders.iterator.foreach { case (name, (info, builder)) =>
registry.registerFunction(name, info, builder)
@@ -150,6 +152,7 @@ object EmptyFunctionRegistry extends FunctionRegistry {
throw new UnsupportedOperationException
}
+ override def clone(): FunctionRegistry = this
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
index 831e37aac1..6cfc4a4321 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
@@ -50,7 +50,6 @@ object SessionCatalog {
class SessionCatalog(
externalCatalog: ExternalCatalog,
globalTempViewManager: GlobalTempViewManager,
- functionResourceLoader: FunctionResourceLoader,
functionRegistry: FunctionRegistry,
conf: CatalystConf,
hadoopConf: Configuration,
@@ -66,16 +65,19 @@ class SessionCatalog(
this(
externalCatalog,
new GlobalTempViewManager("global_temp"),
- DummyFunctionResourceLoader,
functionRegistry,
conf,
new Configuration(),
CatalystSqlParser)
+ functionResourceLoader = DummyFunctionResourceLoader
}
// For testing only.
def this(externalCatalog: ExternalCatalog) {
- this(externalCatalog, new SimpleFunctionRegistry, new SimpleCatalystConf(true))
+ this(
+ externalCatalog,
+ new SimpleFunctionRegistry,
+ SimpleCatalystConf(caseSensitiveAnalysis = true))
}
/** List of temporary tables, mapping from table name to their logical plan. */
@@ -89,6 +91,8 @@ class SessionCatalog(
@GuardedBy("this")
protected var currentDb = formatDatabaseName(DEFAULT_DATABASE)
+ @volatile var functionResourceLoader: FunctionResourceLoader = _
+
/**
* Checks if the given name conforms the Hive standard ("[a-zA-z_0-9]+"),
* i.e. if this name only contains characters, numbers, and _.
@@ -987,6 +991,9 @@ class SessionCatalog(
* by a tuple (resource type, resource uri).
*/
def loadFunctionResources(resources: Seq[FunctionResource]): Unit = {
+ if (functionResourceLoader == null) {
+ throw new IllegalStateException("functionResourceLoader has not yet been initialized")
+ }
resources.foreach(functionResourceLoader.loadResource)
}
@@ -1182,4 +1189,29 @@ class SessionCatalog(
}
}
+ /**
+ * Create a new [[SessionCatalog]] with the provided parameters. `externalCatalog` and
+ * `globalTempViewManager` are `inherited`, while `currentDb` and `tempTables` are copied.
+ */
+ def newSessionCatalogWith(
+ conf: CatalystConf,
+ hadoopConf: Configuration,
+ functionRegistry: FunctionRegistry,
+ parser: ParserInterface): SessionCatalog = {
+ val catalog = new SessionCatalog(
+ externalCatalog,
+ globalTempViewManager,
+ functionRegistry,
+ conf,
+ hadoopConf,
+ parser)
+
+ synchronized {
+ catalog.currentDb = currentDb
+ // copy over temporary tables
+ tempTables.foreach(kv => catalog.tempTables.put(kv._1, kv._2))
+ }
+
+ catalog
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
index 328a16c4bf..7e74dcdef0 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.catalog
+import org.apache.hadoop.conf.Configuration
+
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{FunctionIdentifier, SimpleCatalystConf, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis._
@@ -1197,6 +1199,59 @@ class SessionCatalogSuite extends PlanTest {
}
}
+ test("clone SessionCatalog - temp views") {
+ val externalCatalog = newEmptyCatalog()
+ val original = new SessionCatalog(externalCatalog)
+ val tempTable1 = Range(1, 10, 1, 10)
+ original.createTempView("copytest1", tempTable1, overrideIfExists = false)
+
+ // check if tables copied over
+ val clone = original.newSessionCatalogWith(
+ SimpleCatalystConf(caseSensitiveAnalysis = true),
+ new Configuration(),
+ new SimpleFunctionRegistry,
+ CatalystSqlParser)
+ assert(original ne clone)
+ assert(clone.getTempView("copytest1") == Some(tempTable1))
+
+ // check if clone and original independent
+ clone.dropTable(TableIdentifier("copytest1"), ignoreIfNotExists = false, purge = false)
+ assert(original.getTempView("copytest1") == Some(tempTable1))
+
+ val tempTable2 = Range(1, 20, 2, 10)
+ original.createTempView("copytest2", tempTable2, overrideIfExists = false)
+ assert(clone.getTempView("copytest2").isEmpty)
+ }
+
+ test("clone SessionCatalog - current db") {
+ val externalCatalog = newEmptyCatalog()
+ val db1 = "db1"
+ val db2 = "db2"
+ val db3 = "db3"
+
+ externalCatalog.createDatabase(newDb(db1), ignoreIfExists = true)
+ externalCatalog.createDatabase(newDb(db2), ignoreIfExists = true)
+ externalCatalog.createDatabase(newDb(db3), ignoreIfExists = true)
+
+ val original = new SessionCatalog(externalCatalog)
+ original.setCurrentDatabase(db1)
+
+ // check if current db copied over
+ val clone = original.newSessionCatalogWith(
+ SimpleCatalystConf(caseSensitiveAnalysis = true),
+ new Configuration(),
+ new SimpleFunctionRegistry,
+ CatalystSqlParser)
+ assert(original ne clone)
+ assert(clone.getCurrentDatabase == db1)
+
+ // check if clone and original independent
+ clone.setCurrentDatabase(db2)
+ assert(original.getCurrentDatabase == db1)
+ original.setCurrentDatabase(db3)
+ assert(clone.getCurrentDatabase == db2)
+ }
+
test("SPARK-19737: detect undefined functions without triggering relation resolution") {
import org.apache.spark.sql.catalyst.dsl.plans._
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala
index 1e8ba51e59..bd8dd6ea3f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala
@@ -46,4 +46,10 @@ class ExperimentalMethods private[sql]() {
@volatile var extraOptimizations: Seq[Rule[LogicalPlan]] = Nil
+ override def clone(): ExperimentalMethods = {
+ val result = new ExperimentalMethods
+ result.extraStrategies = extraStrategies
+ result.extraOptimizations = extraOptimizations
+ result
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index afc1827e7e..49562578b2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -21,7 +21,6 @@ import java.io.Closeable
import java.util.concurrent.atomic.AtomicReference
import scala.collection.JavaConverters._
-import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal
@@ -43,7 +42,7 @@ import org.apache.spark.sql.internal.{CatalogImpl, SessionState, SharedState}
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.streaming._
-import org.apache.spark.sql.types.{DataType, LongType, StructType}
+import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.util.ExecutionListenerManager
import org.apache.spark.util.Utils
@@ -67,15 +66,22 @@ import org.apache.spark.util.Utils
* .config("spark.some.config.option", "some-value")
* .getOrCreate()
* }}}
+ *
+ * @param sparkContext The Spark context associated with this Spark session.
+ * @param existingSharedState If supplied, use the existing shared state
+ * instead of creating a new one.
+ * @param parentSessionState If supplied, inherit all session state (i.e. temporary
+ * views, SQL config, UDFs etc) from parent.
*/
@InterfaceStability.Stable
class SparkSession private(
@transient val sparkContext: SparkContext,
- @transient private val existingSharedState: Option[SharedState])
+ @transient private val existingSharedState: Option[SharedState],
+ @transient private val parentSessionState: Option[SessionState])
extends Serializable with Closeable with Logging { self =>
private[sql] def this(sc: SparkContext) {
- this(sc, None)
+ this(sc, None, None)
}
sparkContext.assertNotStopped()
@@ -108,6 +114,7 @@ class SparkSession private(
/**
* State isolated across sessions, including SQL configurations, temporary tables, registered
* functions, and everything else that accepts a [[org.apache.spark.sql.internal.SQLConf]].
+ * If `parentSessionState` is not null, the `SessionState` will be a copy of the parent.
*
* This is internal to Spark and there is no guarantee on interface stability.
*
@@ -116,9 +123,13 @@ class SparkSession private(
@InterfaceStability.Unstable
@transient
lazy val sessionState: SessionState = {
- SparkSession.reflect[SessionState, SparkSession](
- SparkSession.sessionStateClassName(sparkContext.conf),
- self)
+ parentSessionState
+ .map(_.clone(this))
+ .getOrElse {
+ SparkSession.instantiateSessionState(
+ SparkSession.sessionStateClassName(sparkContext.conf),
+ self)
+ }
}
/**
@@ -208,7 +219,25 @@ class SparkSession private(
* @since 2.0.0
*/
def newSession(): SparkSession = {
- new SparkSession(sparkContext, Some(sharedState))
+ new SparkSession(sparkContext, Some(sharedState), parentSessionState = None)
+ }
+
+ /**
+ * Create an identical copy of this `SparkSession`, sharing the underlying `SparkContext`
+ * and shared state. All the state of this session (i.e. SQL configurations, temporary tables,
+ * registered functions) is copied over, and the cloned session is set up with the same shared
+ * state as this session. The cloned session is independent of this session, that is, any
+ * non-global change in either session is not reflected in the other.
+ *
+ * @note Other than the `SparkContext`, all shared state is initialized lazily.
+ * This method will force the initialization of the shared state to ensure that parent
+ * and child sessions are set up with the same shared state. If the underlying catalog
+ * implementation is Hive, this will initialize the metastore, which may take some time.
+ */
+ private[sql] def cloneSession(): SparkSession = {
+ val result = new SparkSession(sparkContext, Some(sharedState), Some(sessionState))
+ result.sessionState // force copy of SessionState
+ result
}
@@ -971,16 +1000,18 @@ object SparkSession {
}
/**
- * Helper method to create an instance of [[T]] using a single-arg constructor that
- * accepts an [[Arg]].
+ * Helper method to create an instance of `SessionState` based on `className` from conf.
+ * The result is either `SessionState` or `HiveSessionState`.
*/
- private def reflect[T, Arg <: AnyRef](
+ private def instantiateSessionState(
className: String,
- ctorArg: Arg)(implicit ctorArgTag: ClassTag[Arg]): T = {
+ sparkSession: SparkSession): SessionState = {
+
try {
+ // get `SessionState.apply(SparkSession)`
val clazz = Utils.classForName(className)
- val ctor = clazz.getDeclaredConstructor(ctorArgTag.runtimeClass)
- ctor.newInstance(ctorArg).asInstanceOf[T]
+ val method = clazz.getMethod("apply", sparkSession.getClass)
+ method.invoke(null, sparkSession).asInstanceOf[SessionState]
} catch {
case NonFatal(e) =>
throw new IllegalArgumentException(s"Error while instantiating '$className':", e)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
index 4d781b96ab..8b598cc60e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
@@ -66,7 +66,8 @@ class ResolveSQLOnFile(sparkSession: SparkSession) extends Rule[LogicalPlan] {
* Preprocess [[CreateTable]], to do some normalization and checking.
*/
case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[LogicalPlan] {
- private val catalog = sparkSession.sessionState.catalog
+ // catalog is a def and not a val/lazy val as the latter would introduce a circular reference
+ private def catalog = sparkSession.sessionState.catalog
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// When we CREATE TABLE without specifying the table schema, we should fail the query if
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 94e3fa7dd1..1244f690fd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -1019,6 +1019,14 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging {
def clear(): Unit = {
settings.clear()
}
+
+ override def clone(): SQLConf = {
+ val result = new SQLConf
+ getAllConfs.foreach {
+ case(k, v) => if (v ne null) result.setConfString(k, v)
+ }
+ result
+ }
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
index 6908560511..ce80604bd3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
@@ -22,38 +22,49 @@ import java.io.File
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
+import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry}
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.command.AnalyzeTableCommand
import org.apache.spark.sql.execution.datasources._
-import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryManager}
+import org.apache.spark.sql.streaming.StreamingQueryManager
import org.apache.spark.sql.util.ExecutionListenerManager
/**
* A class that holds all session-specific state in a given [[SparkSession]].
+ * @param sparkContext The [[SparkContext]].
+ * @param sharedState The shared state.
+ * @param conf SQL-specific key-value configurations.
+ * @param experimentalMethods The experimental methods.
+ * @param functionRegistry Internal catalog for managing functions registered by the user.
+ * @param catalog Internal catalog for managing table and database states.
+ * @param sqlParser Parser that extracts expressions, plans, table identifiers etc. from SQL texts.
+ * @param analyzer Logical query plan analyzer for resolving unresolved attributes and relations.
+ * @param streamingQueryManager Interface to start and stop
+ * [[org.apache.spark.sql.streaming.StreamingQuery]]s.
+ * @param queryExecutionCreator Lambda to create a [[QueryExecution]] from a [[LogicalPlan]]
*/
-private[sql] class SessionState(sparkSession: SparkSession) {
+private[sql] class SessionState(
+ sparkContext: SparkContext,
+ sharedState: SharedState,
+ val conf: SQLConf,
+ val experimentalMethods: ExperimentalMethods,
+ val functionRegistry: FunctionRegistry,
+ val catalog: SessionCatalog,
+ val sqlParser: ParserInterface,
+ val analyzer: Analyzer,
+ val streamingQueryManager: StreamingQueryManager,
+ val queryExecutionCreator: LogicalPlan => QueryExecution) {
- // Note: These are all lazy vals because they depend on each other (e.g. conf) and we
- // want subclasses to override some of the fields. Otherwise, we would get a lot of NPEs.
-
- /**
- * SQL-specific key-value configurations.
- */
- lazy val conf: SQLConf = new SQLConf
-
- def newHadoopConf(): Configuration = {
- val hadoopConf = new Configuration(sparkSession.sparkContext.hadoopConfiguration)
- conf.getAllConfs.foreach { case (k, v) => if (v ne null) hadoopConf.set(k, v) }
- hadoopConf
- }
+ def newHadoopConf(): Configuration = SessionState.newHadoopConf(
+ sparkContext.hadoopConfiguration,
+ conf)
def newHadoopConfWithOptions(options: Map[String, String]): Configuration = {
val hadoopConf = newHadoopConf()
@@ -65,22 +76,15 @@ private[sql] class SessionState(sparkSession: SparkSession) {
hadoopConf
}
- lazy val experimentalMethods = new ExperimentalMethods
-
- /**
- * Internal catalog for managing functions registered by the user.
- */
- lazy val functionRegistry: FunctionRegistry = FunctionRegistry.builtin.copy()
-
/**
* A class for loading resources specified by a function.
*/
- lazy val functionResourceLoader: FunctionResourceLoader = {
+ val functionResourceLoader: FunctionResourceLoader = {
new FunctionResourceLoader {
override def loadResource(resource: FunctionResource): Unit = {
resource.resourceType match {
case JarResource => addJar(resource.uri)
- case FileResource => sparkSession.sparkContext.addFile(resource.uri)
+ case FileResource => sparkContext.addFile(resource.uri)
case ArchiveResource =>
throw new AnalysisException(
"Archive is not allowed to be loaded. If YARN mode is used, " +
@@ -91,92 +95,77 @@ private[sql] class SessionState(sparkSession: SparkSession) {
}
/**
- * Internal catalog for managing table and database states.
- */
- lazy val catalog = new SessionCatalog(
- sparkSession.sharedState.externalCatalog,
- sparkSession.sharedState.globalTempViewManager,
- functionResourceLoader,
- functionRegistry,
- conf,
- newHadoopConf(),
- sqlParser)
-
- /**
* Interface exposed to the user for registering user-defined functions.
* Note that the user-defined functions must be deterministic.
*/
- lazy val udf: UDFRegistration = new UDFRegistration(functionRegistry)
-
- /**
- * Logical query plan analyzer for resolving unresolved attributes and relations.
- */
- lazy val analyzer: Analyzer = {
- new Analyzer(catalog, conf) {
- override val extendedResolutionRules =
- new FindDataSourceTable(sparkSession) ::
- new ResolveSQLOnFile(sparkSession) :: Nil
-
- override val postHocResolutionRules =
- PreprocessTableCreation(sparkSession) ::
- PreprocessTableInsertion(conf) ::
- DataSourceAnalysis(conf) :: Nil
-
- override val extendedCheckRules = Seq(PreWriteCheck, HiveOnlyCheck)
- }
- }
+ val udf: UDFRegistration = new UDFRegistration(functionRegistry)
/**
* Logical query plan optimizer.
*/
- lazy val optimizer: Optimizer = new SparkOptimizer(catalog, conf, experimentalMethods)
-
- /**
- * Parser that extracts expressions, plans, table identifiers etc. from SQL texts.
- */
- lazy val sqlParser: ParserInterface = new SparkSqlParser(conf)
+ val optimizer: Optimizer = new SparkOptimizer(catalog, conf, experimentalMethods)
/**
* Planner that converts optimized logical plans to physical plans.
*/
def planner: SparkPlanner =
- new SparkPlanner(sparkSession.sparkContext, conf, experimentalMethods.extraStrategies)
+ new SparkPlanner(sparkContext, conf, experimentalMethods.extraStrategies)
/**
* An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s
* that listen for execution metrics.
*/
- lazy val listenerManager: ExecutionListenerManager = new ExecutionListenerManager
+ val listenerManager: ExecutionListenerManager = new ExecutionListenerManager
/**
- * Interface to start and stop [[StreamingQuery]]s.
+ * Get an identical copy of the `SessionState` and associate it with the given `SparkSession`
*/
- lazy val streamingQueryManager: StreamingQueryManager = {
- new StreamingQueryManager(sparkSession)
- }
+ def clone(newSparkSession: SparkSession): SessionState = {
+ val sparkContext = newSparkSession.sparkContext
+ val confCopy = conf.clone()
+ val functionRegistryCopy = functionRegistry.clone()
+ val sqlParser: ParserInterface = new SparkSqlParser(confCopy)
+ val catalogCopy = catalog.newSessionCatalogWith(
+ confCopy,
+ SessionState.newHadoopConf(sparkContext.hadoopConfiguration, confCopy),
+ functionRegistryCopy,
+ sqlParser)
+ val queryExecutionCreator = (plan: LogicalPlan) => new QueryExecution(newSparkSession, plan)
- private val jarClassLoader: NonClosableMutableURLClassLoader =
- sparkSession.sharedState.jarClassLoader
+ SessionState.mergeSparkConf(confCopy, sparkContext.getConf)
- // Automatically extract all entries and put it in our SQLConf
- // We need to call it after all of vals have been initialized.
- sparkSession.sparkContext.getConf.getAll.foreach { case (k, v) =>
- conf.setConfString(k, v)
+ new SessionState(
+ sparkContext,
+ newSparkSession.sharedState,
+ confCopy,
+ experimentalMethods.clone(),
+ functionRegistryCopy,
+ catalogCopy,
+ sqlParser,
+ SessionState.createAnalyzer(newSparkSession, catalogCopy, confCopy),
+ new StreamingQueryManager(newSparkSession),
+ queryExecutionCreator)
}
// ------------------------------------------------------
// Helper methods, partially leftover from pre-2.0 days
// ------------------------------------------------------
- def executePlan(plan: LogicalPlan): QueryExecution = new QueryExecution(sparkSession, plan)
+ def executePlan(plan: LogicalPlan): QueryExecution = queryExecutionCreator(plan)
def refreshTable(tableName: String): Unit = {
catalog.refreshTable(sqlParser.parseTableIdentifier(tableName))
}
+ /**
+ * Add a jar path to [[SparkContext]] and the classloader.
+ *
+ * Note: this method seems not access any session state, but the subclass `HiveSessionState` needs
+ * to add the jar to its hive client for the current session. Hence, it still needs to be in
+ * [[SessionState]].
+ */
def addJar(path: String): Unit = {
- sparkSession.sparkContext.addJar(path)
-
+ sparkContext.addJar(path)
val uri = new Path(path).toUri
val jarURL = if (uri.getScheme == null) {
// `path` is a local file path without a URL scheme
@@ -185,15 +174,93 @@ private[sql] class SessionState(sparkSession: SparkSession) {
// `path` is a URL with a scheme
uri.toURL
}
- jarClassLoader.addURL(jarURL)
- Thread.currentThread().setContextClassLoader(jarClassLoader)
+ sharedState.jarClassLoader.addURL(jarURL)
+ Thread.currentThread().setContextClassLoader(sharedState.jarClassLoader)
+ }
+}
+
+
+private[sql] object SessionState {
+
+ def apply(sparkSession: SparkSession): SessionState = {
+ apply(sparkSession, new SQLConf)
+ }
+
+ def apply(sparkSession: SparkSession, sqlConf: SQLConf): SessionState = {
+ val sparkContext = sparkSession.sparkContext
+
+ // Automatically extract all entries and put them in our SQLConf
+ mergeSparkConf(sqlConf, sparkContext.getConf)
+
+ val functionRegistry = FunctionRegistry.builtin.clone()
+
+ val sqlParser: ParserInterface = new SparkSqlParser(sqlConf)
+
+ val catalog = new SessionCatalog(
+ sparkSession.sharedState.externalCatalog,
+ sparkSession.sharedState.globalTempViewManager,
+ functionRegistry,
+ sqlConf,
+ newHadoopConf(sparkContext.hadoopConfiguration, sqlConf),
+ sqlParser)
+
+ val analyzer: Analyzer = createAnalyzer(sparkSession, catalog, sqlConf)
+
+ val streamingQueryManager: StreamingQueryManager = new StreamingQueryManager(sparkSession)
+
+ val queryExecutionCreator = (plan: LogicalPlan) => new QueryExecution(sparkSession, plan)
+
+ val sessionState = new SessionState(
+ sparkContext,
+ sparkSession.sharedState,
+ sqlConf,
+ new ExperimentalMethods,
+ functionRegistry,
+ catalog,
+ sqlParser,
+ analyzer,
+ streamingQueryManager,
+ queryExecutionCreator)
+ // functionResourceLoader needs to access SessionState.addJar, so it cannot be created before
+ // creating SessionState. Setting `catalog.functionResourceLoader` here is safe since the caller
+ // cannot use SessionCatalog before we return SessionState.
+ catalog.functionResourceLoader = sessionState.functionResourceLoader
+ sessionState
+ }
+
+ def newHadoopConf(hadoopConf: Configuration, sqlConf: SQLConf): Configuration = {
+ val newHadoopConf = new Configuration(hadoopConf)
+ sqlConf.getAllConfs.foreach { case (k, v) => if (v ne null) newHadoopConf.set(k, v) }
+ newHadoopConf
+ }
+
+ /**
+ * Create an logical query plan `Analyzer` with rules specific to a non-Hive `SessionState`.
+ */
+ private def createAnalyzer(
+ sparkSession: SparkSession,
+ catalog: SessionCatalog,
+ sqlConf: SQLConf): Analyzer = {
+ new Analyzer(catalog, sqlConf) {
+ override val extendedResolutionRules: Seq[Rule[LogicalPlan]] =
+ new FindDataSourceTable(sparkSession) ::
+ new ResolveSQLOnFile(sparkSession) :: Nil
+
+ override val postHocResolutionRules: Seq[Rule[LogicalPlan]] =
+ PreprocessTableCreation(sparkSession) ::
+ PreprocessTableInsertion(sqlConf) ::
+ DataSourceAnalysis(sqlConf) :: Nil
+
+ override val extendedCheckRules = Seq(PreWriteCheck, HiveOnlyCheck)
+ }
}
/**
- * Analyzes the given table in the current database to generate statistics, which will be
- * used in query optimizations.
+ * Extract entries from `SparkConf` and put them in the `SQLConf`
*/
- def analyze(tableIdent: TableIdentifier, noscan: Boolean = true): Unit = {
- AnalyzeTableCommand(tableIdent, noscan).run(sparkSession)
+ def mergeSparkConf(sqlConf: SQLConf, sparkConf: SparkConf): Unit = {
+ sparkConf.getAll.foreach { case (k, v) =>
+ sqlConf.setConfString(k, v)
+ }
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala
new file mode 100644
index 0000000000..2d5e37242a
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.BeforeAndAfterEach
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+
+class SessionStateSuite extends SparkFunSuite
+ with BeforeAndAfterEach with BeforeAndAfterAll {
+
+ /**
+ * A shared SparkSession for all tests in this suite. Make sure you reset any changes to this
+ * session as this is a singleton HiveSparkSession in HiveSessionStateSuite and it's shared
+ * with all Hive test suites.
+ */
+ protected var activeSession: SparkSession = _
+
+ override def beforeAll(): Unit = {
+ activeSession = SparkSession.builder().master("local").getOrCreate()
+ }
+
+ override def afterAll(): Unit = {
+ if (activeSession != null) {
+ activeSession.stop()
+ activeSession = null
+ }
+ super.afterAll()
+ }
+
+ test("fork new session and inherit RuntimeConfig options") {
+ val key = "spark-config-clone"
+ try {
+ activeSession.conf.set(key, "active")
+
+ // inheritance
+ val forkedSession = activeSession.cloneSession()
+ assert(forkedSession ne activeSession)
+ assert(forkedSession.conf ne activeSession.conf)
+ assert(forkedSession.conf.get(key) == "active")
+
+ // independence
+ forkedSession.conf.set(key, "forked")
+ assert(activeSession.conf.get(key) == "active")
+ activeSession.conf.set(key, "dontcopyme")
+ assert(forkedSession.conf.get(key) == "forked")
+ } finally {
+ activeSession.conf.unset(key)
+ }
+ }
+
+ test("fork new session and inherit function registry and udf") {
+ val testFuncName1 = "strlenScala"
+ val testFuncName2 = "addone"
+ try {
+ activeSession.udf.register(testFuncName1, (_: String).length + (_: Int))
+ val forkedSession = activeSession.cloneSession()
+
+ // inheritance
+ assert(forkedSession ne activeSession)
+ assert(forkedSession.sessionState.functionRegistry ne
+ activeSession.sessionState.functionRegistry)
+ assert(forkedSession.sessionState.functionRegistry.lookupFunction(testFuncName1).nonEmpty)
+
+ // independence
+ forkedSession.sessionState.functionRegistry.dropFunction(testFuncName1)
+ assert(activeSession.sessionState.functionRegistry.lookupFunction(testFuncName1).nonEmpty)
+ activeSession.udf.register(testFuncName2, (_: Int) + 1)
+ assert(forkedSession.sessionState.functionRegistry.lookupFunction(testFuncName2).isEmpty)
+ } finally {
+ activeSession.sessionState.functionRegistry.dropFunction(testFuncName1)
+ activeSession.sessionState.functionRegistry.dropFunction(testFuncName2)
+ }
+ }
+
+ test("fork new session and inherit experimental methods") {
+ val originalExtraOptimizations = activeSession.experimental.extraOptimizations
+ val originalExtraStrategies = activeSession.experimental.extraStrategies
+ try {
+ object DummyRule1 extends Rule[LogicalPlan] {
+ def apply(p: LogicalPlan): LogicalPlan = p
+ }
+ object DummyRule2 extends Rule[LogicalPlan] {
+ def apply(p: LogicalPlan): LogicalPlan = p
+ }
+ val optimizations = List(DummyRule1, DummyRule2)
+ activeSession.experimental.extraOptimizations = optimizations
+ val forkedSession = activeSession.cloneSession()
+
+ // inheritance
+ assert(forkedSession ne activeSession)
+ assert(forkedSession.experimental ne activeSession.experimental)
+ assert(forkedSession.experimental.extraOptimizations.toSet ==
+ activeSession.experimental.extraOptimizations.toSet)
+
+ // independence
+ forkedSession.experimental.extraOptimizations = List(DummyRule2)
+ assert(activeSession.experimental.extraOptimizations == optimizations)
+ activeSession.experimental.extraOptimizations = List(DummyRule1)
+ assert(forkedSession.experimental.extraOptimizations == List(DummyRule2))
+ } finally {
+ activeSession.experimental.extraOptimizations = originalExtraOptimizations
+ activeSession.experimental.extraStrategies = originalExtraStrategies
+ }
+ }
+
+ test("fork new sessions and run query on inherited table") {
+ def checkTableExists(sparkSession: SparkSession): Unit = {
+ QueryTest.checkAnswer(sparkSession.sql(
+ """
+ |SELECT x.str, COUNT(*)
+ |FROM df x JOIN df y ON x.str = y.str
+ |GROUP BY x.str
+ """.stripMargin),
+ Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil)
+ }
+
+ val spark = activeSession
+ // Cannot use `import activeSession.implicits._` due to the compiler limitation.
+ import spark.implicits._
+
+ try {
+ activeSession
+ .createDataset[(Int, String)](Seq(1, 2, 3).map(i => (i, i.toString)))
+ .toDF("int", "str")
+ .createOrReplaceTempView("df")
+ checkTableExists(activeSession)
+
+ val forkedSession = activeSession.cloneSession()
+ assert(forkedSession ne activeSession)
+ assert(forkedSession.sessionState ne activeSession.sessionState)
+ checkTableExists(forkedSession)
+ checkTableExists(activeSession.cloneSession()) // ability to clone multiple times
+ checkTableExists(forkedSession.cloneSession()) // clone of clone
+ } finally {
+ activeSession.sql("drop table df")
+ }
+ }
+
+ test("fork new session and inherit reference to SharedState") {
+ val forkedSession = activeSession.cloneSession()
+ assert(activeSession.sharedState eq forkedSession.sharedState)
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala
index 989a7f2698..fcb8ffbc6e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala
@@ -493,6 +493,25 @@ class CatalogSuite
}
}
- // TODO: add tests for the rest of them
+ test("clone Catalog") {
+ // need to test tempTables are cloned
+ assert(spark.catalog.listTables().collect().isEmpty)
+ createTempTable("my_temp_table")
+ assert(spark.catalog.listTables().collect().map(_.name).toSet == Set("my_temp_table"))
+
+ // inheritance
+ val forkedSession = spark.cloneSession()
+ assert(spark ne forkedSession)
+ assert(spark.catalog ne forkedSession.catalog)
+ assert(forkedSession.catalog.listTables().collect().map(_.name).toSet == Set("my_temp_table"))
+
+ // independence
+ dropTable("my_temp_table") // drop table in original session
+ assert(spark.catalog.listTables().collect().map(_.name).toSet == Set())
+ assert(forkedSession.catalog.listTables().collect().map(_.name).toSet == Set("my_temp_table"))
+ forkedSession.sessionState.catalog
+ .createTempView("fork_table", Range(1, 2, 3, 4), overrideIfExists = true)
+ assert(spark.catalog.listTables().collect().map(_.name).toSet == Set())
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala
index 0e3a5ca9d7..f2456c7704 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala
@@ -187,4 +187,22 @@ class SQLConfEntrySuite extends SparkFunSuite {
}
assert(e2.getMessage === "The maximum size of the cache must not be negative")
}
+
+ test("clone SQLConf") {
+ val original = new SQLConf
+ val key = "spark.sql.SQLConfEntrySuite.clone"
+ assert(original.getConfString(key, "noentry") === "noentry")
+
+ // inheritance
+ original.setConfString(key, "orig")
+ val clone = original.clone()
+ assert(original ne clone)
+ assert(clone.getConfString(key, "noentry") === "orig")
+
+ // independence
+ clone.setConfString(key, "clone")
+ assert(original.getConfString(key, "noentry") === "orig")
+ original.setConfString(key, "dontcopyme")
+ assert(clone.getConfString(key, "noentry") === "clone")
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
index 8ab6db175d..898a2fb4f3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
@@ -35,18 +35,16 @@ private[sql] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) {
}
@transient
- override lazy val sessionState: SessionState = new SessionState(self) {
- override lazy val conf: SQLConf = {
- new SQLConf {
- clear()
- override def clear(): Unit = {
- super.clear()
- // Make sure we start with the default test configs even after clear
- TestSQLContext.overrideConfs.foreach { case (key, value) => setConfString(key, value) }
- }
+ override lazy val sessionState: SessionState = SessionState(
+ this,
+ new SQLConf {
+ clear()
+ override def clear(): Unit = {
+ super.clear()
+ // Make sure we start with the default test configs even after clear
+ TestSQLContext.overrideConfs.foreach { case (key, value) => setConfString(key, value) }
}
- }
- }
+ })
// Needed for Java tests
def loadTestData(): Unit = {
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index 4d3b6c3cec..d135dfa9f4 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -41,8 +41,9 @@ import org.apache.spark.sql.types._
* cleaned up to integrate more nicely with [[HiveExternalCatalog]].
*/
private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Logging {
- private val sessionState = sparkSession.sessionState.asInstanceOf[HiveSessionState]
- private lazy val tableRelationCache = sparkSession.sessionState.catalog.tableRelationCache
+ // these are def_s and not val/lazy val since the latter would introduce circular references
+ private def sessionState = sparkSession.sessionState.asInstanceOf[HiveSessionState]
+ private def tableRelationCache = sparkSession.sessionState.catalog.tableRelationCache
private def getCurrentDatabase: String = sessionState.catalog.getCurrentDatabase
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
index f1ea86890c..6b7599e3d3 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
@@ -26,7 +26,7 @@ import org.apache.hadoop.hive.ql.exec.{FunctionRegistry => HiveFunctionRegistry}
import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, GenericUDF, GenericUDTF}
import org.apache.spark.sql.{AnalysisException, SparkSession}
-import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
+import org.apache.spark.sql.catalyst.{CatalystConf, FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.catalog.{FunctionResourceLoader, GlobalTempViewManager, SessionCatalog}
@@ -43,31 +43,23 @@ import org.apache.spark.util.Utils
private[sql] class HiveSessionCatalog(
externalCatalog: HiveExternalCatalog,
globalTempViewManager: GlobalTempViewManager,
- sparkSession: SparkSession,
- functionResourceLoader: FunctionResourceLoader,
+ private val metastoreCatalog: HiveMetastoreCatalog,
functionRegistry: FunctionRegistry,
conf: SQLConf,
hadoopConf: Configuration,
parser: ParserInterface)
extends SessionCatalog(
- externalCatalog,
- globalTempViewManager,
- functionResourceLoader,
- functionRegistry,
- conf,
- hadoopConf,
- parser) {
+ externalCatalog,
+ globalTempViewManager,
+ functionRegistry,
+ conf,
+ hadoopConf,
+ parser) {
// ----------------------------------------------------------------
// | Methods and fields for interacting with HiveMetastoreCatalog |
// ----------------------------------------------------------------
- // Catalog for handling data source tables. TODO: This really doesn't belong here since it is
- // essentially a cache for metastore tables. However, it relies on a lot of session-specific
- // things so it would be a lot of work to split its functionality between HiveSessionCatalog
- // and HiveCatalog. We should still do it at some point...
- private val metastoreCatalog = new HiveMetastoreCatalog(sparkSession)
-
// These 2 rules must be run before all other DDL post-hoc resolution rules, i.e.
// `PreprocessTableCreation`, `PreprocessTableInsertion`, `DataSourceAnalysis` and `HiveAnalysis`.
val ParquetConversions: Rule[LogicalPlan] = metastoreCatalog.ParquetConversions
@@ -77,10 +69,51 @@ private[sql] class HiveSessionCatalog(
metastoreCatalog.hiveDefaultTableFilePath(name)
}
+ /**
+ * Create a new [[HiveSessionCatalog]] with the provided parameters. `externalCatalog` and
+ * `globalTempViewManager` are `inherited`, while `currentDb` and `tempTables` are copied.
+ */
+ def newSessionCatalogWith(
+ newSparkSession: SparkSession,
+ conf: SQLConf,
+ hadoopConf: Configuration,
+ functionRegistry: FunctionRegistry,
+ parser: ParserInterface): HiveSessionCatalog = {
+ val catalog = HiveSessionCatalog(
+ newSparkSession,
+ functionRegistry,
+ conf,
+ hadoopConf,
+ parser)
+
+ synchronized {
+ catalog.currentDb = currentDb
+ // copy over temporary tables
+ tempTables.foreach(kv => catalog.tempTables.put(kv._1, kv._2))
+ }
+
+ catalog
+ }
+
+ /**
+ * The parent class [[SessionCatalog]] cannot access the [[SparkSession]] class, so we cannot add
+ * a [[SparkSession]] parameter to [[SessionCatalog.newSessionCatalogWith]]. However,
+ * [[HiveSessionCatalog]] requires a [[SparkSession]] parameter, so we can a new version of
+ * `newSessionCatalogWith` and disable this one.
+ *
+ * TODO Refactor HiveSessionCatalog to not use [[SparkSession]] directly.
+ */
+ override def newSessionCatalogWith(
+ conf: CatalystConf,
+ hadoopConf: Configuration,
+ functionRegistry: FunctionRegistry,
+ parser: ParserInterface): HiveSessionCatalog = throw new UnsupportedOperationException(
+ "to clone HiveSessionCatalog, use the other clone method that also accepts a SparkSession")
+
// For testing only
private[hive] def getCachedDataSourceTable(table: TableIdentifier): LogicalPlan = {
val key = metastoreCatalog.getQualifiedTableName(table)
- sparkSession.sessionState.catalog.tableRelationCache.getIfPresent(key)
+ tableRelationCache.getIfPresent(key)
}
override def makeFunctionBuilder(funcName: String, className: String): FunctionBuilder = {
@@ -217,3 +250,28 @@ private[sql] class HiveSessionCatalog(
"histogram_numeric"
)
}
+
+private[sql] object HiveSessionCatalog {
+
+ def apply(
+ sparkSession: SparkSession,
+ functionRegistry: FunctionRegistry,
+ conf: SQLConf,
+ hadoopConf: Configuration,
+ parser: ParserInterface): HiveSessionCatalog = {
+ // Catalog for handling data source tables. TODO: This really doesn't belong here since it is
+ // essentially a cache for metastore tables. However, it relies on a lot of session-specific
+ // things so it would be a lot of work to split its functionality between HiveSessionCatalog
+ // and HiveCatalog. We should still do it at some point...
+ val metastoreCatalog = new HiveMetastoreCatalog(sparkSession)
+
+ new HiveSessionCatalog(
+ sparkSession.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog],
+ sparkSession.sharedState.globalTempViewManager,
+ metastoreCatalog,
+ functionRegistry,
+ conf,
+ hadoopConf,
+ parser)
+ }
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
index 5a08a6bc66..cb8bcb8591 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
@@ -17,89 +17,65 @@
package org.apache.spark.sql.hive
+import org.apache.spark.SparkContext
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.analysis.Analyzer
-import org.apache.spark.sql.execution.SparkPlanner
+import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry}
+import org.apache.spark.sql.catalyst.parser.ParserInterface
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.{QueryExecution, SparkPlanner, SparkSqlParser}
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.hive.client.HiveClient
-import org.apache.spark.sql.internal.SessionState
+import org.apache.spark.sql.internal.{SessionState, SharedState, SQLConf}
+import org.apache.spark.sql.streaming.StreamingQueryManager
/**
* A class that holds all session-specific state in a given [[SparkSession]] backed by Hive.
+ * @param sparkContext The [[SparkContext]].
+ * @param sharedState The shared state.
+ * @param conf SQL-specific key-value configurations.
+ * @param experimentalMethods The experimental methods.
+ * @param functionRegistry Internal catalog for managing functions registered by the user.
+ * @param catalog Internal catalog for managing table and database states that uses Hive client for
+ * interacting with the metastore.
+ * @param sqlParser Parser that extracts expressions, plans, table identifiers etc. from SQL texts.
+ * @param metadataHive The Hive metadata client.
+ * @param analyzer Logical query plan analyzer for resolving unresolved attributes and relations.
+ * @param streamingQueryManager Interface to start and stop
+ * [[org.apache.spark.sql.streaming.StreamingQuery]]s.
+ * @param queryExecutionCreator Lambda to create a [[QueryExecution]] from a [[LogicalPlan]]
+ * @param plannerCreator Lambda to create a planner that takes into account Hive-specific strategies
*/
-private[hive] class HiveSessionState(sparkSession: SparkSession)
- extends SessionState(sparkSession) {
-
- self =>
-
- /**
- * A Hive client used for interacting with the metastore.
- */
- lazy val metadataHive: HiveClient =
- sparkSession.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client.newSession()
-
- /**
- * Internal catalog for managing table and database states.
- */
- override lazy val catalog = {
- new HiveSessionCatalog(
- sparkSession.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog],
- sparkSession.sharedState.globalTempViewManager,
- sparkSession,
- functionResourceLoader,
- functionRegistry,
+private[hive] class HiveSessionState(
+ sparkContext: SparkContext,
+ sharedState: SharedState,
+ conf: SQLConf,
+ experimentalMethods: ExperimentalMethods,
+ functionRegistry: FunctionRegistry,
+ override val catalog: HiveSessionCatalog,
+ sqlParser: ParserInterface,
+ val metadataHive: HiveClient,
+ analyzer: Analyzer,
+ streamingQueryManager: StreamingQueryManager,
+ queryExecutionCreator: LogicalPlan => QueryExecution,
+ val plannerCreator: () => SparkPlanner)
+ extends SessionState(
+ sparkContext,
+ sharedState,
conf,
- newHadoopConf(),
- sqlParser)
- }
-
- /**
- * An analyzer that uses the Hive metastore.
- */
- override lazy val analyzer: Analyzer = {
- new Analyzer(catalog, conf) {
- override val extendedResolutionRules =
- new ResolveHiveSerdeTable(sparkSession) ::
- new FindDataSourceTable(sparkSession) ::
- new ResolveSQLOnFile(sparkSession) :: Nil
-
- override val postHocResolutionRules =
- new DetermineTableStats(sparkSession) ::
- catalog.ParquetConversions ::
- catalog.OrcConversions ::
- PreprocessTableCreation(sparkSession) ::
- PreprocessTableInsertion(conf) ::
- DataSourceAnalysis(conf) ::
- HiveAnalysis :: Nil
-
- override val extendedCheckRules = Seq(PreWriteCheck)
- }
- }
+ experimentalMethods,
+ functionRegistry,
+ catalog,
+ sqlParser,
+ analyzer,
+ streamingQueryManager,
+ queryExecutionCreator) { self =>
/**
* Planner that takes into account Hive-specific strategies.
*/
- override def planner: SparkPlanner = {
- new SparkPlanner(sparkSession.sparkContext, conf, experimentalMethods.extraStrategies)
- with HiveStrategies {
- override val sparkSession: SparkSession = self.sparkSession
-
- override def strategies: Seq[Strategy] = {
- experimentalMethods.extraStrategies ++ Seq(
- FileSourceStrategy,
- DataSourceStrategy,
- SpecialLimits,
- InMemoryScans,
- HiveTableScans,
- Scripts,
- Aggregation,
- JoinSelection,
- BasicOperators
- )
- }
- }
- }
+ override def planner: SparkPlanner = plannerCreator()
// ------------------------------------------------------
@@ -146,4 +122,149 @@ private[hive] class HiveSessionState(sparkSession: SparkSession)
conf.getConf(HiveUtils.HIVE_THRIFT_SERVER_ASYNC)
}
+ /**
+ * Get an identical copy of the `HiveSessionState`.
+ * This should ideally reuse the `SessionState.clone` but cannot do so.
+ * Doing that will throw an exception when trying to clone the catalog.
+ */
+ override def clone(newSparkSession: SparkSession): HiveSessionState = {
+ val sparkContext = newSparkSession.sparkContext
+ val confCopy = conf.clone()
+ val functionRegistryCopy = functionRegistry.clone()
+ val experimentalMethodsCopy = experimentalMethods.clone()
+ val sqlParser: ParserInterface = new SparkSqlParser(confCopy)
+ val catalogCopy = catalog.newSessionCatalogWith(
+ newSparkSession,
+ confCopy,
+ SessionState.newHadoopConf(sparkContext.hadoopConfiguration, confCopy),
+ functionRegistryCopy,
+ sqlParser)
+ val queryExecutionCreator = (plan: LogicalPlan) => new QueryExecution(newSparkSession, plan)
+
+ val hiveClient =
+ newSparkSession.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client
+ .newSession()
+
+ SessionState.mergeSparkConf(confCopy, sparkContext.getConf)
+
+ new HiveSessionState(
+ sparkContext,
+ newSparkSession.sharedState,
+ confCopy,
+ experimentalMethodsCopy,
+ functionRegistryCopy,
+ catalogCopy,
+ sqlParser,
+ hiveClient,
+ HiveSessionState.createAnalyzer(newSparkSession, catalogCopy, confCopy),
+ new StreamingQueryManager(newSparkSession),
+ queryExecutionCreator,
+ HiveSessionState.createPlannerCreator(
+ newSparkSession,
+ confCopy,
+ experimentalMethodsCopy))
+ }
+
+}
+
+private[hive] object HiveSessionState {
+
+ def apply(sparkSession: SparkSession): HiveSessionState = {
+ apply(sparkSession, new SQLConf)
+ }
+
+ def apply(sparkSession: SparkSession, conf: SQLConf): HiveSessionState = {
+ val initHelper = SessionState(sparkSession, conf)
+
+ val sparkContext = sparkSession.sparkContext
+
+ val catalog = HiveSessionCatalog(
+ sparkSession,
+ initHelper.functionRegistry,
+ initHelper.conf,
+ SessionState.newHadoopConf(sparkContext.hadoopConfiguration, initHelper.conf),
+ initHelper.sqlParser)
+
+ val metadataHive: HiveClient =
+ sparkSession.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client
+ .newSession()
+
+ val analyzer: Analyzer = createAnalyzer(sparkSession, catalog, initHelper.conf)
+
+ val plannerCreator = createPlannerCreator(
+ sparkSession,
+ initHelper.conf,
+ initHelper.experimentalMethods)
+
+ val hiveSessionState = new HiveSessionState(
+ sparkContext,
+ sparkSession.sharedState,
+ initHelper.conf,
+ initHelper.experimentalMethods,
+ initHelper.functionRegistry,
+ catalog,
+ initHelper.sqlParser,
+ metadataHive,
+ analyzer,
+ initHelper.streamingQueryManager,
+ initHelper.queryExecutionCreator,
+ plannerCreator)
+ catalog.functionResourceLoader = hiveSessionState.functionResourceLoader
+ hiveSessionState
+ }
+
+ /**
+ * Create an logical query plan `Analyzer` with rules specific to a `HiveSessionState`.
+ */
+ private def createAnalyzer(
+ sparkSession: SparkSession,
+ catalog: HiveSessionCatalog,
+ sqlConf: SQLConf): Analyzer = {
+ new Analyzer(catalog, sqlConf) {
+ override val extendedResolutionRules: Seq[Rule[LogicalPlan]] =
+ new ResolveHiveSerdeTable(sparkSession) ::
+ new FindDataSourceTable(sparkSession) ::
+ new ResolveSQLOnFile(sparkSession) :: Nil
+
+ override val postHocResolutionRules: Seq[Rule[LogicalPlan]] =
+ new DetermineTableStats(sparkSession) ::
+ catalog.ParquetConversions ::
+ catalog.OrcConversions ::
+ PreprocessTableCreation(sparkSession) ::
+ PreprocessTableInsertion(sqlConf) ::
+ DataSourceAnalysis(sqlConf) ::
+ HiveAnalysis :: Nil
+
+ override val extendedCheckRules = Seq(PreWriteCheck)
+ }
+ }
+
+ private def createPlannerCreator(
+ associatedSparkSession: SparkSession,
+ sqlConf: SQLConf,
+ experimentalMethods: ExperimentalMethods): () => SparkPlanner = {
+ () =>
+ new SparkPlanner(
+ associatedSparkSession.sparkContext,
+ sqlConf,
+ experimentalMethods.extraStrategies)
+ with HiveStrategies {
+
+ override val sparkSession: SparkSession = associatedSparkSession
+
+ override def strategies: Seq[Strategy] = {
+ experimentalMethods.extraStrategies ++ Seq(
+ FileSourceStrategy,
+ DataSourceStrategy,
+ SpecialLimits,
+ InMemoryScans,
+ HiveTableScans,
+ Scripts,
+ Aggregation,
+ JoinSelection,
+ BasicOperators
+ )
+ }
+ }
+ }
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
index 469c9d84de..6e1f429286 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
@@ -278,6 +278,8 @@ private[hive] class HiveClientImpl(
state.getConf.setClassLoader(clientLoader.classLoader)
// Set the thread local metastore client to the client associated with this HiveClientImpl.
Hive.set(client)
+ // Replace conf in the thread local Hive with current conf
+ Hive.get(conf)
// setCurrentSessionState will use the classLoader associated
// with the HiveConf in `state` to override the context class loader of the current
// thread.
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
index efc2f00984..076c40d459 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
@@ -30,16 +30,17 @@ import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{SparkSession, SQLContext}
-import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
-import org.apache.spark.sql.catalyst.expressions.ExpressionInfo
+import org.apache.spark.sql.{ExperimentalMethods, SparkSession, SQLContext}
+import org.apache.spark.sql.catalyst.analysis.{Analyzer, UnresolvedRelation}
+import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.execution.QueryExecution
+import org.apache.spark.sql.execution.{QueryExecution, SparkPlanner}
import org.apache.spark.sql.execution.command.CacheTableCommand
import org.apache.spark.sql.hive._
-import org.apache.spark.sql.internal.{SharedState, SQLConf}
+import org.apache.spark.sql.hive.client.HiveClient
+import org.apache.spark.sql.internal.{SessionState, SharedState, SQLConf}
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
+import org.apache.spark.sql.streaming.StreamingQueryManager
import org.apache.spark.util.{ShutdownHookManager, Utils}
// SPARK-3729: Test key required to check for initialization errors with config.
@@ -84,7 +85,7 @@ class TestHiveContext(
new TestHiveContext(sparkSession.newSession())
}
- override def sessionState: TestHiveSessionState = sparkSession.sessionState
+ override def sessionState: HiveSessionState = sparkSession.sessionState
def setCacheTables(c: Boolean): Unit = {
sparkSession.setCacheTables(c)
@@ -144,11 +145,35 @@ private[hive] class TestHiveSparkSession(
existingSharedState.getOrElse(new SharedState(sc))
}
- // TODO: Let's remove TestHiveSessionState. Otherwise, we are not really testing the reflection
- // logic based on the setting of CATALOG_IMPLEMENTATION.
@transient
- override lazy val sessionState: TestHiveSessionState =
- new TestHiveSessionState(self)
+ override lazy val sessionState: HiveSessionState = {
+ val testConf =
+ new SQLConf {
+ clear()
+ override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false)
+ override def clear(): Unit = {
+ super.clear()
+ TestHiveContext.overrideConfs.foreach { case (k, v) => setConfString(k, v) }
+ }
+ }
+ val queryExecutionCreator = (plan: LogicalPlan) => new TestHiveQueryExecution(this, plan)
+ val initHelper = HiveSessionState(this, testConf)
+ SessionState.mergeSparkConf(testConf, sparkContext.getConf)
+
+ new HiveSessionState(
+ sparkContext,
+ sharedState,
+ testConf,
+ initHelper.experimentalMethods,
+ initHelper.functionRegistry,
+ initHelper.catalog,
+ initHelper.sqlParser,
+ initHelper.metadataHive,
+ initHelper.analyzer,
+ initHelper.streamingQueryManager,
+ queryExecutionCreator,
+ initHelper.plannerCreator)
+ }
override def newSession(): TestHiveSparkSession = {
new TestHiveSparkSession(sc, Some(sharedState), loadTestTables)
@@ -492,26 +517,6 @@ private[hive] class TestHiveQueryExecution(
}
}
-private[hive] class TestHiveSessionState(
- sparkSession: TestHiveSparkSession)
- extends HiveSessionState(sparkSession) { self =>
-
- override lazy val conf: SQLConf = {
- new SQLConf {
- clear()
- override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false)
- override def clear(): Unit = {
- super.clear()
- TestHiveContext.overrideConfs.foreach { case (k, v) => setConfString(k, v) }
- }
- }
- }
-
- override def executePlan(plan: LogicalPlan): TestHiveQueryExecution = {
- new TestHiveQueryExecution(sparkSession, plan)
- }
-}
-
private[hive] object TestHiveContext {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionCatalogSuite.scala
new file mode 100644
index 0000000000..3b0f59b159
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionCatalogSuite.scala
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+import java.net.URI
+
+import org.apache.hadoop.conf.Configuration
+
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.analysis.SimpleFunctionRegistry
+import org.apache.spark.sql.catalyst.catalog.CatalogDatabase
+import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
+import org.apache.spark.sql.catalyst.plans.logical.Range
+import org.apache.spark.sql.hive.test.TestHiveSingleton
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.util.Utils
+
+class HiveSessionCatalogSuite extends TestHiveSingleton {
+
+ test("clone HiveSessionCatalog") {
+ val original = spark.sessionState.catalog.asInstanceOf[HiveSessionCatalog]
+
+ val tempTableName1 = "copytest1"
+ val tempTableName2 = "copytest2"
+ try {
+ val tempTable1 = Range(1, 10, 1, 10)
+ original.createTempView(tempTableName1, tempTable1, overrideIfExists = false)
+
+ // check if tables copied over
+ val clone = original.newSessionCatalogWith(
+ spark,
+ new SQLConf,
+ new Configuration(),
+ new SimpleFunctionRegistry,
+ CatalystSqlParser)
+ assert(original ne clone)
+ assert(clone.getTempView(tempTableName1) == Some(tempTable1))
+
+ // check if clone and original independent
+ clone.dropTable(TableIdentifier(tempTableName1), ignoreIfNotExists = false, purge = false)
+ assert(original.getTempView(tempTableName1) == Some(tempTable1))
+
+ val tempTable2 = Range(1, 20, 2, 10)
+ original.createTempView(tempTableName2, tempTable2, overrideIfExists = false)
+ assert(clone.getTempView(tempTableName2).isEmpty)
+ } finally {
+ // Drop the created temp views from the global singleton HiveSession.
+ original.dropTable(TableIdentifier(tempTableName1), ignoreIfNotExists = true, purge = true)
+ original.dropTable(TableIdentifier(tempTableName2), ignoreIfNotExists = true, purge = true)
+ }
+ }
+
+ test("clone SessionCatalog - current db") {
+ val original = spark.sessionState.catalog.asInstanceOf[HiveSessionCatalog]
+ val originalCurrentDatabase = original.getCurrentDatabase
+ val db1 = "db1"
+ val db2 = "db2"
+ val db3 = "db3"
+ try {
+ original.createDatabase(newDb(db1), ignoreIfExists = true)
+ original.createDatabase(newDb(db2), ignoreIfExists = true)
+ original.createDatabase(newDb(db3), ignoreIfExists = true)
+
+ original.setCurrentDatabase(db1)
+
+ // check if tables copied over
+ val clone = original.newSessionCatalogWith(
+ spark,
+ new SQLConf,
+ new Configuration(),
+ new SimpleFunctionRegistry,
+ CatalystSqlParser)
+
+ // check if current db copied over
+ assert(original ne clone)
+ assert(clone.getCurrentDatabase == db1)
+
+ // check if clone and original independent
+ clone.setCurrentDatabase(db2)
+ assert(original.getCurrentDatabase == db1)
+ original.setCurrentDatabase(db3)
+ assert(clone.getCurrentDatabase == db2)
+ } finally {
+ // Drop the created databases from the global singleton HiveSession.
+ original.dropDatabase(db1, ignoreIfNotExists = true, cascade = true)
+ original.dropDatabase(db2, ignoreIfNotExists = true, cascade = true)
+ original.dropDatabase(db3, ignoreIfNotExists = true, cascade = true)
+ original.setCurrentDatabase(originalCurrentDatabase)
+ }
+ }
+
+ def newUriForDatabase(): URI = new URI(Utils.createTempDir().toURI.toString.stripSuffix("/"))
+
+ def newDb(name: String): CatalogDatabase = {
+ CatalogDatabase(name, name + " description", newUriForDatabase(), Map.empty)
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala
new file mode 100644
index 0000000000..67c77fb62f
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+import org.scalatest.BeforeAndAfterEach
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.hive.test.TestHiveSingleton
+
+/**
+ * Run all tests from `SessionStateSuite` with a `HiveSessionState`.
+ */
+class HiveSessionStateSuite extends SessionStateSuite
+ with TestHiveSingleton with BeforeAndAfterEach {
+
+ override def beforeAll(): Unit = {
+ // Reuse the singleton session
+ activeSession = spark
+ }
+
+ override def afterAll(): Unit = {
+ // Set activeSession to null to avoid stopping the singleton session
+ activeSession = null
+ super.afterAll()
+ }
+}