aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/main
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core/src/main')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala23
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala16
16 files changed, 64 insertions, 66 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 52b567ea25..76b8d71ac9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -394,7 +394,8 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
*/
def table(tableName: String): DataFrame = {
Dataset.newDataFrame(sqlContext,
- sqlContext.catalog.lookupRelation(sqlContext.sqlParser.parseTableIdentifier(tableName)))
+ sqlContext.catalog.lookupRelation(
+ sqlContext.sessionState.sqlParser.parseTableIdentifier(tableName)))
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index 3349b8421b..de87f4d7c2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -242,7 +242,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
options = extraOptions.toMap,
partitionColumns = normalizedParCols.getOrElse(Nil))
- df.sqlContext.continuousQueryManager.startQuery(
+ df.sqlContext.sessionState.continuousQueryManager.startQuery(
extraOptions.getOrElse("queryName", StreamExecution.nextName), df, dataSource.createSink())
}
@@ -255,7 +255,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @since 1.4.0
*/
def insertInto(tableName: String): Unit = {
- insertInto(df.sqlContext.sqlParser.parseTableIdentifier(tableName))
+ insertInto(df.sqlContext.sessionState.sqlParser.parseTableIdentifier(tableName))
}
private def insertInto(tableIdent: TableIdentifier): Unit = {
@@ -354,7 +354,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @since 1.4.0
*/
def saveAsTable(tableName: String): Unit = {
- saveAsTable(df.sqlContext.sqlParser.parseTableIdentifier(tableName))
+ saveAsTable(df.sqlContext.sessionState.sqlParser.parseTableIdentifier(tableName))
}
private def saveAsTable(tableIdent: TableIdentifier): Unit = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index b5079cf276..ef239a1e2f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -818,7 +818,7 @@ class Dataset[T] private[sql](
@scala.annotation.varargs
def selectExpr(exprs: String*): DataFrame = {
select(exprs.map { expr =>
- Column(sqlContext.sqlParser.parseExpression(expr))
+ Column(sqlContext.sessionState.sqlParser.parseExpression(expr))
}: _*)
}
@@ -919,7 +919,7 @@ class Dataset[T] private[sql](
* @since 1.3.0
*/
def filter(conditionExpr: String): Dataset[T] = {
- filter(Column(sqlContext.sqlParser.parseExpression(conditionExpr)))
+ filter(Column(sqlContext.sessionState.sqlParser.parseExpression(conditionExpr)))
}
/**
@@ -943,7 +943,7 @@ class Dataset[T] private[sql](
* @since 1.5.0
*/
def where(conditionExpr: String): Dataset[T] = {
- filter(Column(sqlContext.sqlParser.parseExpression(conditionExpr)))
+ filter(Column(sqlContext.sessionState.sqlParser.parseExpression(conditionExpr)))
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala
index deed45d273..d7cd84fd24 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala
@@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
* @since 1.3.0
*/
@Experimental
-class ExperimentalMethods protected[sql](sqlContext: SQLContext) {
+class ExperimentalMethods private[sql]() {
/**
* Allows extra strategies to be injected into the query planner at runtime. Note this API
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 36fe57f78b..0f5d1c8cab 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -121,14 +121,7 @@ class SQLContext private[sql](
protected[sql] lazy val sessionState: SessionState = new SessionState(self)
protected[sql] def conf: SQLConf = sessionState.conf
protected[sql] def catalog: Catalog = sessionState.catalog
- protected[sql] def functionRegistry: FunctionRegistry = sessionState.functionRegistry
protected[sql] def analyzer: Analyzer = sessionState.analyzer
- protected[sql] def optimizer: Optimizer = sessionState.optimizer
- protected[sql] def sqlParser: ParserInterface = sessionState.sqlParser
- protected[sql] def planner: SparkPlanner = sessionState.planner
- protected[sql] def continuousQueryManager = sessionState.continuousQueryManager
- protected[sql] def prepareForExecution: RuleExecutor[SparkPlan] =
- sessionState.prepareForExecution
/**
* An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s
@@ -197,7 +190,7 @@ class SQLContext private[sql](
*/
def getAllConfs: immutable.Map[String, String] = conf.getAllConfs
- protected[sql] def parseSql(sql: String): LogicalPlan = sqlParser.parsePlan(sql)
+ protected[sql] def parseSql(sql: String): LogicalPlan = sessionState.sqlParser.parsePlan(sql)
protected[sql] def executeSql(sql: String): QueryExecution = executePlan(parseSql(sql))
@@ -244,7 +237,7 @@ class SQLContext private[sql](
*/
@Experimental
@transient
- val experimental: ExperimentalMethods = new ExperimentalMethods(this)
+ def experimental: ExperimentalMethods = sessionState.experimentalMethods
/**
* :: Experimental ::
@@ -641,7 +634,7 @@ class SQLContext private[sql](
tableName: String,
source: String,
options: Map[String, String]): DataFrame = {
- val tableIdent = sqlParser.parseTableIdentifier(tableName)
+ val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName)
val cmd =
CreateTableUsing(
tableIdent,
@@ -687,7 +680,7 @@ class SQLContext private[sql](
source: String,
schema: StructType,
options: Map[String, String]): DataFrame = {
- val tableIdent = sqlParser.parseTableIdentifier(tableName)
+ val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName)
val cmd =
CreateTableUsing(
tableIdent,
@@ -706,7 +699,7 @@ class SQLContext private[sql](
* only during the lifetime of this instance of SQLContext.
*/
private[sql] def registerDataFrameAsTable(df: DataFrame, tableName: String): Unit = {
- catalog.registerTable(sqlParser.parseTableIdentifier(tableName), df.logicalPlan)
+ catalog.registerTable(sessionState.sqlParser.parseTableIdentifier(tableName), df.logicalPlan)
}
/**
@@ -800,7 +793,7 @@ class SQLContext private[sql](
* @since 1.3.0
*/
def table(tableName: String): DataFrame = {
- table(sqlParser.parseTableIdentifier(tableName))
+ table(sessionState.sqlParser.parseTableIdentifier(tableName))
}
private def table(tableIdent: TableIdentifier): DataFrame = {
@@ -837,9 +830,7 @@ class SQLContext private[sql](
*
* @since 2.0.0
*/
- def streams: ContinuousQueryManager = {
- continuousQueryManager
- }
+ def streams: ContinuousQueryManager = sessionState.continuousQueryManager
/**
* Returns the names of tables in the current database as an array.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index 9e60c1cd61..5b4254f741 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -45,16 +45,16 @@ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) {
sqlContext.cacheManager.useCachedData(analyzed)
}
- lazy val optimizedPlan: LogicalPlan = sqlContext.optimizer.execute(withCachedData)
+ lazy val optimizedPlan: LogicalPlan = sqlContext.sessionState.optimizer.execute(withCachedData)
lazy val sparkPlan: SparkPlan = {
SQLContext.setActive(sqlContext)
- sqlContext.planner.plan(ReturnAnswer(optimizedPlan)).next()
+ sqlContext.sessionState.planner.plan(ReturnAnswer(optimizedPlan)).next()
}
// executedPlan should not be used to initialize any SparkPlan. It should be
// only used for execution.
- lazy val executedPlan: SparkPlan = sqlContext.prepareForExecution.execute(sparkPlan)
+ lazy val executedPlan: SparkPlan = sqlContext.sessionState.prepareForExecution.execute(sparkPlan)
/** Internal version of the RDD. Avoids copies and has no schema */
lazy val toRdd: RDD[InternalRow] = executedPlan.execute()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
index edaf3b36aa..cbde777d98 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
@@ -17,11 +17,10 @@
package org.apache.spark.sql.execution
-import org.apache.spark.sql.SQLContext
-import org.apache.spark.sql.catalyst.optimizer._
+import org.apache.spark.sql.ExperimentalMethods
+import org.apache.spark.sql.catalyst.optimizer.Optimizer
-class SparkOptimizer(val sqlContext: SQLContext)
- extends Optimizer {
- override def batches: Seq[Batch] = super.batches :+ Batch(
- "User Provided Optimizers", FixedPoint(100), sqlContext.experimental.extraOptimizations: _*)
+class SparkOptimizer(experimentalMethods: ExperimentalMethods) extends Optimizer {
+ override def batches: Seq[Batch] = super.batches :+ Batch(
+ "User Provided Optimizers", FixedPoint(100), experimentalMethods.extraOptimizations: _*)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
index 292d366e72..9da2c74c62 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
@@ -21,14 +21,18 @@ import org.apache.spark.SparkContext
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, FileSourceStrategy}
+import org.apache.spark.sql.internal.SQLConf
-class SparkPlanner(val sqlContext: SQLContext) extends SparkStrategies {
- val sparkContext: SparkContext = sqlContext.sparkContext
+class SparkPlanner(
+ val sparkContext: SparkContext,
+ val conf: SQLConf,
+ val experimentalMethods: ExperimentalMethods)
+ extends SparkStrategies {
- def numPartitions: Int = sqlContext.conf.numShufflePartitions
+ def numPartitions: Int = conf.numShufflePartitions
def strategies: Seq[Strategy] =
- sqlContext.experimental.extraStrategies ++ (
+ experimentalMethods.extraStrategies ++ (
FileSourceStrategy ::
DataSourceStrategy ::
DDLStrategy ::
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 6352c48c76..113cf9ae2f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -80,8 +80,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
*/
object CanBroadcast {
def unapply(plan: LogicalPlan): Option[LogicalPlan] = {
- if (sqlContext.conf.autoBroadcastJoinThreshold > 0 &&
- plan.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold) {
+ if (conf.autoBroadcastJoinThreshold > 0 &&
+ plan.statistics.sizeInBytes <= conf.autoBroadcastJoinThreshold) {
Some(plan)
} else {
None
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index 8fb4705581..81676d3ebb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql.execution
import org.apache.spark.broadcast
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
@@ -29,6 +28,7 @@ import org.apache.spark.sql.catalyst.util.toCommentSafeString
import org.apache.spark.sql.execution.aggregate.TungstenAggregate
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin}
import org.apache.spark.sql.execution.metric.LongSQLMetricValue
+import org.apache.spark.sql.internal.SQLConf
/**
* An interface for those physical operators that support codegen.
@@ -427,7 +427,7 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup
/**
* Find the chained plans that support codegen, collapse them together as WholeStageCodegen.
*/
-private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Rule[SparkPlan] {
+case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] {
private def supportCodegen(e: Expression): Boolean = e match {
case e: LeafExpression => true
@@ -472,7 +472,7 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru
}
def apply(plan: SparkPlan): SparkPlan = {
- if (sqlContext.conf.wholeStageEnabled) {
+ if (conf.wholeStageEnabled) {
insertWholeStageCodegen(plan)
} else {
plan
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala
index 6e36a15a6d..e711797c1b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala
@@ -358,13 +358,14 @@ case class ShowFunctions(db: Option[String], pattern: Option[String]) extends Ru
case Some(p) =>
try {
val regex = java.util.regex.Pattern.compile(p)
- sqlContext.functionRegistry.listFunction().filter(regex.matcher(_).matches()).map(Row(_))
+ sqlContext.sessionState.functionRegistry.listFunction()
+ .filter(regex.matcher(_).matches()).map(Row(_))
} catch {
// probably will failed in the regex that user provided, then returns empty row.
case _: Throwable => Seq.empty[Row]
}
case None =>
- sqlContext.functionRegistry.listFunction().map(Row(_))
+ sqlContext.sessionState.functionRegistry.listFunction().map(Row(_))
}
}
@@ -395,7 +396,7 @@ case class DescribeFunction(
}
override def run(sqlContext: SQLContext): Seq[Row] = {
- sqlContext.functionRegistry.lookupFunction(functionName) match {
+ sqlContext.sessionState.functionRegistry.lookupFunction(functionName) match {
case Some(info) =>
val result =
Row(s"Function: ${info.getName}") ::
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
index 709a424636..4864db7f2a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
@@ -17,11 +17,11 @@
package org.apache.spark.sql.execution.exchange
-import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
+import org.apache.spark.sql.internal.SQLConf
/**
* Ensures that the [[org.apache.spark.sql.catalyst.plans.physical.Partitioning Partitioning]]
@@ -30,15 +30,15 @@ import org.apache.spark.sql.execution._
* each operator by inserting [[ShuffleExchange]] Operators where required. Also ensure that the
* input partition ordering requirements are met.
*/
-private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[SparkPlan] {
- private def defaultNumPreShufflePartitions: Int = sqlContext.conf.numShufflePartitions
+case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
+ private def defaultNumPreShufflePartitions: Int = conf.numShufflePartitions
- private def targetPostShuffleInputSize: Long = sqlContext.conf.targetPostShuffleInputSize
+ private def targetPostShuffleInputSize: Long = conf.targetPostShuffleInputSize
- private def adaptiveExecutionEnabled: Boolean = sqlContext.conf.adaptiveExecutionEnabled
+ private def adaptiveExecutionEnabled: Boolean = conf.adaptiveExecutionEnabled
private def minNumPostShufflePartitions: Option[Int] = {
- val minNumPostShufflePartitions = sqlContext.conf.minNumPostShufflePartitions
+ val minNumPostShufflePartitions = conf.minNumPostShufflePartitions
if (minNumPostShufflePartitions > 0) Some(minNumPostShufflePartitions) else None
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala
index 12513e9106..9eaadea1b1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala
@@ -22,11 +22,11 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.broadcast
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode}
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
/**
@@ -64,10 +64,10 @@ case class ReusedExchange(override val output: Seq[Attribute], child: Exchange)
* Find out duplicated exchanges in the spark plan, then use the same exchange for all the
* references.
*/
-private[sql] case class ReuseExchange(sqlContext: SQLContext) extends Rule[SparkPlan] {
+case class ReuseExchange(conf: SQLConf) extends Rule[SparkPlan] {
def apply(plan: SparkPlan): SparkPlan = {
- if (!sqlContext.conf.exchangeReuseEnabled) {
+ if (!conf.exchangeReuseEnabled) {
return plan
}
// Build a hash map using schema of exchanges to avoid O(N*N) sameResult calls.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
index e6d7480b04..0d580703f5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
@@ -17,12 +17,12 @@
package org.apache.spark.sql.execution
-import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.{expressions, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{ExprId, Literal, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.internal.SessionState
import org.apache.spark.sql.types.DataType
/**
@@ -62,12 +62,12 @@ case class ScalarSubquery(
/**
* Convert the subquery from logical plan into executed plan.
*/
-case class PlanSubqueries(sqlContext: SQLContext) extends Rule[SparkPlan] {
+case class PlanSubqueries(sessionState: SessionState) extends Rule[SparkPlan] {
def apply(plan: SparkPlan): SparkPlan = {
plan.transformAllExpressions {
case subquery: expressions.ScalarSubquery =>
- val sparkPlan = sqlContext.planner.plan(ReturnAnswer(subquery.query)).next()
- val executedPlan = sqlContext.prepareForExecution.execute(sparkPlan)
+ val sparkPlan = sessionState.planner.plan(ReturnAnswer(subquery.query)).next()
+ val executedPlan = sessionState.prepareForExecution.execute(sparkPlan)
ScalarSubquery(executedPlan, subquery.exprId)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 326c1e5a7c..dd4aa9e93a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1161,7 +1161,7 @@ object functions {
* @group normal_funcs
*/
def expr(expr: String): Column = {
- val parser = SQLContext.getActive().map(_.sqlParser).getOrElse(new CatalystQl())
+ val parser = SQLContext.getActive().map(_.sessionState.sqlParser).getOrElse(new CatalystQl())
Column(parser.parseExpression(expr))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
index 98ada4d58a..e6be0ab3bc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.internal
-import org.apache.spark.sql.{ContinuousQueryManager, SQLContext, UDFRegistration}
+import org.apache.spark.sql.{ContinuousQueryManager, ExperimentalMethods, SQLContext, UDFRegistration}
import org.apache.spark.sql.catalyst.analysis.{Analyzer, Catalog, FunctionRegistry, SimpleCatalog}
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.catalyst.parser.ParserInterface
@@ -40,6 +40,8 @@ private[sql] class SessionState(ctx: SQLContext) {
*/
lazy val conf = new SQLConf
+ lazy val experimentalMethods = new ExperimentalMethods
+
/**
* Internal catalog for managing table and database states.
*/
@@ -73,7 +75,7 @@ private[sql] class SessionState(ctx: SQLContext) {
/**
* Logical query plan optimizer.
*/
- lazy val optimizer: Optimizer = new SparkOptimizer(ctx)
+ lazy val optimizer: Optimizer = new SparkOptimizer(experimentalMethods)
/**
* Parser that extracts expressions, plans, table identifiers etc. from SQL texts.
@@ -83,7 +85,7 @@ private[sql] class SessionState(ctx: SQLContext) {
/**
* Planner that converts optimized logical plans to physical plans.
*/
- lazy val planner: SparkPlanner = new SparkPlanner(ctx)
+ lazy val planner: SparkPlanner = new SparkPlanner(ctx.sparkContext, conf, experimentalMethods)
/**
* Prepares a planned [[SparkPlan]] for execution by inserting shuffle operations and internal
@@ -91,10 +93,10 @@ private[sql] class SessionState(ctx: SQLContext) {
*/
lazy val prepareForExecution = new RuleExecutor[SparkPlan] {
override val batches: Seq[Batch] = Seq(
- Batch("Subquery", Once, PlanSubqueries(ctx)),
- Batch("Add exchange", Once, EnsureRequirements(ctx)),
- Batch("Whole stage codegen", Once, CollapseCodegenStages(ctx)),
- Batch("Reuse duplicated exchanges", Once, ReuseExchange(ctx))
+ Batch("Subquery", Once, PlanSubqueries(SessionState.this)),
+ Batch("Add exchange", Once, EnsureRequirements(conf)),
+ Batch("Whole stage codegen", Once, CollapseCodegenStages(conf)),
+ Batch("Reuse duplicated exchanges", Once, ReuseExchange(conf))
)
}