diff options
Diffstat (limited to 'sql/core')
9 files changed, 421 insertions, 111 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala index 1e8ba51e59..bd8dd6ea3f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala @@ -46,4 +46,10 @@ class ExperimentalMethods private[sql]() { @volatile var extraOptimizations: Seq[Rule[LogicalPlan]] = Nil + override def clone(): ExperimentalMethods = { + val result = new ExperimentalMethods + result.extraStrategies = extraStrategies + result.extraOptimizations = extraOptimizations + result + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index afc1827e7e..49562578b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -21,7 +21,6 @@ import java.io.Closeable import java.util.concurrent.atomic.AtomicReference import scala.collection.JavaConverters._ -import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal @@ -43,7 +42,7 @@ import org.apache.spark.sql.internal.{CatalogImpl, SessionState, SharedState} import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.streaming._ -import org.apache.spark.sql.types.{DataType, LongType, StructType} +import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.util.ExecutionListenerManager import org.apache.spark.util.Utils @@ -67,15 +66,22 @@ import org.apache.spark.util.Utils * .config("spark.some.config.option", "some-value") * .getOrCreate() * }}} + * + * @param sparkContext The Spark context associated with this Spark session. + * @param existingSharedState If supplied, use the existing shared state + * instead of creating a new one. + * @param parentSessionState If supplied, inherit all session state (i.e. temporary + * views, SQL config, UDFs etc) from parent. */ @InterfaceStability.Stable class SparkSession private( @transient val sparkContext: SparkContext, - @transient private val existingSharedState: Option[SharedState]) + @transient private val existingSharedState: Option[SharedState], + @transient private val parentSessionState: Option[SessionState]) extends Serializable with Closeable with Logging { self => private[sql] def this(sc: SparkContext) { - this(sc, None) + this(sc, None, None) } sparkContext.assertNotStopped() @@ -108,6 +114,7 @@ class SparkSession private( /** * State isolated across sessions, including SQL configurations, temporary tables, registered * functions, and everything else that accepts a [[org.apache.spark.sql.internal.SQLConf]]. + * If `parentSessionState` is not null, the `SessionState` will be a copy of the parent. * * This is internal to Spark and there is no guarantee on interface stability. * @@ -116,9 +123,13 @@ class SparkSession private( @InterfaceStability.Unstable @transient lazy val sessionState: SessionState = { - SparkSession.reflect[SessionState, SparkSession]( - SparkSession.sessionStateClassName(sparkContext.conf), - self) + parentSessionState + .map(_.clone(this)) + .getOrElse { + SparkSession.instantiateSessionState( + SparkSession.sessionStateClassName(sparkContext.conf), + self) + } } /** @@ -208,7 +219,25 @@ class SparkSession private( * @since 2.0.0 */ def newSession(): SparkSession = { - new SparkSession(sparkContext, Some(sharedState)) + new SparkSession(sparkContext, Some(sharedState), parentSessionState = None) + } + + /** + * Create an identical copy of this `SparkSession`, sharing the underlying `SparkContext` + * and shared state. All the state of this session (i.e. SQL configurations, temporary tables, + * registered functions) is copied over, and the cloned session is set up with the same shared + * state as this session. The cloned session is independent of this session, that is, any + * non-global change in either session is not reflected in the other. + * + * @note Other than the `SparkContext`, all shared state is initialized lazily. + * This method will force the initialization of the shared state to ensure that parent + * and child sessions are set up with the same shared state. If the underlying catalog + * implementation is Hive, this will initialize the metastore, which may take some time. + */ + private[sql] def cloneSession(): SparkSession = { + val result = new SparkSession(sparkContext, Some(sharedState), Some(sessionState)) + result.sessionState // force copy of SessionState + result } @@ -971,16 +1000,18 @@ object SparkSession { } /** - * Helper method to create an instance of [[T]] using a single-arg constructor that - * accepts an [[Arg]]. + * Helper method to create an instance of `SessionState` based on `className` from conf. + * The result is either `SessionState` or `HiveSessionState`. */ - private def reflect[T, Arg <: AnyRef]( + private def instantiateSessionState( className: String, - ctorArg: Arg)(implicit ctorArgTag: ClassTag[Arg]): T = { + sparkSession: SparkSession): SessionState = { + try { + // get `SessionState.apply(SparkSession)` val clazz = Utils.classForName(className) - val ctor = clazz.getDeclaredConstructor(ctorArgTag.runtimeClass) - ctor.newInstance(ctorArg).asInstanceOf[T] + val method = clazz.getMethod("apply", sparkSession.getClass) + method.invoke(null, sparkSession).asInstanceOf[SessionState] } catch { case NonFatal(e) => throw new IllegalArgumentException(s"Error while instantiating '$className':", e) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 4d781b96ab..8b598cc60e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -66,7 +66,8 @@ class ResolveSQLOnFile(sparkSession: SparkSession) extends Rule[LogicalPlan] { * Preprocess [[CreateTable]], to do some normalization and checking. */ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[LogicalPlan] { - private val catalog = sparkSession.sessionState.catalog + // catalog is a def and not a val/lazy val as the latter would introduce a circular reference + private def catalog = sparkSession.sessionState.catalog def apply(plan: LogicalPlan): LogicalPlan = plan transform { // When we CREATE TABLE without specifying the table schema, we should fail the query if diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 94e3fa7dd1..1244f690fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1019,6 +1019,14 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def clear(): Unit = { settings.clear() } + + override def clone(): SQLConf = { + val result = new SQLConf + getAllConfs.foreach { + case(k, v) => if (v ne null) result.setConfString(k, v) + } + result + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index 6908560511..ce80604bd3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -22,38 +22,49 @@ import java.io.File import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.command.AnalyzeTableCommand import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryManager} +import org.apache.spark.sql.streaming.StreamingQueryManager import org.apache.spark.sql.util.ExecutionListenerManager /** * A class that holds all session-specific state in a given [[SparkSession]]. + * @param sparkContext The [[SparkContext]]. + * @param sharedState The shared state. + * @param conf SQL-specific key-value configurations. + * @param experimentalMethods The experimental methods. + * @param functionRegistry Internal catalog for managing functions registered by the user. + * @param catalog Internal catalog for managing table and database states. + * @param sqlParser Parser that extracts expressions, plans, table identifiers etc. from SQL texts. + * @param analyzer Logical query plan analyzer for resolving unresolved attributes and relations. + * @param streamingQueryManager Interface to start and stop + * [[org.apache.spark.sql.streaming.StreamingQuery]]s. + * @param queryExecutionCreator Lambda to create a [[QueryExecution]] from a [[LogicalPlan]] */ -private[sql] class SessionState(sparkSession: SparkSession) { +private[sql] class SessionState( + sparkContext: SparkContext, + sharedState: SharedState, + val conf: SQLConf, + val experimentalMethods: ExperimentalMethods, + val functionRegistry: FunctionRegistry, + val catalog: SessionCatalog, + val sqlParser: ParserInterface, + val analyzer: Analyzer, + val streamingQueryManager: StreamingQueryManager, + val queryExecutionCreator: LogicalPlan => QueryExecution) { - // Note: These are all lazy vals because they depend on each other (e.g. conf) and we - // want subclasses to override some of the fields. Otherwise, we would get a lot of NPEs. - - /** - * SQL-specific key-value configurations. - */ - lazy val conf: SQLConf = new SQLConf - - def newHadoopConf(): Configuration = { - val hadoopConf = new Configuration(sparkSession.sparkContext.hadoopConfiguration) - conf.getAllConfs.foreach { case (k, v) => if (v ne null) hadoopConf.set(k, v) } - hadoopConf - } + def newHadoopConf(): Configuration = SessionState.newHadoopConf( + sparkContext.hadoopConfiguration, + conf) def newHadoopConfWithOptions(options: Map[String, String]): Configuration = { val hadoopConf = newHadoopConf() @@ -65,22 +76,15 @@ private[sql] class SessionState(sparkSession: SparkSession) { hadoopConf } - lazy val experimentalMethods = new ExperimentalMethods - - /** - * Internal catalog for managing functions registered by the user. - */ - lazy val functionRegistry: FunctionRegistry = FunctionRegistry.builtin.copy() - /** * A class for loading resources specified by a function. */ - lazy val functionResourceLoader: FunctionResourceLoader = { + val functionResourceLoader: FunctionResourceLoader = { new FunctionResourceLoader { override def loadResource(resource: FunctionResource): Unit = { resource.resourceType match { case JarResource => addJar(resource.uri) - case FileResource => sparkSession.sparkContext.addFile(resource.uri) + case FileResource => sparkContext.addFile(resource.uri) case ArchiveResource => throw new AnalysisException( "Archive is not allowed to be loaded. If YARN mode is used, " + @@ -91,92 +95,77 @@ private[sql] class SessionState(sparkSession: SparkSession) { } /** - * Internal catalog for managing table and database states. - */ - lazy val catalog = new SessionCatalog( - sparkSession.sharedState.externalCatalog, - sparkSession.sharedState.globalTempViewManager, - functionResourceLoader, - functionRegistry, - conf, - newHadoopConf(), - sqlParser) - - /** * Interface exposed to the user for registering user-defined functions. * Note that the user-defined functions must be deterministic. */ - lazy val udf: UDFRegistration = new UDFRegistration(functionRegistry) - - /** - * Logical query plan analyzer for resolving unresolved attributes and relations. - */ - lazy val analyzer: Analyzer = { - new Analyzer(catalog, conf) { - override val extendedResolutionRules = - new FindDataSourceTable(sparkSession) :: - new ResolveSQLOnFile(sparkSession) :: Nil - - override val postHocResolutionRules = - PreprocessTableCreation(sparkSession) :: - PreprocessTableInsertion(conf) :: - DataSourceAnalysis(conf) :: Nil - - override val extendedCheckRules = Seq(PreWriteCheck, HiveOnlyCheck) - } - } + val udf: UDFRegistration = new UDFRegistration(functionRegistry) /** * Logical query plan optimizer. */ - lazy val optimizer: Optimizer = new SparkOptimizer(catalog, conf, experimentalMethods) - - /** - * Parser that extracts expressions, plans, table identifiers etc. from SQL texts. - */ - lazy val sqlParser: ParserInterface = new SparkSqlParser(conf) + val optimizer: Optimizer = new SparkOptimizer(catalog, conf, experimentalMethods) /** * Planner that converts optimized logical plans to physical plans. */ def planner: SparkPlanner = - new SparkPlanner(sparkSession.sparkContext, conf, experimentalMethods.extraStrategies) + new SparkPlanner(sparkContext, conf, experimentalMethods.extraStrategies) /** * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s * that listen for execution metrics. */ - lazy val listenerManager: ExecutionListenerManager = new ExecutionListenerManager + val listenerManager: ExecutionListenerManager = new ExecutionListenerManager /** - * Interface to start and stop [[StreamingQuery]]s. + * Get an identical copy of the `SessionState` and associate it with the given `SparkSession` */ - lazy val streamingQueryManager: StreamingQueryManager = { - new StreamingQueryManager(sparkSession) - } + def clone(newSparkSession: SparkSession): SessionState = { + val sparkContext = newSparkSession.sparkContext + val confCopy = conf.clone() + val functionRegistryCopy = functionRegistry.clone() + val sqlParser: ParserInterface = new SparkSqlParser(confCopy) + val catalogCopy = catalog.newSessionCatalogWith( + confCopy, + SessionState.newHadoopConf(sparkContext.hadoopConfiguration, confCopy), + functionRegistryCopy, + sqlParser) + val queryExecutionCreator = (plan: LogicalPlan) => new QueryExecution(newSparkSession, plan) - private val jarClassLoader: NonClosableMutableURLClassLoader = - sparkSession.sharedState.jarClassLoader + SessionState.mergeSparkConf(confCopy, sparkContext.getConf) - // Automatically extract all entries and put it in our SQLConf - // We need to call it after all of vals have been initialized. - sparkSession.sparkContext.getConf.getAll.foreach { case (k, v) => - conf.setConfString(k, v) + new SessionState( + sparkContext, + newSparkSession.sharedState, + confCopy, + experimentalMethods.clone(), + functionRegistryCopy, + catalogCopy, + sqlParser, + SessionState.createAnalyzer(newSparkSession, catalogCopy, confCopy), + new StreamingQueryManager(newSparkSession), + queryExecutionCreator) } // ------------------------------------------------------ // Helper methods, partially leftover from pre-2.0 days // ------------------------------------------------------ - def executePlan(plan: LogicalPlan): QueryExecution = new QueryExecution(sparkSession, plan) + def executePlan(plan: LogicalPlan): QueryExecution = queryExecutionCreator(plan) def refreshTable(tableName: String): Unit = { catalog.refreshTable(sqlParser.parseTableIdentifier(tableName)) } + /** + * Add a jar path to [[SparkContext]] and the classloader. + * + * Note: this method seems not access any session state, but the subclass `HiveSessionState` needs + * to add the jar to its hive client for the current session. Hence, it still needs to be in + * [[SessionState]]. + */ def addJar(path: String): Unit = { - sparkSession.sparkContext.addJar(path) - + sparkContext.addJar(path) val uri = new Path(path).toUri val jarURL = if (uri.getScheme == null) { // `path` is a local file path without a URL scheme @@ -185,15 +174,93 @@ private[sql] class SessionState(sparkSession: SparkSession) { // `path` is a URL with a scheme uri.toURL } - jarClassLoader.addURL(jarURL) - Thread.currentThread().setContextClassLoader(jarClassLoader) + sharedState.jarClassLoader.addURL(jarURL) + Thread.currentThread().setContextClassLoader(sharedState.jarClassLoader) + } +} + + +private[sql] object SessionState { + + def apply(sparkSession: SparkSession): SessionState = { + apply(sparkSession, new SQLConf) + } + + def apply(sparkSession: SparkSession, sqlConf: SQLConf): SessionState = { + val sparkContext = sparkSession.sparkContext + + // Automatically extract all entries and put them in our SQLConf + mergeSparkConf(sqlConf, sparkContext.getConf) + + val functionRegistry = FunctionRegistry.builtin.clone() + + val sqlParser: ParserInterface = new SparkSqlParser(sqlConf) + + val catalog = new SessionCatalog( + sparkSession.sharedState.externalCatalog, + sparkSession.sharedState.globalTempViewManager, + functionRegistry, + sqlConf, + newHadoopConf(sparkContext.hadoopConfiguration, sqlConf), + sqlParser) + + val analyzer: Analyzer = createAnalyzer(sparkSession, catalog, sqlConf) + + val streamingQueryManager: StreamingQueryManager = new StreamingQueryManager(sparkSession) + + val queryExecutionCreator = (plan: LogicalPlan) => new QueryExecution(sparkSession, plan) + + val sessionState = new SessionState( + sparkContext, + sparkSession.sharedState, + sqlConf, + new ExperimentalMethods, + functionRegistry, + catalog, + sqlParser, + analyzer, + streamingQueryManager, + queryExecutionCreator) + // functionResourceLoader needs to access SessionState.addJar, so it cannot be created before + // creating SessionState. Setting `catalog.functionResourceLoader` here is safe since the caller + // cannot use SessionCatalog before we return SessionState. + catalog.functionResourceLoader = sessionState.functionResourceLoader + sessionState + } + + def newHadoopConf(hadoopConf: Configuration, sqlConf: SQLConf): Configuration = { + val newHadoopConf = new Configuration(hadoopConf) + sqlConf.getAllConfs.foreach { case (k, v) => if (v ne null) newHadoopConf.set(k, v) } + newHadoopConf + } + + /** + * Create an logical query plan `Analyzer` with rules specific to a non-Hive `SessionState`. + */ + private def createAnalyzer( + sparkSession: SparkSession, + catalog: SessionCatalog, + sqlConf: SQLConf): Analyzer = { + new Analyzer(catalog, sqlConf) { + override val extendedResolutionRules: Seq[Rule[LogicalPlan]] = + new FindDataSourceTable(sparkSession) :: + new ResolveSQLOnFile(sparkSession) :: Nil + + override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = + PreprocessTableCreation(sparkSession) :: + PreprocessTableInsertion(sqlConf) :: + DataSourceAnalysis(sqlConf) :: Nil + + override val extendedCheckRules = Seq(PreWriteCheck, HiveOnlyCheck) + } } /** - * Analyzes the given table in the current database to generate statistics, which will be - * used in query optimizations. + * Extract entries from `SparkConf` and put them in the `SQLConf` */ - def analyze(tableIdent: TableIdentifier, noscan: Boolean = true): Unit = { - AnalyzeTableCommand(tableIdent, noscan).run(sparkSession) + def mergeSparkConf(sqlConf: SQLConf, sparkConf: SparkConf): Unit = { + sparkConf.getAll.foreach { case (k, v) => + sqlConf.setConfString(k, v) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala new file mode 100644 index 0000000000..2d5e37242a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.scalatest.BeforeAndAfterAll +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule + +class SessionStateSuite extends SparkFunSuite + with BeforeAndAfterEach with BeforeAndAfterAll { + + /** + * A shared SparkSession for all tests in this suite. Make sure you reset any changes to this + * session as this is a singleton HiveSparkSession in HiveSessionStateSuite and it's shared + * with all Hive test suites. + */ + protected var activeSession: SparkSession = _ + + override def beforeAll(): Unit = { + activeSession = SparkSession.builder().master("local").getOrCreate() + } + + override def afterAll(): Unit = { + if (activeSession != null) { + activeSession.stop() + activeSession = null + } + super.afterAll() + } + + test("fork new session and inherit RuntimeConfig options") { + val key = "spark-config-clone" + try { + activeSession.conf.set(key, "active") + + // inheritance + val forkedSession = activeSession.cloneSession() + assert(forkedSession ne activeSession) + assert(forkedSession.conf ne activeSession.conf) + assert(forkedSession.conf.get(key) == "active") + + // independence + forkedSession.conf.set(key, "forked") + assert(activeSession.conf.get(key) == "active") + activeSession.conf.set(key, "dontcopyme") + assert(forkedSession.conf.get(key) == "forked") + } finally { + activeSession.conf.unset(key) + } + } + + test("fork new session and inherit function registry and udf") { + val testFuncName1 = "strlenScala" + val testFuncName2 = "addone" + try { + activeSession.udf.register(testFuncName1, (_: String).length + (_: Int)) + val forkedSession = activeSession.cloneSession() + + // inheritance + assert(forkedSession ne activeSession) + assert(forkedSession.sessionState.functionRegistry ne + activeSession.sessionState.functionRegistry) + assert(forkedSession.sessionState.functionRegistry.lookupFunction(testFuncName1).nonEmpty) + + // independence + forkedSession.sessionState.functionRegistry.dropFunction(testFuncName1) + assert(activeSession.sessionState.functionRegistry.lookupFunction(testFuncName1).nonEmpty) + activeSession.udf.register(testFuncName2, (_: Int) + 1) + assert(forkedSession.sessionState.functionRegistry.lookupFunction(testFuncName2).isEmpty) + } finally { + activeSession.sessionState.functionRegistry.dropFunction(testFuncName1) + activeSession.sessionState.functionRegistry.dropFunction(testFuncName2) + } + } + + test("fork new session and inherit experimental methods") { + val originalExtraOptimizations = activeSession.experimental.extraOptimizations + val originalExtraStrategies = activeSession.experimental.extraStrategies + try { + object DummyRule1 extends Rule[LogicalPlan] { + def apply(p: LogicalPlan): LogicalPlan = p + } + object DummyRule2 extends Rule[LogicalPlan] { + def apply(p: LogicalPlan): LogicalPlan = p + } + val optimizations = List(DummyRule1, DummyRule2) + activeSession.experimental.extraOptimizations = optimizations + val forkedSession = activeSession.cloneSession() + + // inheritance + assert(forkedSession ne activeSession) + assert(forkedSession.experimental ne activeSession.experimental) + assert(forkedSession.experimental.extraOptimizations.toSet == + activeSession.experimental.extraOptimizations.toSet) + + // independence + forkedSession.experimental.extraOptimizations = List(DummyRule2) + assert(activeSession.experimental.extraOptimizations == optimizations) + activeSession.experimental.extraOptimizations = List(DummyRule1) + assert(forkedSession.experimental.extraOptimizations == List(DummyRule2)) + } finally { + activeSession.experimental.extraOptimizations = originalExtraOptimizations + activeSession.experimental.extraStrategies = originalExtraStrategies + } + } + + test("fork new sessions and run query on inherited table") { + def checkTableExists(sparkSession: SparkSession): Unit = { + QueryTest.checkAnswer(sparkSession.sql( + """ + |SELECT x.str, COUNT(*) + |FROM df x JOIN df y ON x.str = y.str + |GROUP BY x.str + """.stripMargin), + Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil) + } + + val spark = activeSession + // Cannot use `import activeSession.implicits._` due to the compiler limitation. + import spark.implicits._ + + try { + activeSession + .createDataset[(Int, String)](Seq(1, 2, 3).map(i => (i, i.toString))) + .toDF("int", "str") + .createOrReplaceTempView("df") + checkTableExists(activeSession) + + val forkedSession = activeSession.cloneSession() + assert(forkedSession ne activeSession) + assert(forkedSession.sessionState ne activeSession.sessionState) + checkTableExists(forkedSession) + checkTableExists(activeSession.cloneSession()) // ability to clone multiple times + checkTableExists(forkedSession.cloneSession()) // clone of clone + } finally { + activeSession.sql("drop table df") + } + } + + test("fork new session and inherit reference to SharedState") { + val forkedSession = activeSession.cloneSession() + assert(activeSession.sharedState eq forkedSession.sharedState) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala index 989a7f2698..fcb8ffbc6e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala @@ -493,6 +493,25 @@ class CatalogSuite } } - // TODO: add tests for the rest of them + test("clone Catalog") { + // need to test tempTables are cloned + assert(spark.catalog.listTables().collect().isEmpty) + createTempTable("my_temp_table") + assert(spark.catalog.listTables().collect().map(_.name).toSet == Set("my_temp_table")) + + // inheritance + val forkedSession = spark.cloneSession() + assert(spark ne forkedSession) + assert(spark.catalog ne forkedSession.catalog) + assert(forkedSession.catalog.listTables().collect().map(_.name).toSet == Set("my_temp_table")) + + // independence + dropTable("my_temp_table") // drop table in original session + assert(spark.catalog.listTables().collect().map(_.name).toSet == Set()) + assert(forkedSession.catalog.listTables().collect().map(_.name).toSet == Set("my_temp_table")) + forkedSession.sessionState.catalog + .createTempView("fork_table", Range(1, 2, 3, 4), overrideIfExists = true) + assert(spark.catalog.listTables().collect().map(_.name).toSet == Set()) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala index 0e3a5ca9d7..f2456c7704 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala @@ -187,4 +187,22 @@ class SQLConfEntrySuite extends SparkFunSuite { } assert(e2.getMessage === "The maximum size of the cache must not be negative") } + + test("clone SQLConf") { + val original = new SQLConf + val key = "spark.sql.SQLConfEntrySuite.clone" + assert(original.getConfString(key, "noentry") === "noentry") + + // inheritance + original.setConfString(key, "orig") + val clone = original.clone() + assert(original ne clone) + assert(clone.getConfString(key, "noentry") === "orig") + + // independence + clone.setConfString(key, "clone") + assert(original.getConfString(key, "noentry") === "orig") + original.setConfString(key, "dontcopyme") + assert(clone.getConfString(key, "noentry") === "clone") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index 8ab6db175d..898a2fb4f3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -35,18 +35,16 @@ private[sql] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) { } @transient - override lazy val sessionState: SessionState = new SessionState(self) { - override lazy val conf: SQLConf = { - new SQLConf { - clear() - override def clear(): Unit = { - super.clear() - // Make sure we start with the default test configs even after clear - TestSQLContext.overrideConfs.foreach { case (key, value) => setConfString(key, value) } - } + override lazy val sessionState: SessionState = SessionState( + this, + new SQLConf { + clear() + override def clear(): Unit = { + super.clear() + // Make sure we start with the default test configs even after clear + TestSQLContext.overrideConfs.foreach { case (key, value) => setConfString(key, value) } } - } - } + }) // Needed for Java tests def loadTestData(): Unit = { |