aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorHerman van Hovell <hvanhovell@databricks.com>2017-03-28 10:07:24 +0800
committerWenchen Fan <wenchen@databricks.com>2017-03-28 10:07:24 +0800
commitea361165e1ddce4d8aa0242ae3e878d7b39f1de2 (patch)
treef3014ba709d54b48172a399708074480a6ed9661 /sql/core
parent8a6f33f0483dcee81467e6374a796b5dbd53ea30 (diff)
downloadspark-ea361165e1ddce4d8aa0242ae3e878d7b39f1de2.tar.gz
spark-ea361165e1ddce4d8aa0242ae3e878d7b39f1de2.tar.bz2
spark-ea361165e1ddce4d8aa0242ae3e878d7b39f1de2.zip
[SPARK-20100][SQL] Refactor SessionState initialization
## What changes were proposed in this pull request? The current SessionState initialization code path is quite complex. A part of the creation is done in the SessionState companion objects, a part of the creation is one inside the SessionState class, and a part is done by passing functions. This PR refactors this code path, and consolidates SessionState initialization into a builder class. This SessionState will not do any initialization and just becomes a place holder for the various Spark SQL internals. This also lays the ground work for two future improvements: 1. This provides us with a start for removing the `HiveSessionState`. Removing the `HiveSessionState` would also require us to move resource loading into a separate class, and to (re)move metadata hive. 2. This makes it easier to customize the Spark Session. Currently you will need to create a custom version of the builder. I have added hooks to facilitate this. A future step will be to create a semi stable API on top of this. ## How was this patch tested? Existing tests. Author: Herman van Hovell <hvanhovell@databricks.com> Closes #17433 from hvanhovell/SPARK-20100.
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala23
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala180
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/sessionStateBuilders.scala279
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala23
6 files changed, 364 insertions, 164 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
index 981728331d..2cdfb7a782 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
@@ -30,9 +30,17 @@ class SparkOptimizer(
experimentalMethods: ExperimentalMethods)
extends Optimizer(catalog, conf) {
- override def batches: Seq[Batch] = super.batches :+
+ override def batches: Seq[Batch] = (super.batches :+
Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog, conf)) :+
Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+
- Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions) :+
+ Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions)) ++
+ postHocOptimizationBatches :+
Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*)
+
+ /**
+ * Optimization batches that are executed after the regular optimization batches, but before the
+ * batch executing the [[ExperimentalMethods]] optimizer rules. This hook can be used to add
+ * custom optimizer batches to the Spark optimizer.
+ */
+ def postHocOptimizationBatches: Seq[Batch] = Nil
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
index 678241656c..6566502bd8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
@@ -27,13 +27,14 @@ import org.apache.spark.sql.internal.SQLConf
class SparkPlanner(
val sparkContext: SparkContext,
val conf: SQLConf,
- val extraStrategies: Seq[Strategy])
+ val experimentalMethods: ExperimentalMethods)
extends SparkStrategies {
def numPartitions: Int = conf.numShufflePartitions
def strategies: Seq[Strategy] =
- extraStrategies ++ (
+ experimentalMethods.extraStrategies ++
+ extraPlanningStrategies ++ (
FileSourceStrategy ::
DataSourceStrategy ::
SpecialLimits ::
@@ -42,6 +43,12 @@ class SparkPlanner(
InMemoryScans ::
BasicOperators :: Nil)
+ /**
+ * Override to add extra planning strategies to the planner. These strategies are tried after
+ * the strategies defined in [[ExperimentalMethods]], and before the regular strategies.
+ */
+ def extraPlanningStrategies: Seq[Strategy] = Nil
+
override protected def collectPlaceholders(plan: SparkPlan): Seq[(SparkPlan, LogicalPlan)] = {
plan.collect {
case placeholder @ PlanLater(logicalPlan) => placeholder -> logicalPlan
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
index 0f0e4a91f8..622e049630 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
@@ -20,8 +20,8 @@ package org.apache.spark.sql.execution.streaming
import java.util.concurrent.atomic.AtomicInteger
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.catalyst.expressions.{CurrentBatchTimestamp, Literal}
-import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.{SparkSession, Strategy}
+import org.apache.spark.sql.catalyst.expressions.CurrentBatchTimestamp
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SparkPlanner, UnaryExecNode}
@@ -40,20 +40,17 @@ class IncrementalExecution(
offsetSeqMetadata: OffsetSeqMetadata)
extends QueryExecution(sparkSession, logicalPlan) with Logging {
- // TODO: make this always part of planning.
- val streamingExtraStrategies =
- sparkSession.sessionState.planner.StatefulAggregationStrategy +:
- sparkSession.sessionState.planner.FlatMapGroupsWithStateStrategy +:
- sparkSession.sessionState.planner.StreamingRelationStrategy +:
- sparkSession.sessionState.planner.StreamingDeduplicationStrategy +:
- sparkSession.sessionState.experimentalMethods.extraStrategies
-
// Modified planner with stateful operations.
- override def planner: SparkPlanner =
- new SparkPlanner(
+ override val planner: SparkPlanner = new SparkPlanner(
sparkSession.sparkContext,
sparkSession.sessionState.conf,
- streamingExtraStrategies)
+ sparkSession.sessionState.experimentalMethods) {
+ override def extraPlanningStrategies: Seq[Strategy] =
+ StatefulAggregationStrategy ::
+ FlatMapGroupsWithStateStrategy ::
+ StreamingRelationStrategy ::
+ StreamingDeduplicationStrategy :: Nil
+ }
/**
* See [SPARK-18339]
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
index ce80604bd3..b5b0bb0bfc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
@@ -22,22 +22,21 @@ import java.io.File
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
-import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.SparkContext
+import org.apache.spark.annotation.{Experimental, InterfaceStability}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry}
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.streaming.StreamingQueryManager
import org.apache.spark.sql.util.ExecutionListenerManager
-
/**
* A class that holds all session-specific state in a given [[SparkSession]].
+ *
* @param sparkContext The [[SparkContext]].
* @param sharedState The shared state.
* @param conf SQL-specific key-value configurations.
@@ -46,9 +45,11 @@ import org.apache.spark.sql.util.ExecutionListenerManager
* @param catalog Internal catalog for managing table and database states.
* @param sqlParser Parser that extracts expressions, plans, table identifiers etc. from SQL texts.
* @param analyzer Logical query plan analyzer for resolving unresolved attributes and relations.
- * @param streamingQueryManager Interface to start and stop
- * [[org.apache.spark.sql.streaming.StreamingQuery]]s.
- * @param queryExecutionCreator Lambda to create a [[QueryExecution]] from a [[LogicalPlan]]
+ * @param optimizer Logical query plan optimizer.
+ * @param planner Planner that converts optimized logical plans to physical plans
+ * @param streamingQueryManager Interface to start and stop streaming queries.
+ * @param createQueryExecution Function used to create QueryExecution objects.
+ * @param createClone Function used to create clones of the session state.
*/
private[sql] class SessionState(
sparkContext: SparkContext,
@@ -59,8 +60,11 @@ private[sql] class SessionState(
val catalog: SessionCatalog,
val sqlParser: ParserInterface,
val analyzer: Analyzer,
+ val optimizer: Optimizer,
+ val planner: SparkPlanner,
val streamingQueryManager: StreamingQueryManager,
- val queryExecutionCreator: LogicalPlan => QueryExecution) {
+ createQueryExecution: LogicalPlan => QueryExecution,
+ createClone: (SparkSession, SessionState) => SessionState) {
def newHadoopConf(): Configuration = SessionState.newHadoopConf(
sparkContext.hadoopConfiguration,
@@ -77,41 +81,12 @@ private[sql] class SessionState(
}
/**
- * A class for loading resources specified by a function.
- */
- val functionResourceLoader: FunctionResourceLoader = {
- new FunctionResourceLoader {
- override def loadResource(resource: FunctionResource): Unit = {
- resource.resourceType match {
- case JarResource => addJar(resource.uri)
- case FileResource => sparkContext.addFile(resource.uri)
- case ArchiveResource =>
- throw new AnalysisException(
- "Archive is not allowed to be loaded. If YARN mode is used, " +
- "please use --archives options while calling spark-submit.")
- }
- }
- }
- }
-
- /**
* Interface exposed to the user for registering user-defined functions.
* Note that the user-defined functions must be deterministic.
*/
val udf: UDFRegistration = new UDFRegistration(functionRegistry)
/**
- * Logical query plan optimizer.
- */
- val optimizer: Optimizer = new SparkOptimizer(catalog, conf, experimentalMethods)
-
- /**
- * Planner that converts optimized logical plans to physical plans.
- */
- def planner: SparkPlanner =
- new SparkPlanner(sparkContext, conf, experimentalMethods.extraStrategies)
-
- /**
* An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s
* that listen for execution metrics.
*/
@@ -120,38 +95,13 @@ private[sql] class SessionState(
/**
* Get an identical copy of the `SessionState` and associate it with the given `SparkSession`
*/
- def clone(newSparkSession: SparkSession): SessionState = {
- val sparkContext = newSparkSession.sparkContext
- val confCopy = conf.clone()
- val functionRegistryCopy = functionRegistry.clone()
- val sqlParser: ParserInterface = new SparkSqlParser(confCopy)
- val catalogCopy = catalog.newSessionCatalogWith(
- confCopy,
- SessionState.newHadoopConf(sparkContext.hadoopConfiguration, confCopy),
- functionRegistryCopy,
- sqlParser)
- val queryExecutionCreator = (plan: LogicalPlan) => new QueryExecution(newSparkSession, plan)
-
- SessionState.mergeSparkConf(confCopy, sparkContext.getConf)
-
- new SessionState(
- sparkContext,
- newSparkSession.sharedState,
- confCopy,
- experimentalMethods.clone(),
- functionRegistryCopy,
- catalogCopy,
- sqlParser,
- SessionState.createAnalyzer(newSparkSession, catalogCopy, confCopy),
- new StreamingQueryManager(newSparkSession),
- queryExecutionCreator)
- }
+ def clone(newSparkSession: SparkSession): SessionState = createClone(newSparkSession, this)
// ------------------------------------------------------
// Helper methods, partially leftover from pre-2.0 days
// ------------------------------------------------------
- def executePlan(plan: LogicalPlan): QueryExecution = queryExecutionCreator(plan)
+ def executePlan(plan: LogicalPlan): QueryExecution = createQueryExecution(plan)
def refreshTable(tableName: String): Unit = {
catalog.refreshTable(sqlParser.parseTableIdentifier(tableName))
@@ -179,53 +129,12 @@ private[sql] class SessionState(
}
}
-
private[sql] object SessionState {
-
- def apply(sparkSession: SparkSession): SessionState = {
- apply(sparkSession, new SQLConf)
- }
-
- def apply(sparkSession: SparkSession, sqlConf: SQLConf): SessionState = {
- val sparkContext = sparkSession.sparkContext
-
- // Automatically extract all entries and put them in our SQLConf
- mergeSparkConf(sqlConf, sparkContext.getConf)
-
- val functionRegistry = FunctionRegistry.builtin.clone()
-
- val sqlParser: ParserInterface = new SparkSqlParser(sqlConf)
-
- val catalog = new SessionCatalog(
- sparkSession.sharedState.externalCatalog,
- sparkSession.sharedState.globalTempViewManager,
- functionRegistry,
- sqlConf,
- newHadoopConf(sparkContext.hadoopConfiguration, sqlConf),
- sqlParser)
-
- val analyzer: Analyzer = createAnalyzer(sparkSession, catalog, sqlConf)
-
- val streamingQueryManager: StreamingQueryManager = new StreamingQueryManager(sparkSession)
-
- val queryExecutionCreator = (plan: LogicalPlan) => new QueryExecution(sparkSession, plan)
-
- val sessionState = new SessionState(
- sparkContext,
- sparkSession.sharedState,
- sqlConf,
- new ExperimentalMethods,
- functionRegistry,
- catalog,
- sqlParser,
- analyzer,
- streamingQueryManager,
- queryExecutionCreator)
- // functionResourceLoader needs to access SessionState.addJar, so it cannot be created before
- // creating SessionState. Setting `catalog.functionResourceLoader` here is safe since the caller
- // cannot use SessionCatalog before we return SessionState.
- catalog.functionResourceLoader = sessionState.functionResourceLoader
- sessionState
+ /**
+ * Create a new [[SessionState]] for the given session.
+ */
+ def apply(session: SparkSession): SessionState = {
+ new SessionStateBuilder(session).build()
}
def newHadoopConf(hadoopConf: Configuration, sqlConf: SQLConf): Configuration = {
@@ -233,34 +142,33 @@ private[sql] object SessionState {
sqlConf.getAllConfs.foreach { case (k, v) => if (v ne null) newHadoopConf.set(k, v) }
newHadoopConf
}
+}
- /**
- * Create an logical query plan `Analyzer` with rules specific to a non-Hive `SessionState`.
- */
- private def createAnalyzer(
- sparkSession: SparkSession,
- catalog: SessionCatalog,
- sqlConf: SQLConf): Analyzer = {
- new Analyzer(catalog, sqlConf) {
- override val extendedResolutionRules: Seq[Rule[LogicalPlan]] =
- new FindDataSourceTable(sparkSession) ::
- new ResolveSQLOnFile(sparkSession) :: Nil
-
- override val postHocResolutionRules: Seq[Rule[LogicalPlan]] =
- PreprocessTableCreation(sparkSession) ::
- PreprocessTableInsertion(sqlConf) ::
- DataSourceAnalysis(sqlConf) :: Nil
-
- override val extendedCheckRules = Seq(PreWriteCheck, HiveOnlyCheck)
- }
- }
+/**
+ * Concrete implementation of a [[SessionStateBuilder]].
+ */
+@Experimental
+@InterfaceStability.Unstable
+class SessionStateBuilder(
+ session: SparkSession,
+ parentState: Option[SessionState] = None)
+ extends BaseSessionStateBuilder(session, parentState) {
+ override protected def newBuilder: NewBuilder = new SessionStateBuilder(_, _)
+}
- /**
- * Extract entries from `SparkConf` and put them in the `SQLConf`
- */
- def mergeSparkConf(sqlConf: SQLConf, sparkConf: SparkConf): Unit = {
- sparkConf.getAll.foreach { case (k, v) =>
- sqlConf.setConfString(k, v)
+/**
+ * Session shared [[FunctionResourceLoader]].
+ */
+@InterfaceStability.Unstable
+class SessionFunctionResourceLoader(session: SparkSession) extends FunctionResourceLoader {
+ override def loadResource(resource: FunctionResource): Unit = {
+ resource.resourceType match {
+ case JarResource => session.sessionState.addJar(resource.uri)
+ case FileResource => session.sparkContext.addFile(resource.uri)
+ case ArchiveResource =>
+ throw new AnalysisException(
+ "Archive is not allowed to be loaded. If YARN mode is used, " +
+ "please use --archives options while calling spark-submit.")
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/sessionStateBuilders.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/sessionStateBuilders.scala
new file mode 100644
index 0000000000..6b5559adb1
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/sessionStateBuilders.scala
@@ -0,0 +1,279 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.internal
+
+import org.apache.spark.SparkConf
+import org.apache.spark.annotation.{Experimental, InterfaceStability}
+import org.apache.spark.sql.{ExperimentalMethods, SparkSession, Strategy}
+import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry}
+import org.apache.spark.sql.catalyst.catalog.SessionCatalog
+import org.apache.spark.sql.catalyst.optimizer.Optimizer
+import org.apache.spark.sql.catalyst.parser.ParserInterface
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.{QueryExecution, SparkOptimizer, SparkPlanner, SparkSqlParser}
+import org.apache.spark.sql.execution.datasources._
+import org.apache.spark.sql.streaming.StreamingQueryManager
+
+/**
+ * Builder class that coordinates construction of a new [[SessionState]].
+ *
+ * The builder explicitly defines all components needed by the session state, and creates a session
+ * state when `build` is called. Components should only be initialized once. This is not a problem
+ * for most components as they are only used in the `build` function. However some components
+ * (`conf`, `catalog`, `functionRegistry`, `experimentalMethods` & `sqlParser`) are as dependencies
+ * for other components and are shared as a result. These components are defined as lazy vals to
+ * make sure the component is created only once.
+ *
+ * A developer can modify the builder by providing custom versions of components, or by using the
+ * hooks provided for the analyzer, optimizer & planner. There are some dependencies between the
+ * components (they are documented per dependency), a developer should respect these when making
+ * modifications in order to prevent initialization problems.
+ *
+ * A parent [[SessionState]] can be used to initialize the new [[SessionState]]. The new session
+ * state will clone the parent sessions state's `conf`, `functionRegistry`, `experimentalMethods`
+ * and `catalog` fields. Note that the state is cloned when `build` is called, and not before.
+ */
+@Experimental
+@InterfaceStability.Unstable
+abstract class BaseSessionStateBuilder(
+ val session: SparkSession,
+ val parentState: Option[SessionState] = None) {
+ type NewBuilder = (SparkSession, Option[SessionState]) => BaseSessionStateBuilder
+
+ /**
+ * Function that produces a new instance of the SessionStateBuilder. This is used by the
+ * [[SessionState]]'s clone functionality. Make sure to override this when implementing your own
+ * [[SessionStateBuilder]].
+ */
+ protected def newBuilder: NewBuilder
+
+ /**
+ * Extract entries from `SparkConf` and put them in the `SQLConf`
+ */
+ protected def mergeSparkConf(sqlConf: SQLConf, sparkConf: SparkConf): Unit = {
+ sparkConf.getAll.foreach { case (k, v) =>
+ sqlConf.setConfString(k, v)
+ }
+ }
+
+ /**
+ * SQL-specific key-value configurations.
+ *
+ * These either get cloned from a pre-existing instance or newly created. The conf is always
+ * merged with its [[SparkConf]].
+ */
+ protected lazy val conf: SQLConf = {
+ val conf = parentState.map(_.conf.clone()).getOrElse(new SQLConf)
+ mergeSparkConf(conf, session.sparkContext.conf)
+ conf
+ }
+
+ /**
+ * Internal catalog managing functions registered by the user.
+ *
+ * This either gets cloned from a pre-existing version or cloned from the built-in registry.
+ */
+ protected lazy val functionRegistry: FunctionRegistry = {
+ parentState.map(_.functionRegistry).getOrElse(FunctionRegistry.builtin).clone()
+ }
+
+ /**
+ * Experimental methods that can be used to define custom optimization rules and custom planning
+ * strategies.
+ *
+ * This either gets cloned from a pre-existing version or newly created.
+ */
+ protected lazy val experimentalMethods: ExperimentalMethods = {
+ parentState.map(_.experimentalMethods.clone()).getOrElse(new ExperimentalMethods)
+ }
+
+ /**
+ * Parser that extracts expressions, plans, table identifiers etc. from SQL texts.
+ *
+ * Note: this depends on the `conf` field.
+ */
+ protected lazy val sqlParser: ParserInterface = new SparkSqlParser(conf)
+
+ /**
+ * Catalog for managing table and database states. If there is a pre-existing catalog, the state
+ * of that catalog (temp tables & current database) will be copied into the new catalog.
+ *
+ * Note: this depends on the `conf`, `functionRegistry` and `sqlParser` fields.
+ */
+ protected lazy val catalog: SessionCatalog = {
+ val catalog = new SessionCatalog(
+ session.sharedState.externalCatalog,
+ session.sharedState.globalTempViewManager,
+ functionRegistry,
+ conf,
+ SessionState.newHadoopConf(session.sparkContext.hadoopConfiguration, conf),
+ sqlParser,
+ new SessionFunctionResourceLoader(session))
+ parentState.foreach(_.catalog.copyStateTo(catalog))
+ catalog
+ }
+
+ /**
+ * Logical query plan analyzer for resolving unresolved attributes and relations.
+ *
+ * Note: this depends on the `conf` and `catalog` fields.
+ */
+ protected def analyzer: Analyzer = new Analyzer(catalog, conf) {
+ override val extendedResolutionRules: Seq[Rule[LogicalPlan]] =
+ new FindDataSourceTable(session) +:
+ new ResolveSQLOnFile(session) +:
+ customResolutionRules
+
+ override val postHocResolutionRules: Seq[Rule[LogicalPlan]] =
+ PreprocessTableCreation(session) +:
+ PreprocessTableInsertion(conf) +:
+ DataSourceAnalysis(conf) +:
+ customPostHocResolutionRules
+
+ override val extendedCheckRules: Seq[LogicalPlan => Unit] =
+ PreWriteCheck +:
+ HiveOnlyCheck +:
+ customCheckRules
+ }
+
+ /**
+ * Custom resolution rules to add to the Analyzer. Prefer overriding this instead of creating
+ * your own Analyzer.
+ *
+ * Note that this may NOT depend on the `analyzer` function.
+ */
+ protected def customResolutionRules: Seq[Rule[LogicalPlan]] = Nil
+
+ /**
+ * Custom post resolution rules to add to the Analyzer. Prefer overriding this instead of
+ * creating your own Analyzer.
+ *
+ * Note that this may NOT depend on the `analyzer` function.
+ */
+ protected def customPostHocResolutionRules: Seq[Rule[LogicalPlan]] = Nil
+
+ /**
+ * Custom check rules to add to the Analyzer. Prefer overriding this instead of creating
+ * your own Analyzer.
+ *
+ * Note that this may NOT depend on the `analyzer` function.
+ */
+ protected def customCheckRules: Seq[LogicalPlan => Unit] = Nil
+
+ /**
+ * Logical query plan optimizer.
+ *
+ * Note: this depends on the `conf`, `catalog` and `experimentalMethods` fields.
+ */
+ protected def optimizer: Optimizer = {
+ new SparkOptimizer(catalog, conf, experimentalMethods) {
+ override def extendedOperatorOptimizationRules: Seq[Rule[LogicalPlan]] =
+ super.extendedOperatorOptimizationRules ++ customOperatorOptimizationRules
+ }
+ }
+
+ /**
+ * Custom operator optimization rules to add to the Optimizer. Prefer overriding this instead
+ * of creating your own Optimizer.
+ *
+ * Note that this may NOT depend on the `optimizer` function.
+ */
+ protected def customOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = Nil
+
+ /**
+ * Planner that converts optimized logical plans to physical plans.
+ *
+ * Note: this depends on the `conf` and `experimentalMethods` fields.
+ */
+ protected def planner: SparkPlanner = {
+ new SparkPlanner(session.sparkContext, conf, experimentalMethods) {
+ override def extraPlanningStrategies: Seq[Strategy] =
+ super.extraPlanningStrategies ++ customPlanningStrategies
+ }
+ }
+
+ /**
+ * Custom strategies to add to the planner. Prefer overriding this instead of creating
+ * your own Planner.
+ *
+ * Note that this may NOT depend on the `planner` function.
+ */
+ protected def customPlanningStrategies: Seq[Strategy] = Nil
+
+ /**
+ * Create a query execution object.
+ */
+ protected def createQueryExecution: LogicalPlan => QueryExecution = { plan =>
+ new QueryExecution(session, plan)
+ }
+
+ /**
+ * Interface to start and stop streaming queries.
+ */
+ protected def streamingQueryManager: StreamingQueryManager = new StreamingQueryManager(session)
+
+ /**
+ * Function used to make clones of the session state.
+ */
+ protected def createClone: (SparkSession, SessionState) => SessionState = {
+ val createBuilder = newBuilder
+ (session, state) => createBuilder(session, Option(state)).build()
+ }
+
+ /**
+ * Build the [[SessionState]].
+ */
+ def build(): SessionState = {
+ new SessionState(
+ session.sparkContext,
+ session.sharedState,
+ conf,
+ experimentalMethods,
+ functionRegistry,
+ catalog,
+ sqlParser,
+ analyzer,
+ optimizer,
+ planner,
+ streamingQueryManager,
+ createQueryExecution,
+ createClone)
+ }
+}
+
+/**
+ * Helper class for using SessionStateBuilders during tests.
+ */
+private[sql] trait WithTestConf { self: BaseSessionStateBuilder =>
+ def overrideConfs: Map[String, String]
+
+ override protected lazy val conf: SQLConf = {
+ val conf = parentState.map(_.conf.clone()).getOrElse {
+ new SQLConf {
+ clear()
+ override def clear(): Unit = {
+ super.clear()
+ // Make sure we start with the default test configs even after clear
+ overrideConfs.foreach { case (key, value) => setConfString(key, value) }
+ }
+ }
+ }
+ mergeSparkConf(conf, session.sparkContext.conf)
+ conf
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
index 898a2fb4f3..b01977a238 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.test
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.internal.{SessionState, SQLConf}
+import org.apache.spark.sql.internal.{SessionState, SessionStateBuilder, SQLConf, WithTestConf}
/**
* A special [[SparkSession]] prepared for testing.
@@ -35,16 +35,9 @@ private[sql] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) {
}
@transient
- override lazy val sessionState: SessionState = SessionState(
- this,
- new SQLConf {
- clear()
- override def clear(): Unit = {
- super.clear()
- // Make sure we start with the default test configs even after clear
- TestSQLContext.overrideConfs.foreach { case (key, value) => setConfString(key, value) }
- }
- })
+ override lazy val sessionState: SessionState = {
+ new TestSQLSessionStateBuilder(this, None).build()
+ }
// Needed for Java tests
def loadTestData(): Unit = {
@@ -67,3 +60,11 @@ private[sql] object TestSQLContext {
// Fewer shuffle partitions to speed up testing.
SQLConf.SHUFFLE_PARTITIONS.key -> "5")
}
+
+private[sql] class TestSQLSessionStateBuilder(
+ session: SparkSession,
+ state: Option[SessionState])
+ extends SessionStateBuilder(session, state) with WithTestConf {
+ override def overrideConfs: Map[String, String] = TestSQLContext.overrideConfs
+ override def newBuilder: NewBuilder = new TestSQLSessionStateBuilder(_, _)
+}