From 276c2d51a3bbe2531763a11580adfec7e39fdd58 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 14 Mar 2016 23:58:57 -0700 Subject: [SPARK-13890][SQL] Remove some internal classes' dependency on SQLContext ## What changes were proposed in this pull request? In general it is better for internal classes to not depend on the external class (in this case SQLContext) to reduce coupling between user-facing APIs and the internal implementations. This patch removes SQLContext dependency from some internal classes such as SparkPlanner, SparkOptimizer. As part of this patch, I also removed the following internal methods from SQLContext: ``` protected[sql] def functionRegistry: FunctionRegistry protected[sql] def optimizer: Optimizer protected[sql] def sqlParser: ParserInterface protected[sql] def planner: SparkPlanner protected[sql] def continuousQueryManager protected[sql] def prepareForExecution: RuleExecutor[SparkPlan] ``` ## How was this patch tested? Existing unit/integration tests. Author: Reynold Xin Closes #11712 from rxin/sqlContext-planner. --- .../org/apache/spark/sql/DataFrameReader.scala | 3 ++- .../org/apache/spark/sql/DataFrameWriter.scala | 6 ++--- .../main/scala/org/apache/spark/sql/Dataset.scala | 6 ++--- .../org/apache/spark/sql/ExperimentalMethods.scala | 2 +- .../scala/org/apache/spark/sql/SQLContext.scala | 23 ++++++------------- .../spark/sql/execution/QueryExecution.scala | 6 ++--- .../spark/sql/execution/SparkOptimizer.scala | 11 +++++---- .../apache/spark/sql/execution/SparkPlanner.scala | 12 ++++++---- .../spark/sql/execution/SparkStrategies.scala | 4 ++-- .../spark/sql/execution/WholeStageCodegen.scala | 6 ++--- .../spark/sql/execution/command/commands.scala | 7 +++--- .../execution/exchange/EnsureRequirements.scala | 12 +++++----- .../spark/sql/execution/exchange/Exchange.scala | 6 ++--- .../org/apache/spark/sql/execution/subquery.scala | 8 +++---- .../scala/org/apache/spark/sql/functions.scala | 2 +- .../apache/spark/sql/internal/SessionState.scala | 16 +++++++------ .../scala/org/apache/spark/sql/JoinSuite.scala | 4 ++-- .../org/apache/spark/sql/SQLContextSuite.scala | 2 +- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 3 ++- .../apache/spark/sql/execution/PlannerSuite.scala | 26 +++++++++++----------- .../apache/spark/sql/execution/SparkPlanTest.scala | 2 +- .../sql/execution/joins/BroadcastJoinSuite.scala | 3 ++- .../spark/sql/execution/joins/InnerJoinSuite.scala | 4 ++-- .../spark/sql/execution/joins/OuterJoinSuite.scala | 2 +- .../spark/sql/execution/joins/SemiJoinSuite.scala | 2 +- 25 files changed, 89 insertions(+), 89 deletions(-) (limited to 'sql/core/src') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 52b567ea25..76b8d71ac9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -394,7 +394,8 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { */ def table(tableName: String): DataFrame = { Dataset.newDataFrame(sqlContext, - sqlContext.catalog.lookupRelation(sqlContext.sqlParser.parseTableIdentifier(tableName))) + sqlContext.catalog.lookupRelation( + sqlContext.sessionState.sqlParser.parseTableIdentifier(tableName))) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 3349b8421b..de87f4d7c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -242,7 +242,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { options = extraOptions.toMap, partitionColumns = normalizedParCols.getOrElse(Nil)) - df.sqlContext.continuousQueryManager.startQuery( + df.sqlContext.sessionState.continuousQueryManager.startQuery( extraOptions.getOrElse("queryName", StreamExecution.nextName), df, dataSource.createSink()) } @@ -255,7 +255,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def insertInto(tableName: String): Unit = { - insertInto(df.sqlContext.sqlParser.parseTableIdentifier(tableName)) + insertInto(df.sqlContext.sessionState.sqlParser.parseTableIdentifier(tableName)) } private def insertInto(tableIdent: TableIdentifier): Unit = { @@ -354,7 +354,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def saveAsTable(tableName: String): Unit = { - saveAsTable(df.sqlContext.sqlParser.parseTableIdentifier(tableName)) + saveAsTable(df.sqlContext.sessionState.sqlParser.parseTableIdentifier(tableName)) } private def saveAsTable(tableIdent: TableIdentifier): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index b5079cf276..ef239a1e2f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -818,7 +818,7 @@ class Dataset[T] private[sql]( @scala.annotation.varargs def selectExpr(exprs: String*): DataFrame = { select(exprs.map { expr => - Column(sqlContext.sqlParser.parseExpression(expr)) + Column(sqlContext.sessionState.sqlParser.parseExpression(expr)) }: _*) } @@ -919,7 +919,7 @@ class Dataset[T] private[sql]( * @since 1.3.0 */ def filter(conditionExpr: String): Dataset[T] = { - filter(Column(sqlContext.sqlParser.parseExpression(conditionExpr))) + filter(Column(sqlContext.sessionState.sqlParser.parseExpression(conditionExpr))) } /** @@ -943,7 +943,7 @@ class Dataset[T] private[sql]( * @since 1.5.0 */ def where(conditionExpr: String): Dataset[T] = { - filter(Column(sqlContext.sqlParser.parseExpression(conditionExpr))) + filter(Column(sqlContext.sessionState.sqlParser.parseExpression(conditionExpr))) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala index deed45d273..d7cd84fd24 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.rules.Rule * @since 1.3.0 */ @Experimental -class ExperimentalMethods protected[sql](sqlContext: SQLContext) { +class ExperimentalMethods private[sql]() { /** * Allows extra strategies to be injected into the query planner at runtime. Note this API diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 36fe57f78b..0f5d1c8cab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -121,14 +121,7 @@ class SQLContext private[sql]( protected[sql] lazy val sessionState: SessionState = new SessionState(self) protected[sql] def conf: SQLConf = sessionState.conf protected[sql] def catalog: Catalog = sessionState.catalog - protected[sql] def functionRegistry: FunctionRegistry = sessionState.functionRegistry protected[sql] def analyzer: Analyzer = sessionState.analyzer - protected[sql] def optimizer: Optimizer = sessionState.optimizer - protected[sql] def sqlParser: ParserInterface = sessionState.sqlParser - protected[sql] def planner: SparkPlanner = sessionState.planner - protected[sql] def continuousQueryManager = sessionState.continuousQueryManager - protected[sql] def prepareForExecution: RuleExecutor[SparkPlan] = - sessionState.prepareForExecution /** * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s @@ -197,7 +190,7 @@ class SQLContext private[sql]( */ def getAllConfs: immutable.Map[String, String] = conf.getAllConfs - protected[sql] def parseSql(sql: String): LogicalPlan = sqlParser.parsePlan(sql) + protected[sql] def parseSql(sql: String): LogicalPlan = sessionState.sqlParser.parsePlan(sql) protected[sql] def executeSql(sql: String): QueryExecution = executePlan(parseSql(sql)) @@ -244,7 +237,7 @@ class SQLContext private[sql]( */ @Experimental @transient - val experimental: ExperimentalMethods = new ExperimentalMethods(this) + def experimental: ExperimentalMethods = sessionState.experimentalMethods /** * :: Experimental :: @@ -641,7 +634,7 @@ class SQLContext private[sql]( tableName: String, source: String, options: Map[String, String]): DataFrame = { - val tableIdent = sqlParser.parseTableIdentifier(tableName) + val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName) val cmd = CreateTableUsing( tableIdent, @@ -687,7 +680,7 @@ class SQLContext private[sql]( source: String, schema: StructType, options: Map[String, String]): DataFrame = { - val tableIdent = sqlParser.parseTableIdentifier(tableName) + val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName) val cmd = CreateTableUsing( tableIdent, @@ -706,7 +699,7 @@ class SQLContext private[sql]( * only during the lifetime of this instance of SQLContext. */ private[sql] def registerDataFrameAsTable(df: DataFrame, tableName: String): Unit = { - catalog.registerTable(sqlParser.parseTableIdentifier(tableName), df.logicalPlan) + catalog.registerTable(sessionState.sqlParser.parseTableIdentifier(tableName), df.logicalPlan) } /** @@ -800,7 +793,7 @@ class SQLContext private[sql]( * @since 1.3.0 */ def table(tableName: String): DataFrame = { - table(sqlParser.parseTableIdentifier(tableName)) + table(sessionState.sqlParser.parseTableIdentifier(tableName)) } private def table(tableIdent: TableIdentifier): DataFrame = { @@ -837,9 +830,7 @@ class SQLContext private[sql]( * * @since 2.0.0 */ - def streams: ContinuousQueryManager = { - continuousQueryManager - } + def streams: ContinuousQueryManager = sessionState.continuousQueryManager /** * Returns the names of tables in the current database as an array. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 9e60c1cd61..5b4254f741 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -45,16 +45,16 @@ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) { sqlContext.cacheManager.useCachedData(analyzed) } - lazy val optimizedPlan: LogicalPlan = sqlContext.optimizer.execute(withCachedData) + lazy val optimizedPlan: LogicalPlan = sqlContext.sessionState.optimizer.execute(withCachedData) lazy val sparkPlan: SparkPlan = { SQLContext.setActive(sqlContext) - sqlContext.planner.plan(ReturnAnswer(optimizedPlan)).next() + sqlContext.sessionState.planner.plan(ReturnAnswer(optimizedPlan)).next() } // executedPlan should not be used to initialize any SparkPlan. It should be // only used for execution. - lazy val executedPlan: SparkPlan = sqlContext.prepareForExecution.execute(sparkPlan) + lazy val executedPlan: SparkPlan = sqlContext.sessionState.prepareForExecution.execute(sparkPlan) /** Internal version of the RDD. Avoids copies and has no schema */ lazy val toRdd: RDD[InternalRow] = executedPlan.execute() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index edaf3b36aa..cbde777d98 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.optimizer._ +import org.apache.spark.sql.ExperimentalMethods +import org.apache.spark.sql.catalyst.optimizer.Optimizer -class SparkOptimizer(val sqlContext: SQLContext) - extends Optimizer { - override def batches: Seq[Batch] = super.batches :+ Batch( - "User Provided Optimizers", FixedPoint(100), sqlContext.experimental.extraOptimizations: _*) +class SparkOptimizer(experimentalMethods: ExperimentalMethods) extends Optimizer { + override def batches: Seq[Batch] = super.batches :+ Batch( + "User Provided Optimizers", FixedPoint(100), experimentalMethods.extraOptimizations: _*) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 292d366e72..9da2c74c62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -21,14 +21,18 @@ import org.apache.spark.SparkContext import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, FileSourceStrategy} +import org.apache.spark.sql.internal.SQLConf -class SparkPlanner(val sqlContext: SQLContext) extends SparkStrategies { - val sparkContext: SparkContext = sqlContext.sparkContext +class SparkPlanner( + val sparkContext: SparkContext, + val conf: SQLConf, + val experimentalMethods: ExperimentalMethods) + extends SparkStrategies { - def numPartitions: Int = sqlContext.conf.numShufflePartitions + def numPartitions: Int = conf.numShufflePartitions def strategies: Seq[Strategy] = - sqlContext.experimental.extraStrategies ++ ( + experimentalMethods.extraStrategies ++ ( FileSourceStrategy :: DataSourceStrategy :: DDLStrategy :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 6352c48c76..113cf9ae2f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -80,8 +80,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { */ object CanBroadcast { def unapply(plan: LogicalPlan): Option[LogicalPlan] = { - if (sqlContext.conf.autoBroadcastJoinThreshold > 0 && - plan.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold) { + if (conf.autoBroadcastJoinThreshold > 0 && + plan.statistics.sizeInBytes <= conf.autoBroadcastJoinThreshold) { Some(plan) } else { None diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 8fb4705581..81676d3ebb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution import org.apache.spark.broadcast import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -29,6 +28,7 @@ import org.apache.spark.sql.catalyst.util.toCommentSafeString import org.apache.spark.sql.execution.aggregate.TungstenAggregate import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin} import org.apache.spark.sql.execution.metric.LongSQLMetricValue +import org.apache.spark.sql.internal.SQLConf /** * An interface for those physical operators that support codegen. @@ -427,7 +427,7 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup /** * Find the chained plans that support codegen, collapse them together as WholeStageCodegen. */ -private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Rule[SparkPlan] { +case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] { private def supportCodegen(e: Expression): Boolean = e match { case e: LeafExpression => true @@ -472,7 +472,7 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru } def apply(plan: SparkPlan): SparkPlan = { - if (sqlContext.conf.wholeStageEnabled) { + if (conf.wholeStageEnabled) { insertWholeStageCodegen(plan) } else { plan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 6e36a15a6d..e711797c1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -358,13 +358,14 @@ case class ShowFunctions(db: Option[String], pattern: Option[String]) extends Ru case Some(p) => try { val regex = java.util.regex.Pattern.compile(p) - sqlContext.functionRegistry.listFunction().filter(regex.matcher(_).matches()).map(Row(_)) + sqlContext.sessionState.functionRegistry.listFunction() + .filter(regex.matcher(_).matches()).map(Row(_)) } catch { // probably will failed in the regex that user provided, then returns empty row. case _: Throwable => Seq.empty[Row] } case None => - sqlContext.functionRegistry.listFunction().map(Row(_)) + sqlContext.sessionState.functionRegistry.listFunction().map(Row(_)) } } @@ -395,7 +396,7 @@ case class DescribeFunction( } override def run(sqlContext: SQLContext): Seq[Row] = { - sqlContext.functionRegistry.lookupFunction(functionName) match { + sqlContext.sessionState.functionRegistry.lookupFunction(functionName) match { case Some(info) => val result = Row(s"Function: ${info.getName}") :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 709a424636..4864db7f2a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.execution.exchange -import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ +import org.apache.spark.sql.internal.SQLConf /** * Ensures that the [[org.apache.spark.sql.catalyst.plans.physical.Partitioning Partitioning]] @@ -30,15 +30,15 @@ import org.apache.spark.sql.execution._ * each operator by inserting [[ShuffleExchange]] Operators where required. Also ensure that the * input partition ordering requirements are met. */ -private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[SparkPlan] { - private def defaultNumPreShufflePartitions: Int = sqlContext.conf.numShufflePartitions +case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { + private def defaultNumPreShufflePartitions: Int = conf.numShufflePartitions - private def targetPostShuffleInputSize: Long = sqlContext.conf.targetPostShuffleInputSize + private def targetPostShuffleInputSize: Long = conf.targetPostShuffleInputSize - private def adaptiveExecutionEnabled: Boolean = sqlContext.conf.adaptiveExecutionEnabled + private def adaptiveExecutionEnabled: Boolean = conf.adaptiveExecutionEnabled private def minNumPostShufflePartitions: Option[Int] = { - val minNumPostShufflePartitions = sqlContext.conf.minNumPostShufflePartitions + val minNumPostShufflePartitions = conf.minNumPostShufflePartitions if (minNumPostShufflePartitions > 0) Some(minNumPostShufflePartitions) else None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala index 12513e9106..9eaadea1b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala @@ -22,11 +22,11 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.broadcast import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType /** @@ -64,10 +64,10 @@ case class ReusedExchange(override val output: Seq[Attribute], child: Exchange) * Find out duplicated exchanges in the spark plan, then use the same exchange for all the * references. */ -private[sql] case class ReuseExchange(sqlContext: SQLContext) extends Rule[SparkPlan] { +case class ReuseExchange(conf: SQLConf) extends Rule[SparkPlan] { def apply(plan: SparkPlan): SparkPlan = { - if (!sqlContext.conf.exchangeReuseEnabled) { + if (!conf.exchangeReuseEnabled) { return plan } // Build a hash map using schema of exchanges to avoid O(N*N) sameResult calls. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index e6d7480b04..0d580703f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -17,12 +17,12 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.{expressions, InternalRow} import org.apache.spark.sql.catalyst.expressions.{ExprId, Literal, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SessionState import org.apache.spark.sql.types.DataType /** @@ -62,12 +62,12 @@ case class ScalarSubquery( /** * Convert the subquery from logical plan into executed plan. */ -case class PlanSubqueries(sqlContext: SQLContext) extends Rule[SparkPlan] { +case class PlanSubqueries(sessionState: SessionState) extends Rule[SparkPlan] { def apply(plan: SparkPlan): SparkPlan = { plan.transformAllExpressions { case subquery: expressions.ScalarSubquery => - val sparkPlan = sqlContext.planner.plan(ReturnAnswer(subquery.query)).next() - val executedPlan = sqlContext.prepareForExecution.execute(sparkPlan) + val sparkPlan = sessionState.planner.plan(ReturnAnswer(subquery.query)).next() + val executedPlan = sessionState.prepareForExecution.execute(sparkPlan) ScalarSubquery(executedPlan, subquery.exprId) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 326c1e5a7c..dd4aa9e93a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1161,7 +1161,7 @@ object functions { * @group normal_funcs */ def expr(expr: String): Column = { - val parser = SQLContext.getActive().map(_.sqlParser).getOrElse(new CatalystQl()) + val parser = SQLContext.getActive().map(_.sessionState.sqlParser).getOrElse(new CatalystQl()) Column(parser.parseExpression(expr)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index 98ada4d58a..e6be0ab3bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.internal -import org.apache.spark.sql.{ContinuousQueryManager, SQLContext, UDFRegistration} +import org.apache.spark.sql.{ContinuousQueryManager, ExperimentalMethods, SQLContext, UDFRegistration} import org.apache.spark.sql.catalyst.analysis.{Analyzer, Catalog, FunctionRegistry, SimpleCatalog} import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.parser.ParserInterface @@ -40,6 +40,8 @@ private[sql] class SessionState(ctx: SQLContext) { */ lazy val conf = new SQLConf + lazy val experimentalMethods = new ExperimentalMethods + /** * Internal catalog for managing table and database states. */ @@ -73,7 +75,7 @@ private[sql] class SessionState(ctx: SQLContext) { /** * Logical query plan optimizer. */ - lazy val optimizer: Optimizer = new SparkOptimizer(ctx) + lazy val optimizer: Optimizer = new SparkOptimizer(experimentalMethods) /** * Parser that extracts expressions, plans, table identifiers etc. from SQL texts. @@ -83,7 +85,7 @@ private[sql] class SessionState(ctx: SQLContext) { /** * Planner that converts optimized logical plans to physical plans. */ - lazy val planner: SparkPlanner = new SparkPlanner(ctx) + lazy val planner: SparkPlanner = new SparkPlanner(ctx.sparkContext, conf, experimentalMethods) /** * Prepares a planned [[SparkPlan]] for execution by inserting shuffle operations and internal @@ -91,10 +93,10 @@ private[sql] class SessionState(ctx: SQLContext) { */ lazy val prepareForExecution = new RuleExecutor[SparkPlan] { override val batches: Seq[Batch] = Seq( - Batch("Subquery", Once, PlanSubqueries(ctx)), - Batch("Add exchange", Once, EnsureRequirements(ctx)), - Batch("Whole stage codegen", Once, CollapseCodegenStages(ctx)), - Batch("Reuse duplicated exchanges", Once, ReuseExchange(ctx)) + Batch("Subquery", Once, PlanSubqueries(SessionState.this)), + Batch("Add exchange", Once, EnsureRequirements(conf)), + Batch("Whole stage codegen", Once, CollapseCodegenStages(conf)), + Batch("Reuse duplicated exchanges", Once, ReuseExchange(conf)) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 2bd29ef19b..50647c2840 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -37,7 +37,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.optimizedPlan - val planned = sqlContext.planner.EquiJoinSelection(join) + val planned = sqlContext.sessionState.planner.EquiJoinSelection(join) assert(planned.size === 1) } @@ -139,7 +139,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.optimizedPlan - val planned = sqlContext.planner.EquiJoinSelection(join) + val planned = sqlContext.sessionState.planner.EquiJoinSelection(join) assert(planned.size === 1) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala index ec19d97d8c..2ad92b52c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala @@ -76,6 +76,6 @@ class SQLContextSuite extends SparkFunSuite with SharedSparkContext{ test("Catalyst optimization passes are modifiable at runtime") { val sqlContext = SQLContext.getOrCreate(sc) sqlContext.experimental.extraOptimizations = Seq(DummyRule) - assert(sqlContext.optimizer.batches.flatMap(_.rules).contains(DummyRule)) + assert(sqlContext.sessionState.optimizer.batches.flatMap(_.rules).contains(DummyRule)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 98d0008489..836fb1ce85 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -54,7 +54,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("show functions") { def getFunctions(pattern: String): Seq[Row] = { val regex = java.util.regex.Pattern.compile(pattern) - sqlContext.functionRegistry.listFunction().filter(regex.matcher(_).matches()).map(Row(_)) + sqlContext.sessionState.functionRegistry.listFunction() + .filter(regex.matcher(_).matches()).map(Row(_)) } checkAnswer(sql("SHOW functions"), getFunctions(".*")) Seq("^c.*", ".*e$", "log.*", ".*date.*").foreach { pattern => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index ab0a7ff628..88fbcda296 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -37,7 +37,7 @@ class PlannerSuite extends SharedSQLContext { setupTestData() private def testPartialAggregationPlan(query: LogicalPlan): Unit = { - val planner = sqlContext.planner + val planner = sqlContext.sessionState.planner import planner._ val plannedOption = Aggregation(query).headOption val planned = @@ -294,7 +294,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildDistribution = Seq(distribution, distribution), requiredChildOrdering = Seq(Seq.empty, Seq.empty) ) - val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchange => true }.isEmpty) { fail(s"Exchange should have been added:\n$outputPlan") @@ -314,7 +314,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildDistribution = Seq(distribution, distribution), requiredChildOrdering = Seq(Seq.empty, Seq.empty) ) - val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) } @@ -332,7 +332,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildDistribution = Seq(distribution, distribution), requiredChildOrdering = Seq(Seq.empty, Seq.empty) ) - val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchange => true }.isEmpty) { fail(s"Exchange should have been added:\n$outputPlan") @@ -352,7 +352,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildDistribution = Seq(distribution, distribution), requiredChildOrdering = Seq(Seq.empty, Seq.empty) ) - val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchange => true }.nonEmpty) { fail(s"Exchange should not have been added:\n$outputPlan") @@ -375,7 +375,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildDistribution = Seq(distribution, distribution), requiredChildOrdering = Seq(outputOrdering, outputOrdering) ) - val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchange => true }.nonEmpty) { fail(s"No Exchanges should have been added:\n$outputPlan") @@ -391,7 +391,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildOrdering = Seq(Seq(orderingB)), requiredChildDistribution = Seq(UnspecifiedDistribution) ) - val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case s: Sort => true }.isEmpty) { fail(s"Sort should have been added:\n$outputPlan") @@ -407,7 +407,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildOrdering = Seq(Seq(orderingA)), requiredChildDistribution = Seq(UnspecifiedDistribution) ) - val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case s: Sort => true }.nonEmpty) { fail(s"No sorts should have been added:\n$outputPlan") @@ -424,7 +424,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildOrdering = Seq(Seq(orderingA, orderingB)), requiredChildDistribution = Seq(UnspecifiedDistribution) ) - val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case s: Sort => true }.isEmpty) { fail(s"Sort should have been added:\n$outputPlan") @@ -443,7 +443,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildOrdering = Seq(Seq.empty)), None) - val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchange => true }.size == 2) { fail(s"Topmost Exchange should have been eliminated:\n$outputPlan") @@ -463,7 +463,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildOrdering = Seq(Seq.empty)), None) - val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchange => true }.size == 1) { fail(s"Topmost Exchange should not have been eliminated:\n$outputPlan") @@ -491,7 +491,7 @@ class PlannerSuite extends SharedSQLContext { shuffle, shuffle) - val outputPlan = ReuseExchange(sqlContext).apply(inputPlan) + val outputPlan = ReuseExchange(sqlContext.sessionState.conf).apply(inputPlan) if (outputPlan.collect { case e: ReusedExchange => true }.size != 1) { fail(s"Should re-use the shuffle:\n$outputPlan") } @@ -507,7 +507,7 @@ class PlannerSuite extends SharedSQLContext { ShuffleExchange(finalPartitioning, inputPlan), ShuffleExchange(finalPartitioning, inputPlan)) - val outputPlan2 = ReuseExchange(sqlContext).apply(inputPlan2) + val outputPlan2 = ReuseExchange(sqlContext.sessionState.conf).apply(inputPlan2) if (outputPlan2.collect { case e: ReusedExchange => true }.size != 2) { fail(s"Should re-use the two shuffles:\n$outputPlan2") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index aa928cfc80..ed0d3f56e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -233,7 +233,7 @@ object SparkPlanTest { private def executePlan(outputPlan: SparkPlan, sqlContext: SQLContext): Seq[Row] = { // A very simple resolver to make writing tests easier. In contrast to the real resolver // this is always case sensitive and does not try to handle scoping or complex type resolution. - val resolvedPlan = sqlContext.prepareForExecution.execute( + val resolvedPlan = sqlContext.sessionState.prepareForExecution.execute( outputPlan transform { case plan: SparkPlan => val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index a256ee95a1..6d5b777733 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -63,7 +63,8 @@ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { // Comparison at the end is for broadcast left semi join val joinExpression = df1("key") === df2("key") && df1("value") > df2("value") val df3 = df1.join(broadcast(df2), joinExpression, joinType) - val plan = EnsureRequirements(sqlContext).apply(df3.queryExecution.sparkPlan) + val plan = + EnsureRequirements(sqlContext.sessionState.conf).apply(df3.queryExecution.sparkPlan) assert(plan.collect { case p: T => p }.size === 1) plan.executeCollect() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 7eb15249eb..eeb44404e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -98,7 +98,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { boundCondition, leftPlan, rightPlan) - EnsureRequirements(sqlContext).apply(broadcastJoin) + EnsureRequirements(sqlContext.sessionState.conf).apply(broadcastJoin) } def makeSortMergeJoin( @@ -109,7 +109,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { rightPlan: SparkPlan) = { val sortMergeJoin = joins.SortMergeJoin(leftKeys, rightKeys, boundCondition, leftPlan, rightPlan) - EnsureRequirements(sqlContext).apply(sortMergeJoin) + EnsureRequirements(sqlContext.sessionState.conf).apply(sortMergeJoin) } test(s"$testName using BroadcastHashJoin (build=left)") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 0d1c29fe57..4525486430 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -98,7 +98,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - EnsureRequirements(sqlContext).apply( + EnsureRequirements(sqlContext.sessionState.conf).apply( SortMergeOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)), expectedAnswer.map(Row.fromTuple), sortAnswers = true) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala index bc341db557..d8c9564f1e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala @@ -76,7 +76,7 @@ class SemiJoinSuite extends SparkPlanTest with SharedSQLContext { extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - EnsureRequirements(left.sqlContext).apply( + EnsureRequirements(left.sqlContext.sessionState.conf).apply( LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)), expectedAnswer.map(Row.fromTuple), sortAnswers = true) -- cgit v1.2.3