aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala59
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala235
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala162
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala21
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala18
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala20
9 files changed, 421 insertions, 111 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala
index 1e8ba51e59..bd8dd6ea3f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala
@@ -46,4 +46,10 @@ class ExperimentalMethods private[sql]() {
@volatile var extraOptimizations: Seq[Rule[LogicalPlan]] = Nil
+ override def clone(): ExperimentalMethods = {
+ val result = new ExperimentalMethods
+ result.extraStrategies = extraStrategies
+ result.extraOptimizations = extraOptimizations
+ result
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index afc1827e7e..49562578b2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -21,7 +21,6 @@ import java.io.Closeable
import java.util.concurrent.atomic.AtomicReference
import scala.collection.JavaConverters._
-import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal
@@ -43,7 +42,7 @@ import org.apache.spark.sql.internal.{CatalogImpl, SessionState, SharedState}
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.streaming._
-import org.apache.spark.sql.types.{DataType, LongType, StructType}
+import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.util.ExecutionListenerManager
import org.apache.spark.util.Utils
@@ -67,15 +66,22 @@ import org.apache.spark.util.Utils
* .config("spark.some.config.option", "some-value")
* .getOrCreate()
* }}}
+ *
+ * @param sparkContext The Spark context associated with this Spark session.
+ * @param existingSharedState If supplied, use the existing shared state
+ * instead of creating a new one.
+ * @param parentSessionState If supplied, inherit all session state (i.e. temporary
+ * views, SQL config, UDFs etc) from parent.
*/
@InterfaceStability.Stable
class SparkSession private(
@transient val sparkContext: SparkContext,
- @transient private val existingSharedState: Option[SharedState])
+ @transient private val existingSharedState: Option[SharedState],
+ @transient private val parentSessionState: Option[SessionState])
extends Serializable with Closeable with Logging { self =>
private[sql] def this(sc: SparkContext) {
- this(sc, None)
+ this(sc, None, None)
}
sparkContext.assertNotStopped()
@@ -108,6 +114,7 @@ class SparkSession private(
/**
* State isolated across sessions, including SQL configurations, temporary tables, registered
* functions, and everything else that accepts a [[org.apache.spark.sql.internal.SQLConf]].
+ * If `parentSessionState` is not null, the `SessionState` will be a copy of the parent.
*
* This is internal to Spark and there is no guarantee on interface stability.
*
@@ -116,9 +123,13 @@ class SparkSession private(
@InterfaceStability.Unstable
@transient
lazy val sessionState: SessionState = {
- SparkSession.reflect[SessionState, SparkSession](
- SparkSession.sessionStateClassName(sparkContext.conf),
- self)
+ parentSessionState
+ .map(_.clone(this))
+ .getOrElse {
+ SparkSession.instantiateSessionState(
+ SparkSession.sessionStateClassName(sparkContext.conf),
+ self)
+ }
}
/**
@@ -208,7 +219,25 @@ class SparkSession private(
* @since 2.0.0
*/
def newSession(): SparkSession = {
- new SparkSession(sparkContext, Some(sharedState))
+ new SparkSession(sparkContext, Some(sharedState), parentSessionState = None)
+ }
+
+ /**
+ * Create an identical copy of this `SparkSession`, sharing the underlying `SparkContext`
+ * and shared state. All the state of this session (i.e. SQL configurations, temporary tables,
+ * registered functions) is copied over, and the cloned session is set up with the same shared
+ * state as this session. The cloned session is independent of this session, that is, any
+ * non-global change in either session is not reflected in the other.
+ *
+ * @note Other than the `SparkContext`, all shared state is initialized lazily.
+ * This method will force the initialization of the shared state to ensure that parent
+ * and child sessions are set up with the same shared state. If the underlying catalog
+ * implementation is Hive, this will initialize the metastore, which may take some time.
+ */
+ private[sql] def cloneSession(): SparkSession = {
+ val result = new SparkSession(sparkContext, Some(sharedState), Some(sessionState))
+ result.sessionState // force copy of SessionState
+ result
}
@@ -971,16 +1000,18 @@ object SparkSession {
}
/**
- * Helper method to create an instance of [[T]] using a single-arg constructor that
- * accepts an [[Arg]].
+ * Helper method to create an instance of `SessionState` based on `className` from conf.
+ * The result is either `SessionState` or `HiveSessionState`.
*/
- private def reflect[T, Arg <: AnyRef](
+ private def instantiateSessionState(
className: String,
- ctorArg: Arg)(implicit ctorArgTag: ClassTag[Arg]): T = {
+ sparkSession: SparkSession): SessionState = {
+
try {
+ // get `SessionState.apply(SparkSession)`
val clazz = Utils.classForName(className)
- val ctor = clazz.getDeclaredConstructor(ctorArgTag.runtimeClass)
- ctor.newInstance(ctorArg).asInstanceOf[T]
+ val method = clazz.getMethod("apply", sparkSession.getClass)
+ method.invoke(null, sparkSession).asInstanceOf[SessionState]
} catch {
case NonFatal(e) =>
throw new IllegalArgumentException(s"Error while instantiating '$className':", e)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
index 4d781b96ab..8b598cc60e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
@@ -66,7 +66,8 @@ class ResolveSQLOnFile(sparkSession: SparkSession) extends Rule[LogicalPlan] {
* Preprocess [[CreateTable]], to do some normalization and checking.
*/
case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[LogicalPlan] {
- private val catalog = sparkSession.sessionState.catalog
+ // catalog is a def and not a val/lazy val as the latter would introduce a circular reference
+ private def catalog = sparkSession.sessionState.catalog
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// When we CREATE TABLE without specifying the table schema, we should fail the query if
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 94e3fa7dd1..1244f690fd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -1019,6 +1019,14 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging {
def clear(): Unit = {
settings.clear()
}
+
+ override def clone(): SQLConf = {
+ val result = new SQLConf
+ getAllConfs.foreach {
+ case(k, v) => if (v ne null) result.setConfString(k, v)
+ }
+ result
+ }
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
index 6908560511..ce80604bd3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
@@ -22,38 +22,49 @@ import java.io.File
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
+import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry}
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.command.AnalyzeTableCommand
import org.apache.spark.sql.execution.datasources._
-import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryManager}
+import org.apache.spark.sql.streaming.StreamingQueryManager
import org.apache.spark.sql.util.ExecutionListenerManager
/**
* A class that holds all session-specific state in a given [[SparkSession]].
+ * @param sparkContext The [[SparkContext]].
+ * @param sharedState The shared state.
+ * @param conf SQL-specific key-value configurations.
+ * @param experimentalMethods The experimental methods.
+ * @param functionRegistry Internal catalog for managing functions registered by the user.
+ * @param catalog Internal catalog for managing table and database states.
+ * @param sqlParser Parser that extracts expressions, plans, table identifiers etc. from SQL texts.
+ * @param analyzer Logical query plan analyzer for resolving unresolved attributes and relations.
+ * @param streamingQueryManager Interface to start and stop
+ * [[org.apache.spark.sql.streaming.StreamingQuery]]s.
+ * @param queryExecutionCreator Lambda to create a [[QueryExecution]] from a [[LogicalPlan]]
*/
-private[sql] class SessionState(sparkSession: SparkSession) {
+private[sql] class SessionState(
+ sparkContext: SparkContext,
+ sharedState: SharedState,
+ val conf: SQLConf,
+ val experimentalMethods: ExperimentalMethods,
+ val functionRegistry: FunctionRegistry,
+ val catalog: SessionCatalog,
+ val sqlParser: ParserInterface,
+ val analyzer: Analyzer,
+ val streamingQueryManager: StreamingQueryManager,
+ val queryExecutionCreator: LogicalPlan => QueryExecution) {
- // Note: These are all lazy vals because they depend on each other (e.g. conf) and we
- // want subclasses to override some of the fields. Otherwise, we would get a lot of NPEs.
-
- /**
- * SQL-specific key-value configurations.
- */
- lazy val conf: SQLConf = new SQLConf
-
- def newHadoopConf(): Configuration = {
- val hadoopConf = new Configuration(sparkSession.sparkContext.hadoopConfiguration)
- conf.getAllConfs.foreach { case (k, v) => if (v ne null) hadoopConf.set(k, v) }
- hadoopConf
- }
+ def newHadoopConf(): Configuration = SessionState.newHadoopConf(
+ sparkContext.hadoopConfiguration,
+ conf)
def newHadoopConfWithOptions(options: Map[String, String]): Configuration = {
val hadoopConf = newHadoopConf()
@@ -65,22 +76,15 @@ private[sql] class SessionState(sparkSession: SparkSession) {
hadoopConf
}
- lazy val experimentalMethods = new ExperimentalMethods
-
- /**
- * Internal catalog for managing functions registered by the user.
- */
- lazy val functionRegistry: FunctionRegistry = FunctionRegistry.builtin.copy()
-
/**
* A class for loading resources specified by a function.
*/
- lazy val functionResourceLoader: FunctionResourceLoader = {
+ val functionResourceLoader: FunctionResourceLoader = {
new FunctionResourceLoader {
override def loadResource(resource: FunctionResource): Unit = {
resource.resourceType match {
case JarResource => addJar(resource.uri)
- case FileResource => sparkSession.sparkContext.addFile(resource.uri)
+ case FileResource => sparkContext.addFile(resource.uri)
case ArchiveResource =>
throw new AnalysisException(
"Archive is not allowed to be loaded. If YARN mode is used, " +
@@ -91,92 +95,77 @@ private[sql] class SessionState(sparkSession: SparkSession) {
}
/**
- * Internal catalog for managing table and database states.
- */
- lazy val catalog = new SessionCatalog(
- sparkSession.sharedState.externalCatalog,
- sparkSession.sharedState.globalTempViewManager,
- functionResourceLoader,
- functionRegistry,
- conf,
- newHadoopConf(),
- sqlParser)
-
- /**
* Interface exposed to the user for registering user-defined functions.
* Note that the user-defined functions must be deterministic.
*/
- lazy val udf: UDFRegistration = new UDFRegistration(functionRegistry)
-
- /**
- * Logical query plan analyzer for resolving unresolved attributes and relations.
- */
- lazy val analyzer: Analyzer = {
- new Analyzer(catalog, conf) {
- override val extendedResolutionRules =
- new FindDataSourceTable(sparkSession) ::
- new ResolveSQLOnFile(sparkSession) :: Nil
-
- override val postHocResolutionRules =
- PreprocessTableCreation(sparkSession) ::
- PreprocessTableInsertion(conf) ::
- DataSourceAnalysis(conf) :: Nil
-
- override val extendedCheckRules = Seq(PreWriteCheck, HiveOnlyCheck)
- }
- }
+ val udf: UDFRegistration = new UDFRegistration(functionRegistry)
/**
* Logical query plan optimizer.
*/
- lazy val optimizer: Optimizer = new SparkOptimizer(catalog, conf, experimentalMethods)
-
- /**
- * Parser that extracts expressions, plans, table identifiers etc. from SQL texts.
- */
- lazy val sqlParser: ParserInterface = new SparkSqlParser(conf)
+ val optimizer: Optimizer = new SparkOptimizer(catalog, conf, experimentalMethods)
/**
* Planner that converts optimized logical plans to physical plans.
*/
def planner: SparkPlanner =
- new SparkPlanner(sparkSession.sparkContext, conf, experimentalMethods.extraStrategies)
+ new SparkPlanner(sparkContext, conf, experimentalMethods.extraStrategies)
/**
* An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s
* that listen for execution metrics.
*/
- lazy val listenerManager: ExecutionListenerManager = new ExecutionListenerManager
+ val listenerManager: ExecutionListenerManager = new ExecutionListenerManager
/**
- * Interface to start and stop [[StreamingQuery]]s.
+ * Get an identical copy of the `SessionState` and associate it with the given `SparkSession`
*/
- lazy val streamingQueryManager: StreamingQueryManager = {
- new StreamingQueryManager(sparkSession)
- }
+ def clone(newSparkSession: SparkSession): SessionState = {
+ val sparkContext = newSparkSession.sparkContext
+ val confCopy = conf.clone()
+ val functionRegistryCopy = functionRegistry.clone()
+ val sqlParser: ParserInterface = new SparkSqlParser(confCopy)
+ val catalogCopy = catalog.newSessionCatalogWith(
+ confCopy,
+ SessionState.newHadoopConf(sparkContext.hadoopConfiguration, confCopy),
+ functionRegistryCopy,
+ sqlParser)
+ val queryExecutionCreator = (plan: LogicalPlan) => new QueryExecution(newSparkSession, plan)
- private val jarClassLoader: NonClosableMutableURLClassLoader =
- sparkSession.sharedState.jarClassLoader
+ SessionState.mergeSparkConf(confCopy, sparkContext.getConf)
- // Automatically extract all entries and put it in our SQLConf
- // We need to call it after all of vals have been initialized.
- sparkSession.sparkContext.getConf.getAll.foreach { case (k, v) =>
- conf.setConfString(k, v)
+ new SessionState(
+ sparkContext,
+ newSparkSession.sharedState,
+ confCopy,
+ experimentalMethods.clone(),
+ functionRegistryCopy,
+ catalogCopy,
+ sqlParser,
+ SessionState.createAnalyzer(newSparkSession, catalogCopy, confCopy),
+ new StreamingQueryManager(newSparkSession),
+ queryExecutionCreator)
}
// ------------------------------------------------------
// Helper methods, partially leftover from pre-2.0 days
// ------------------------------------------------------
- def executePlan(plan: LogicalPlan): QueryExecution = new QueryExecution(sparkSession, plan)
+ def executePlan(plan: LogicalPlan): QueryExecution = queryExecutionCreator(plan)
def refreshTable(tableName: String): Unit = {
catalog.refreshTable(sqlParser.parseTableIdentifier(tableName))
}
+ /**
+ * Add a jar path to [[SparkContext]] and the classloader.
+ *
+ * Note: this method seems not access any session state, but the subclass `HiveSessionState` needs
+ * to add the jar to its hive client for the current session. Hence, it still needs to be in
+ * [[SessionState]].
+ */
def addJar(path: String): Unit = {
- sparkSession.sparkContext.addJar(path)
-
+ sparkContext.addJar(path)
val uri = new Path(path).toUri
val jarURL = if (uri.getScheme == null) {
// `path` is a local file path without a URL scheme
@@ -185,15 +174,93 @@ private[sql] class SessionState(sparkSession: SparkSession) {
// `path` is a URL with a scheme
uri.toURL
}
- jarClassLoader.addURL(jarURL)
- Thread.currentThread().setContextClassLoader(jarClassLoader)
+ sharedState.jarClassLoader.addURL(jarURL)
+ Thread.currentThread().setContextClassLoader(sharedState.jarClassLoader)
+ }
+}
+
+
+private[sql] object SessionState {
+
+ def apply(sparkSession: SparkSession): SessionState = {
+ apply(sparkSession, new SQLConf)
+ }
+
+ def apply(sparkSession: SparkSession, sqlConf: SQLConf): SessionState = {
+ val sparkContext = sparkSession.sparkContext
+
+ // Automatically extract all entries and put them in our SQLConf
+ mergeSparkConf(sqlConf, sparkContext.getConf)
+
+ val functionRegistry = FunctionRegistry.builtin.clone()
+
+ val sqlParser: ParserInterface = new SparkSqlParser(sqlConf)
+
+ val catalog = new SessionCatalog(
+ sparkSession.sharedState.externalCatalog,
+ sparkSession.sharedState.globalTempViewManager,
+ functionRegistry,
+ sqlConf,
+ newHadoopConf(sparkContext.hadoopConfiguration, sqlConf),
+ sqlParser)
+
+ val analyzer: Analyzer = createAnalyzer(sparkSession, catalog, sqlConf)
+
+ val streamingQueryManager: StreamingQueryManager = new StreamingQueryManager(sparkSession)
+
+ val queryExecutionCreator = (plan: LogicalPlan) => new QueryExecution(sparkSession, plan)
+
+ val sessionState = new SessionState(
+ sparkContext,
+ sparkSession.sharedState,
+ sqlConf,
+ new ExperimentalMethods,
+ functionRegistry,
+ catalog,
+ sqlParser,
+ analyzer,
+ streamingQueryManager,
+ queryExecutionCreator)
+ // functionResourceLoader needs to access SessionState.addJar, so it cannot be created before
+ // creating SessionState. Setting `catalog.functionResourceLoader` here is safe since the caller
+ // cannot use SessionCatalog before we return SessionState.
+ catalog.functionResourceLoader = sessionState.functionResourceLoader
+ sessionState
+ }
+
+ def newHadoopConf(hadoopConf: Configuration, sqlConf: SQLConf): Configuration = {
+ val newHadoopConf = new Configuration(hadoopConf)
+ sqlConf.getAllConfs.foreach { case (k, v) => if (v ne null) newHadoopConf.set(k, v) }
+ newHadoopConf
+ }
+
+ /**
+ * Create an logical query plan `Analyzer` with rules specific to a non-Hive `SessionState`.
+ */
+ private def createAnalyzer(
+ sparkSession: SparkSession,
+ catalog: SessionCatalog,
+ sqlConf: SQLConf): Analyzer = {
+ new Analyzer(catalog, sqlConf) {
+ override val extendedResolutionRules: Seq[Rule[LogicalPlan]] =
+ new FindDataSourceTable(sparkSession) ::
+ new ResolveSQLOnFile(sparkSession) :: Nil
+
+ override val postHocResolutionRules: Seq[Rule[LogicalPlan]] =
+ PreprocessTableCreation(sparkSession) ::
+ PreprocessTableInsertion(sqlConf) ::
+ DataSourceAnalysis(sqlConf) :: Nil
+
+ override val extendedCheckRules = Seq(PreWriteCheck, HiveOnlyCheck)
+ }
}
/**
- * Analyzes the given table in the current database to generate statistics, which will be
- * used in query optimizations.
+ * Extract entries from `SparkConf` and put them in the `SQLConf`
*/
- def analyze(tableIdent: TableIdentifier, noscan: Boolean = true): Unit = {
- AnalyzeTableCommand(tableIdent, noscan).run(sparkSession)
+ def mergeSparkConf(sqlConf: SQLConf, sparkConf: SparkConf): Unit = {
+ sparkConf.getAll.foreach { case (k, v) =>
+ sqlConf.setConfString(k, v)
+ }
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala
new file mode 100644
index 0000000000..2d5e37242a
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.BeforeAndAfterEach
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+
+class SessionStateSuite extends SparkFunSuite
+ with BeforeAndAfterEach with BeforeAndAfterAll {
+
+ /**
+ * A shared SparkSession for all tests in this suite. Make sure you reset any changes to this
+ * session as this is a singleton HiveSparkSession in HiveSessionStateSuite and it's shared
+ * with all Hive test suites.
+ */
+ protected var activeSession: SparkSession = _
+
+ override def beforeAll(): Unit = {
+ activeSession = SparkSession.builder().master("local").getOrCreate()
+ }
+
+ override def afterAll(): Unit = {
+ if (activeSession != null) {
+ activeSession.stop()
+ activeSession = null
+ }
+ super.afterAll()
+ }
+
+ test("fork new session and inherit RuntimeConfig options") {
+ val key = "spark-config-clone"
+ try {
+ activeSession.conf.set(key, "active")
+
+ // inheritance
+ val forkedSession = activeSession.cloneSession()
+ assert(forkedSession ne activeSession)
+ assert(forkedSession.conf ne activeSession.conf)
+ assert(forkedSession.conf.get(key) == "active")
+
+ // independence
+ forkedSession.conf.set(key, "forked")
+ assert(activeSession.conf.get(key) == "active")
+ activeSession.conf.set(key, "dontcopyme")
+ assert(forkedSession.conf.get(key) == "forked")
+ } finally {
+ activeSession.conf.unset(key)
+ }
+ }
+
+ test("fork new session and inherit function registry and udf") {
+ val testFuncName1 = "strlenScala"
+ val testFuncName2 = "addone"
+ try {
+ activeSession.udf.register(testFuncName1, (_: String).length + (_: Int))
+ val forkedSession = activeSession.cloneSession()
+
+ // inheritance
+ assert(forkedSession ne activeSession)
+ assert(forkedSession.sessionState.functionRegistry ne
+ activeSession.sessionState.functionRegistry)
+ assert(forkedSession.sessionState.functionRegistry.lookupFunction(testFuncName1).nonEmpty)
+
+ // independence
+ forkedSession.sessionState.functionRegistry.dropFunction(testFuncName1)
+ assert(activeSession.sessionState.functionRegistry.lookupFunction(testFuncName1).nonEmpty)
+ activeSession.udf.register(testFuncName2, (_: Int) + 1)
+ assert(forkedSession.sessionState.functionRegistry.lookupFunction(testFuncName2).isEmpty)
+ } finally {
+ activeSession.sessionState.functionRegistry.dropFunction(testFuncName1)
+ activeSession.sessionState.functionRegistry.dropFunction(testFuncName2)
+ }
+ }
+
+ test("fork new session and inherit experimental methods") {
+ val originalExtraOptimizations = activeSession.experimental.extraOptimizations
+ val originalExtraStrategies = activeSession.experimental.extraStrategies
+ try {
+ object DummyRule1 extends Rule[LogicalPlan] {
+ def apply(p: LogicalPlan): LogicalPlan = p
+ }
+ object DummyRule2 extends Rule[LogicalPlan] {
+ def apply(p: LogicalPlan): LogicalPlan = p
+ }
+ val optimizations = List(DummyRule1, DummyRule2)
+ activeSession.experimental.extraOptimizations = optimizations
+ val forkedSession = activeSession.cloneSession()
+
+ // inheritance
+ assert(forkedSession ne activeSession)
+ assert(forkedSession.experimental ne activeSession.experimental)
+ assert(forkedSession.experimental.extraOptimizations.toSet ==
+ activeSession.experimental.extraOptimizations.toSet)
+
+ // independence
+ forkedSession.experimental.extraOptimizations = List(DummyRule2)
+ assert(activeSession.experimental.extraOptimizations == optimizations)
+ activeSession.experimental.extraOptimizations = List(DummyRule1)
+ assert(forkedSession.experimental.extraOptimizations == List(DummyRule2))
+ } finally {
+ activeSession.experimental.extraOptimizations = originalExtraOptimizations
+ activeSession.experimental.extraStrategies = originalExtraStrategies
+ }
+ }
+
+ test("fork new sessions and run query on inherited table") {
+ def checkTableExists(sparkSession: SparkSession): Unit = {
+ QueryTest.checkAnswer(sparkSession.sql(
+ """
+ |SELECT x.str, COUNT(*)
+ |FROM df x JOIN df y ON x.str = y.str
+ |GROUP BY x.str
+ """.stripMargin),
+ Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil)
+ }
+
+ val spark = activeSession
+ // Cannot use `import activeSession.implicits._` due to the compiler limitation.
+ import spark.implicits._
+
+ try {
+ activeSession
+ .createDataset[(Int, String)](Seq(1, 2, 3).map(i => (i, i.toString)))
+ .toDF("int", "str")
+ .createOrReplaceTempView("df")
+ checkTableExists(activeSession)
+
+ val forkedSession = activeSession.cloneSession()
+ assert(forkedSession ne activeSession)
+ assert(forkedSession.sessionState ne activeSession.sessionState)
+ checkTableExists(forkedSession)
+ checkTableExists(activeSession.cloneSession()) // ability to clone multiple times
+ checkTableExists(forkedSession.cloneSession()) // clone of clone
+ } finally {
+ activeSession.sql("drop table df")
+ }
+ }
+
+ test("fork new session and inherit reference to SharedState") {
+ val forkedSession = activeSession.cloneSession()
+ assert(activeSession.sharedState eq forkedSession.sharedState)
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala
index 989a7f2698..fcb8ffbc6e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala
@@ -493,6 +493,25 @@ class CatalogSuite
}
}
- // TODO: add tests for the rest of them
+ test("clone Catalog") {
+ // need to test tempTables are cloned
+ assert(spark.catalog.listTables().collect().isEmpty)
+ createTempTable("my_temp_table")
+ assert(spark.catalog.listTables().collect().map(_.name).toSet == Set("my_temp_table"))
+
+ // inheritance
+ val forkedSession = spark.cloneSession()
+ assert(spark ne forkedSession)
+ assert(spark.catalog ne forkedSession.catalog)
+ assert(forkedSession.catalog.listTables().collect().map(_.name).toSet == Set("my_temp_table"))
+
+ // independence
+ dropTable("my_temp_table") // drop table in original session
+ assert(spark.catalog.listTables().collect().map(_.name).toSet == Set())
+ assert(forkedSession.catalog.listTables().collect().map(_.name).toSet == Set("my_temp_table"))
+ forkedSession.sessionState.catalog
+ .createTempView("fork_table", Range(1, 2, 3, 4), overrideIfExists = true)
+ assert(spark.catalog.listTables().collect().map(_.name).toSet == Set())
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala
index 0e3a5ca9d7..f2456c7704 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala
@@ -187,4 +187,22 @@ class SQLConfEntrySuite extends SparkFunSuite {
}
assert(e2.getMessage === "The maximum size of the cache must not be negative")
}
+
+ test("clone SQLConf") {
+ val original = new SQLConf
+ val key = "spark.sql.SQLConfEntrySuite.clone"
+ assert(original.getConfString(key, "noentry") === "noentry")
+
+ // inheritance
+ original.setConfString(key, "orig")
+ val clone = original.clone()
+ assert(original ne clone)
+ assert(clone.getConfString(key, "noentry") === "orig")
+
+ // independence
+ clone.setConfString(key, "clone")
+ assert(original.getConfString(key, "noentry") === "orig")
+ original.setConfString(key, "dontcopyme")
+ assert(clone.getConfString(key, "noentry") === "clone")
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
index 8ab6db175d..898a2fb4f3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
@@ -35,18 +35,16 @@ private[sql] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) {
}
@transient
- override lazy val sessionState: SessionState = new SessionState(self) {
- override lazy val conf: SQLConf = {
- new SQLConf {
- clear()
- override def clear(): Unit = {
- super.clear()
- // Make sure we start with the default test configs even after clear
- TestSQLContext.overrideConfs.foreach { case (key, value) => setConfString(key, value) }
- }
+ override lazy val sessionState: SessionState = SessionState(
+ this,
+ new SQLConf {
+ clear()
+ override def clear(): Unit = {
+ super.clear()
+ // Make sure we start with the default test configs even after clear
+ TestSQLContext.overrideConfs.foreach { case (key, value) => setConfString(key, value) }
}
- }
- }
+ })
// Needed for Java tests
def loadTestData(): Unit = {