From ea361165e1ddce4d8aa0242ae3e878d7b39f1de2 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 28 Mar 2017 10:07:24 +0800 Subject: [SPARK-20100][SQL] Refactor SessionState initialization ## What changes were proposed in this pull request? The current SessionState initialization code path is quite complex. A part of the creation is done in the SessionState companion objects, a part of the creation is one inside the SessionState class, and a part is done by passing functions. This PR refactors this code path, and consolidates SessionState initialization into a builder class. This SessionState will not do any initialization and just becomes a place holder for the various Spark SQL internals. This also lays the ground work for two future improvements: 1. This provides us with a start for removing the `HiveSessionState`. Removing the `HiveSessionState` would also require us to move resource loading into a separate class, and to (re)move metadata hive. 2. This makes it easier to customize the Spark Session. Currently you will need to create a custom version of the builder. I have added hooks to facilitate this. A future step will be to create a semi stable API on top of this. ## How was this patch tested? Existing tests. Author: Herman van Hovell Closes #17433 from hvanhovell/SPARK-20100. --- .../spark/sql/execution/SparkOptimizer.scala | 12 +- .../apache/spark/sql/execution/SparkPlanner.scala | 11 +- .../execution/streaming/IncrementalExecution.scala | 23 +- .../apache/spark/sql/internal/SessionState.scala | 180 ++++--------- .../spark/sql/internal/sessionStateBuilders.scala | 279 +++++++++++++++++++++ .../org/apache/spark/sql/test/TestSQLContext.scala | 23 +- 6 files changed, 364 insertions(+), 164 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/internal/sessionStateBuilders.scala (limited to 'sql/core') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 981728331d..2cdfb7a782 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -30,9 +30,17 @@ class SparkOptimizer( experimentalMethods: ExperimentalMethods) extends Optimizer(catalog, conf) { - override def batches: Seq[Batch] = super.batches :+ + override def batches: Seq[Batch] = (super.batches :+ Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog, conf)) :+ Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+ - Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions) :+ + Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions)) ++ + postHocOptimizationBatches :+ Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) + + /** + * Optimization batches that are executed after the regular optimization batches, but before the + * batch executing the [[ExperimentalMethods]] optimizer rules. This hook can be used to add + * custom optimizer batches to the Spark optimizer. + */ + def postHocOptimizationBatches: Seq[Batch] = Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 678241656c..6566502bd8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -27,13 +27,14 @@ import org.apache.spark.sql.internal.SQLConf class SparkPlanner( val sparkContext: SparkContext, val conf: SQLConf, - val extraStrategies: Seq[Strategy]) + val experimentalMethods: ExperimentalMethods) extends SparkStrategies { def numPartitions: Int = conf.numShufflePartitions def strategies: Seq[Strategy] = - extraStrategies ++ ( + experimentalMethods.extraStrategies ++ + extraPlanningStrategies ++ ( FileSourceStrategy :: DataSourceStrategy :: SpecialLimits :: @@ -42,6 +43,12 @@ class SparkPlanner( InMemoryScans :: BasicOperators :: Nil) + /** + * Override to add extra planning strategies to the planner. These strategies are tried after + * the strategies defined in [[ExperimentalMethods]], and before the regular strategies. + */ + def extraPlanningStrategies: Seq[Strategy] = Nil + override protected def collectPlaceholders(plan: SparkPlan): Seq[(SparkPlan, LogicalPlan)] = { plan.collect { case placeholder @ PlanLater(logicalPlan) => placeholder -> logicalPlan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 0f0e4a91f8..622e049630 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.execution.streaming import java.util.concurrent.atomic.AtomicInteger import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.expressions.{CurrentBatchTimestamp, Literal} -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{SparkSession, Strategy} +import org.apache.spark.sql.catalyst.expressions.CurrentBatchTimestamp import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SparkPlanner, UnaryExecNode} @@ -40,20 +40,17 @@ class IncrementalExecution( offsetSeqMetadata: OffsetSeqMetadata) extends QueryExecution(sparkSession, logicalPlan) with Logging { - // TODO: make this always part of planning. - val streamingExtraStrategies = - sparkSession.sessionState.planner.StatefulAggregationStrategy +: - sparkSession.sessionState.planner.FlatMapGroupsWithStateStrategy +: - sparkSession.sessionState.planner.StreamingRelationStrategy +: - sparkSession.sessionState.planner.StreamingDeduplicationStrategy +: - sparkSession.sessionState.experimentalMethods.extraStrategies - // Modified planner with stateful operations. - override def planner: SparkPlanner = - new SparkPlanner( + override val planner: SparkPlanner = new SparkPlanner( sparkSession.sparkContext, sparkSession.sessionState.conf, - streamingExtraStrategies) + sparkSession.sessionState.experimentalMethods) { + override def extraPlanningStrategies: Seq[Strategy] = + StatefulAggregationStrategy :: + FlatMapGroupsWithStateStrategy :: + StreamingRelationStrategy :: + StreamingDeduplicationStrategy :: Nil + } /** * See [SPARK-18339] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index ce80604bd3..b5b0bb0bfc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -22,22 +22,21 @@ import java.io.File import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.SparkContext +import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.streaming.StreamingQueryManager import org.apache.spark.sql.util.ExecutionListenerManager - /** * A class that holds all session-specific state in a given [[SparkSession]]. + * * @param sparkContext The [[SparkContext]]. * @param sharedState The shared state. * @param conf SQL-specific key-value configurations. @@ -46,9 +45,11 @@ import org.apache.spark.sql.util.ExecutionListenerManager * @param catalog Internal catalog for managing table and database states. * @param sqlParser Parser that extracts expressions, plans, table identifiers etc. from SQL texts. * @param analyzer Logical query plan analyzer for resolving unresolved attributes and relations. - * @param streamingQueryManager Interface to start and stop - * [[org.apache.spark.sql.streaming.StreamingQuery]]s. - * @param queryExecutionCreator Lambda to create a [[QueryExecution]] from a [[LogicalPlan]] + * @param optimizer Logical query plan optimizer. + * @param planner Planner that converts optimized logical plans to physical plans + * @param streamingQueryManager Interface to start and stop streaming queries. + * @param createQueryExecution Function used to create QueryExecution objects. + * @param createClone Function used to create clones of the session state. */ private[sql] class SessionState( sparkContext: SparkContext, @@ -59,8 +60,11 @@ private[sql] class SessionState( val catalog: SessionCatalog, val sqlParser: ParserInterface, val analyzer: Analyzer, + val optimizer: Optimizer, + val planner: SparkPlanner, val streamingQueryManager: StreamingQueryManager, - val queryExecutionCreator: LogicalPlan => QueryExecution) { + createQueryExecution: LogicalPlan => QueryExecution, + createClone: (SparkSession, SessionState) => SessionState) { def newHadoopConf(): Configuration = SessionState.newHadoopConf( sparkContext.hadoopConfiguration, @@ -76,41 +80,12 @@ private[sql] class SessionState( hadoopConf } - /** - * A class for loading resources specified by a function. - */ - val functionResourceLoader: FunctionResourceLoader = { - new FunctionResourceLoader { - override def loadResource(resource: FunctionResource): Unit = { - resource.resourceType match { - case JarResource => addJar(resource.uri) - case FileResource => sparkContext.addFile(resource.uri) - case ArchiveResource => - throw new AnalysisException( - "Archive is not allowed to be loaded. If YARN mode is used, " + - "please use --archives options while calling spark-submit.") - } - } - } - } - /** * Interface exposed to the user for registering user-defined functions. * Note that the user-defined functions must be deterministic. */ val udf: UDFRegistration = new UDFRegistration(functionRegistry) - /** - * Logical query plan optimizer. - */ - val optimizer: Optimizer = new SparkOptimizer(catalog, conf, experimentalMethods) - - /** - * Planner that converts optimized logical plans to physical plans. - */ - def planner: SparkPlanner = - new SparkPlanner(sparkContext, conf, experimentalMethods.extraStrategies) - /** * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s * that listen for execution metrics. @@ -120,38 +95,13 @@ private[sql] class SessionState( /** * Get an identical copy of the `SessionState` and associate it with the given `SparkSession` */ - def clone(newSparkSession: SparkSession): SessionState = { - val sparkContext = newSparkSession.sparkContext - val confCopy = conf.clone() - val functionRegistryCopy = functionRegistry.clone() - val sqlParser: ParserInterface = new SparkSqlParser(confCopy) - val catalogCopy = catalog.newSessionCatalogWith( - confCopy, - SessionState.newHadoopConf(sparkContext.hadoopConfiguration, confCopy), - functionRegistryCopy, - sqlParser) - val queryExecutionCreator = (plan: LogicalPlan) => new QueryExecution(newSparkSession, plan) - - SessionState.mergeSparkConf(confCopy, sparkContext.getConf) - - new SessionState( - sparkContext, - newSparkSession.sharedState, - confCopy, - experimentalMethods.clone(), - functionRegistryCopy, - catalogCopy, - sqlParser, - SessionState.createAnalyzer(newSparkSession, catalogCopy, confCopy), - new StreamingQueryManager(newSparkSession), - queryExecutionCreator) - } + def clone(newSparkSession: SparkSession): SessionState = createClone(newSparkSession, this) // ------------------------------------------------------ // Helper methods, partially leftover from pre-2.0 days // ------------------------------------------------------ - def executePlan(plan: LogicalPlan): QueryExecution = queryExecutionCreator(plan) + def executePlan(plan: LogicalPlan): QueryExecution = createQueryExecution(plan) def refreshTable(tableName: String): Unit = { catalog.refreshTable(sqlParser.parseTableIdentifier(tableName)) @@ -179,53 +129,12 @@ private[sql] class SessionState( } } - private[sql] object SessionState { - - def apply(sparkSession: SparkSession): SessionState = { - apply(sparkSession, new SQLConf) - } - - def apply(sparkSession: SparkSession, sqlConf: SQLConf): SessionState = { - val sparkContext = sparkSession.sparkContext - - // Automatically extract all entries and put them in our SQLConf - mergeSparkConf(sqlConf, sparkContext.getConf) - - val functionRegistry = FunctionRegistry.builtin.clone() - - val sqlParser: ParserInterface = new SparkSqlParser(sqlConf) - - val catalog = new SessionCatalog( - sparkSession.sharedState.externalCatalog, - sparkSession.sharedState.globalTempViewManager, - functionRegistry, - sqlConf, - newHadoopConf(sparkContext.hadoopConfiguration, sqlConf), - sqlParser) - - val analyzer: Analyzer = createAnalyzer(sparkSession, catalog, sqlConf) - - val streamingQueryManager: StreamingQueryManager = new StreamingQueryManager(sparkSession) - - val queryExecutionCreator = (plan: LogicalPlan) => new QueryExecution(sparkSession, plan) - - val sessionState = new SessionState( - sparkContext, - sparkSession.sharedState, - sqlConf, - new ExperimentalMethods, - functionRegistry, - catalog, - sqlParser, - analyzer, - streamingQueryManager, - queryExecutionCreator) - // functionResourceLoader needs to access SessionState.addJar, so it cannot be created before - // creating SessionState. Setting `catalog.functionResourceLoader` here is safe since the caller - // cannot use SessionCatalog before we return SessionState. - catalog.functionResourceLoader = sessionState.functionResourceLoader - sessionState + /** + * Create a new [[SessionState]] for the given session. + */ + def apply(session: SparkSession): SessionState = { + new SessionStateBuilder(session).build() } def newHadoopConf(hadoopConf: Configuration, sqlConf: SQLConf): Configuration = { @@ -233,34 +142,33 @@ private[sql] object SessionState { sqlConf.getAllConfs.foreach { case (k, v) => if (v ne null) newHadoopConf.set(k, v) } newHadoopConf } +} - /** - * Create an logical query plan `Analyzer` with rules specific to a non-Hive `SessionState`. - */ - private def createAnalyzer( - sparkSession: SparkSession, - catalog: SessionCatalog, - sqlConf: SQLConf): Analyzer = { - new Analyzer(catalog, sqlConf) { - override val extendedResolutionRules: Seq[Rule[LogicalPlan]] = - new FindDataSourceTable(sparkSession) :: - new ResolveSQLOnFile(sparkSession) :: Nil - - override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = - PreprocessTableCreation(sparkSession) :: - PreprocessTableInsertion(sqlConf) :: - DataSourceAnalysis(sqlConf) :: Nil - - override val extendedCheckRules = Seq(PreWriteCheck, HiveOnlyCheck) - } - } +/** + * Concrete implementation of a [[SessionStateBuilder]]. + */ +@Experimental +@InterfaceStability.Unstable +class SessionStateBuilder( + session: SparkSession, + parentState: Option[SessionState] = None) + extends BaseSessionStateBuilder(session, parentState) { + override protected def newBuilder: NewBuilder = new SessionStateBuilder(_, _) +} - /** - * Extract entries from `SparkConf` and put them in the `SQLConf` - */ - def mergeSparkConf(sqlConf: SQLConf, sparkConf: SparkConf): Unit = { - sparkConf.getAll.foreach { case (k, v) => - sqlConf.setConfString(k, v) +/** + * Session shared [[FunctionResourceLoader]]. + */ +@InterfaceStability.Unstable +class SessionFunctionResourceLoader(session: SparkSession) extends FunctionResourceLoader { + override def loadResource(resource: FunctionResource): Unit = { + resource.resourceType match { + case JarResource => session.sessionState.addJar(resource.uri) + case FileResource => session.sparkContext.addFile(resource.uri) + case ArchiveResource => + throw new AnalysisException( + "Archive is not allowed to be loaded. If YARN mode is used, " + + "please use --archives options while calling spark-submit.") } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/sessionStateBuilders.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/sessionStateBuilders.scala new file mode 100644 index 0000000000..6b5559adb1 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/sessionStateBuilders.scala @@ -0,0 +1,279 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.internal + +import org.apache.spark.SparkConf +import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.sql.{ExperimentalMethods, SparkSession, Strategy} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry} +import org.apache.spark.sql.catalyst.catalog.SessionCatalog +import org.apache.spark.sql.catalyst.optimizer.Optimizer +import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{QueryExecution, SparkOptimizer, SparkPlanner, SparkSqlParser} +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.streaming.StreamingQueryManager + +/** + * Builder class that coordinates construction of a new [[SessionState]]. + * + * The builder explicitly defines all components needed by the session state, and creates a session + * state when `build` is called. Components should only be initialized once. This is not a problem + * for most components as they are only used in the `build` function. However some components + * (`conf`, `catalog`, `functionRegistry`, `experimentalMethods` & `sqlParser`) are as dependencies + * for other components and are shared as a result. These components are defined as lazy vals to + * make sure the component is created only once. + * + * A developer can modify the builder by providing custom versions of components, or by using the + * hooks provided for the analyzer, optimizer & planner. There are some dependencies between the + * components (they are documented per dependency), a developer should respect these when making + * modifications in order to prevent initialization problems. + * + * A parent [[SessionState]] can be used to initialize the new [[SessionState]]. The new session + * state will clone the parent sessions state's `conf`, `functionRegistry`, `experimentalMethods` + * and `catalog` fields. Note that the state is cloned when `build` is called, and not before. + */ +@Experimental +@InterfaceStability.Unstable +abstract class BaseSessionStateBuilder( + val session: SparkSession, + val parentState: Option[SessionState] = None) { + type NewBuilder = (SparkSession, Option[SessionState]) => BaseSessionStateBuilder + + /** + * Function that produces a new instance of the SessionStateBuilder. This is used by the + * [[SessionState]]'s clone functionality. Make sure to override this when implementing your own + * [[SessionStateBuilder]]. + */ + protected def newBuilder: NewBuilder + + /** + * Extract entries from `SparkConf` and put them in the `SQLConf` + */ + protected def mergeSparkConf(sqlConf: SQLConf, sparkConf: SparkConf): Unit = { + sparkConf.getAll.foreach { case (k, v) => + sqlConf.setConfString(k, v) + } + } + + /** + * SQL-specific key-value configurations. + * + * These either get cloned from a pre-existing instance or newly created. The conf is always + * merged with its [[SparkConf]]. + */ + protected lazy val conf: SQLConf = { + val conf = parentState.map(_.conf.clone()).getOrElse(new SQLConf) + mergeSparkConf(conf, session.sparkContext.conf) + conf + } + + /** + * Internal catalog managing functions registered by the user. + * + * This either gets cloned from a pre-existing version or cloned from the built-in registry. + */ + protected lazy val functionRegistry: FunctionRegistry = { + parentState.map(_.functionRegistry).getOrElse(FunctionRegistry.builtin).clone() + } + + /** + * Experimental methods that can be used to define custom optimization rules and custom planning + * strategies. + * + * This either gets cloned from a pre-existing version or newly created. + */ + protected lazy val experimentalMethods: ExperimentalMethods = { + parentState.map(_.experimentalMethods.clone()).getOrElse(new ExperimentalMethods) + } + + /** + * Parser that extracts expressions, plans, table identifiers etc. from SQL texts. + * + * Note: this depends on the `conf` field. + */ + protected lazy val sqlParser: ParserInterface = new SparkSqlParser(conf) + + /** + * Catalog for managing table and database states. If there is a pre-existing catalog, the state + * of that catalog (temp tables & current database) will be copied into the new catalog. + * + * Note: this depends on the `conf`, `functionRegistry` and `sqlParser` fields. + */ + protected lazy val catalog: SessionCatalog = { + val catalog = new SessionCatalog( + session.sharedState.externalCatalog, + session.sharedState.globalTempViewManager, + functionRegistry, + conf, + SessionState.newHadoopConf(session.sparkContext.hadoopConfiguration, conf), + sqlParser, + new SessionFunctionResourceLoader(session)) + parentState.foreach(_.catalog.copyStateTo(catalog)) + catalog + } + + /** + * Logical query plan analyzer for resolving unresolved attributes and relations. + * + * Note: this depends on the `conf` and `catalog` fields. + */ + protected def analyzer: Analyzer = new Analyzer(catalog, conf) { + override val extendedResolutionRules: Seq[Rule[LogicalPlan]] = + new FindDataSourceTable(session) +: + new ResolveSQLOnFile(session) +: + customResolutionRules + + override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = + PreprocessTableCreation(session) +: + PreprocessTableInsertion(conf) +: + DataSourceAnalysis(conf) +: + customPostHocResolutionRules + + override val extendedCheckRules: Seq[LogicalPlan => Unit] = + PreWriteCheck +: + HiveOnlyCheck +: + customCheckRules + } + + /** + * Custom resolution rules to add to the Analyzer. Prefer overriding this instead of creating + * your own Analyzer. + * + * Note that this may NOT depend on the `analyzer` function. + */ + protected def customResolutionRules: Seq[Rule[LogicalPlan]] = Nil + + /** + * Custom post resolution rules to add to the Analyzer. Prefer overriding this instead of + * creating your own Analyzer. + * + * Note that this may NOT depend on the `analyzer` function. + */ + protected def customPostHocResolutionRules: Seq[Rule[LogicalPlan]] = Nil + + /** + * Custom check rules to add to the Analyzer. Prefer overriding this instead of creating + * your own Analyzer. + * + * Note that this may NOT depend on the `analyzer` function. + */ + protected def customCheckRules: Seq[LogicalPlan => Unit] = Nil + + /** + * Logical query plan optimizer. + * + * Note: this depends on the `conf`, `catalog` and `experimentalMethods` fields. + */ + protected def optimizer: Optimizer = { + new SparkOptimizer(catalog, conf, experimentalMethods) { + override def extendedOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = + super.extendedOperatorOptimizationRules ++ customOperatorOptimizationRules + } + } + + /** + * Custom operator optimization rules to add to the Optimizer. Prefer overriding this instead + * of creating your own Optimizer. + * + * Note that this may NOT depend on the `optimizer` function. + */ + protected def customOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = Nil + + /** + * Planner that converts optimized logical plans to physical plans. + * + * Note: this depends on the `conf` and `experimentalMethods` fields. + */ + protected def planner: SparkPlanner = { + new SparkPlanner(session.sparkContext, conf, experimentalMethods) { + override def extraPlanningStrategies: Seq[Strategy] = + super.extraPlanningStrategies ++ customPlanningStrategies + } + } + + /** + * Custom strategies to add to the planner. Prefer overriding this instead of creating + * your own Planner. + * + * Note that this may NOT depend on the `planner` function. + */ + protected def customPlanningStrategies: Seq[Strategy] = Nil + + /** + * Create a query execution object. + */ + protected def createQueryExecution: LogicalPlan => QueryExecution = { plan => + new QueryExecution(session, plan) + } + + /** + * Interface to start and stop streaming queries. + */ + protected def streamingQueryManager: StreamingQueryManager = new StreamingQueryManager(session) + + /** + * Function used to make clones of the session state. + */ + protected def createClone: (SparkSession, SessionState) => SessionState = { + val createBuilder = newBuilder + (session, state) => createBuilder(session, Option(state)).build() + } + + /** + * Build the [[SessionState]]. + */ + def build(): SessionState = { + new SessionState( + session.sparkContext, + session.sharedState, + conf, + experimentalMethods, + functionRegistry, + catalog, + sqlParser, + analyzer, + optimizer, + planner, + streamingQueryManager, + createQueryExecution, + createClone) + } +} + +/** + * Helper class for using SessionStateBuilders during tests. + */ +private[sql] trait WithTestConf { self: BaseSessionStateBuilder => + def overrideConfs: Map[String, String] + + override protected lazy val conf: SQLConf = { + val conf = parentState.map(_.conf.clone()).getOrElse { + new SQLConf { + clear() + override def clear(): Unit = { + super.clear() + // Make sure we start with the default test configs even after clear + overrideConfs.foreach { case (key, value) => setConfString(key, value) } + } + } + } + mergeSparkConf(conf, session.sparkContext.conf) + conf + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index 898a2fb4f3..b01977a238 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.test import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.internal.{SessionState, SQLConf} +import org.apache.spark.sql.internal.{SessionState, SessionStateBuilder, SQLConf, WithTestConf} /** * A special [[SparkSession]] prepared for testing. @@ -35,16 +35,9 @@ private[sql] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) { } @transient - override lazy val sessionState: SessionState = SessionState( - this, - new SQLConf { - clear() - override def clear(): Unit = { - super.clear() - // Make sure we start with the default test configs even after clear - TestSQLContext.overrideConfs.foreach { case (key, value) => setConfString(key, value) } - } - }) + override lazy val sessionState: SessionState = { + new TestSQLSessionStateBuilder(this, None).build() + } // Needed for Java tests def loadTestData(): Unit = { @@ -67,3 +60,11 @@ private[sql] object TestSQLContext { // Fewer shuffle partitions to speed up testing. SQLConf.SHUFFLE_PARTITIONS.key -> "5") } + +private[sql] class TestSQLSessionStateBuilder( + session: SparkSession, + state: Option[SessionState]) + extends SessionStateBuilder(session, state) with WithTestConf { + override def overrideConfs: Map[String, String] = TestSQLContext.overrideConfs + override def newBuilder: NewBuilder = new TestSQLSessionStateBuilder(_, _) +} -- cgit v1.2.3