aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2016-03-14 23:58:57 -0700
committerReynold Xin <rxin@databricks.com>2016-03-14 23:58:57 -0700
commit276c2d51a3bbe2531763a11580adfec7e39fdd58 (patch)
treee2e6d063986795847167109e6d58464d1b376a39 /sql/core/src
parenta51f877b5dc56b7bb9ef95044a50024c6b64718e (diff)
downloadspark-276c2d51a3bbe2531763a11580adfec7e39fdd58.tar.gz
spark-276c2d51a3bbe2531763a11580adfec7e39fdd58.tar.bz2
spark-276c2d51a3bbe2531763a11580adfec7e39fdd58.zip
[SPARK-13890][SQL] Remove some internal classes' dependency on SQLContext
## What changes were proposed in this pull request? In general it is better for internal classes to not depend on the external class (in this case SQLContext) to reduce coupling between user-facing APIs and the internal implementations. This patch removes SQLContext dependency from some internal classes such as SparkPlanner, SparkOptimizer. As part of this patch, I also removed the following internal methods from SQLContext: ``` protected[sql] def functionRegistry: FunctionRegistry protected[sql] def optimizer: Optimizer protected[sql] def sqlParser: ParserInterface protected[sql] def planner: SparkPlanner protected[sql] def continuousQueryManager protected[sql] def prepareForExecution: RuleExecutor[SparkPlan] ``` ## How was this patch tested? Existing unit/integration tests. Author: Reynold Xin <rxin@databricks.com> Closes #11712 from rxin/sqlContext-planner.
Diffstat (limited to 'sql/core/src')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala23
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala16
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala26
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala2
25 files changed, 89 insertions, 89 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 52b567ea25..76b8d71ac9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -394,7 +394,8 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
*/
def table(tableName: String): DataFrame = {
Dataset.newDataFrame(sqlContext,
- sqlContext.catalog.lookupRelation(sqlContext.sqlParser.parseTableIdentifier(tableName)))
+ sqlContext.catalog.lookupRelation(
+ sqlContext.sessionState.sqlParser.parseTableIdentifier(tableName)))
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index 3349b8421b..de87f4d7c2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -242,7 +242,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
options = extraOptions.toMap,
partitionColumns = normalizedParCols.getOrElse(Nil))
- df.sqlContext.continuousQueryManager.startQuery(
+ df.sqlContext.sessionState.continuousQueryManager.startQuery(
extraOptions.getOrElse("queryName", StreamExecution.nextName), df, dataSource.createSink())
}
@@ -255,7 +255,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @since 1.4.0
*/
def insertInto(tableName: String): Unit = {
- insertInto(df.sqlContext.sqlParser.parseTableIdentifier(tableName))
+ insertInto(df.sqlContext.sessionState.sqlParser.parseTableIdentifier(tableName))
}
private def insertInto(tableIdent: TableIdentifier): Unit = {
@@ -354,7 +354,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @since 1.4.0
*/
def saveAsTable(tableName: String): Unit = {
- saveAsTable(df.sqlContext.sqlParser.parseTableIdentifier(tableName))
+ saveAsTable(df.sqlContext.sessionState.sqlParser.parseTableIdentifier(tableName))
}
private def saveAsTable(tableIdent: TableIdentifier): Unit = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index b5079cf276..ef239a1e2f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -818,7 +818,7 @@ class Dataset[T] private[sql](
@scala.annotation.varargs
def selectExpr(exprs: String*): DataFrame = {
select(exprs.map { expr =>
- Column(sqlContext.sqlParser.parseExpression(expr))
+ Column(sqlContext.sessionState.sqlParser.parseExpression(expr))
}: _*)
}
@@ -919,7 +919,7 @@ class Dataset[T] private[sql](
* @since 1.3.0
*/
def filter(conditionExpr: String): Dataset[T] = {
- filter(Column(sqlContext.sqlParser.parseExpression(conditionExpr)))
+ filter(Column(sqlContext.sessionState.sqlParser.parseExpression(conditionExpr)))
}
/**
@@ -943,7 +943,7 @@ class Dataset[T] private[sql](
* @since 1.5.0
*/
def where(conditionExpr: String): Dataset[T] = {
- filter(Column(sqlContext.sqlParser.parseExpression(conditionExpr)))
+ filter(Column(sqlContext.sessionState.sqlParser.parseExpression(conditionExpr)))
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala
index deed45d273..d7cd84fd24 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala
@@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
* @since 1.3.0
*/
@Experimental
-class ExperimentalMethods protected[sql](sqlContext: SQLContext) {
+class ExperimentalMethods private[sql]() {
/**
* Allows extra strategies to be injected into the query planner at runtime. Note this API
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 36fe57f78b..0f5d1c8cab 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -121,14 +121,7 @@ class SQLContext private[sql](
protected[sql] lazy val sessionState: SessionState = new SessionState(self)
protected[sql] def conf: SQLConf = sessionState.conf
protected[sql] def catalog: Catalog = sessionState.catalog
- protected[sql] def functionRegistry: FunctionRegistry = sessionState.functionRegistry
protected[sql] def analyzer: Analyzer = sessionState.analyzer
- protected[sql] def optimizer: Optimizer = sessionState.optimizer
- protected[sql] def sqlParser: ParserInterface = sessionState.sqlParser
- protected[sql] def planner: SparkPlanner = sessionState.planner
- protected[sql] def continuousQueryManager = sessionState.continuousQueryManager
- protected[sql] def prepareForExecution: RuleExecutor[SparkPlan] =
- sessionState.prepareForExecution
/**
* An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s
@@ -197,7 +190,7 @@ class SQLContext private[sql](
*/
def getAllConfs: immutable.Map[String, String] = conf.getAllConfs
- protected[sql] def parseSql(sql: String): LogicalPlan = sqlParser.parsePlan(sql)
+ protected[sql] def parseSql(sql: String): LogicalPlan = sessionState.sqlParser.parsePlan(sql)
protected[sql] def executeSql(sql: String): QueryExecution = executePlan(parseSql(sql))
@@ -244,7 +237,7 @@ class SQLContext private[sql](
*/
@Experimental
@transient
- val experimental: ExperimentalMethods = new ExperimentalMethods(this)
+ def experimental: ExperimentalMethods = sessionState.experimentalMethods
/**
* :: Experimental ::
@@ -641,7 +634,7 @@ class SQLContext private[sql](
tableName: String,
source: String,
options: Map[String, String]): DataFrame = {
- val tableIdent = sqlParser.parseTableIdentifier(tableName)
+ val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName)
val cmd =
CreateTableUsing(
tableIdent,
@@ -687,7 +680,7 @@ class SQLContext private[sql](
source: String,
schema: StructType,
options: Map[String, String]): DataFrame = {
- val tableIdent = sqlParser.parseTableIdentifier(tableName)
+ val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName)
val cmd =
CreateTableUsing(
tableIdent,
@@ -706,7 +699,7 @@ class SQLContext private[sql](
* only during the lifetime of this instance of SQLContext.
*/
private[sql] def registerDataFrameAsTable(df: DataFrame, tableName: String): Unit = {
- catalog.registerTable(sqlParser.parseTableIdentifier(tableName), df.logicalPlan)
+ catalog.registerTable(sessionState.sqlParser.parseTableIdentifier(tableName), df.logicalPlan)
}
/**
@@ -800,7 +793,7 @@ class SQLContext private[sql](
* @since 1.3.0
*/
def table(tableName: String): DataFrame = {
- table(sqlParser.parseTableIdentifier(tableName))
+ table(sessionState.sqlParser.parseTableIdentifier(tableName))
}
private def table(tableIdent: TableIdentifier): DataFrame = {
@@ -837,9 +830,7 @@ class SQLContext private[sql](
*
* @since 2.0.0
*/
- def streams: ContinuousQueryManager = {
- continuousQueryManager
- }
+ def streams: ContinuousQueryManager = sessionState.continuousQueryManager
/**
* Returns the names of tables in the current database as an array.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index 9e60c1cd61..5b4254f741 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -45,16 +45,16 @@ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) {
sqlContext.cacheManager.useCachedData(analyzed)
}
- lazy val optimizedPlan: LogicalPlan = sqlContext.optimizer.execute(withCachedData)
+ lazy val optimizedPlan: LogicalPlan = sqlContext.sessionState.optimizer.execute(withCachedData)
lazy val sparkPlan: SparkPlan = {
SQLContext.setActive(sqlContext)
- sqlContext.planner.plan(ReturnAnswer(optimizedPlan)).next()
+ sqlContext.sessionState.planner.plan(ReturnAnswer(optimizedPlan)).next()
}
// executedPlan should not be used to initialize any SparkPlan. It should be
// only used for execution.
- lazy val executedPlan: SparkPlan = sqlContext.prepareForExecution.execute(sparkPlan)
+ lazy val executedPlan: SparkPlan = sqlContext.sessionState.prepareForExecution.execute(sparkPlan)
/** Internal version of the RDD. Avoids copies and has no schema */
lazy val toRdd: RDD[InternalRow] = executedPlan.execute()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
index edaf3b36aa..cbde777d98 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
@@ -17,11 +17,10 @@
package org.apache.spark.sql.execution
-import org.apache.spark.sql.SQLContext
-import org.apache.spark.sql.catalyst.optimizer._
+import org.apache.spark.sql.ExperimentalMethods
+import org.apache.spark.sql.catalyst.optimizer.Optimizer
-class SparkOptimizer(val sqlContext: SQLContext)
- extends Optimizer {
- override def batches: Seq[Batch] = super.batches :+ Batch(
- "User Provided Optimizers", FixedPoint(100), sqlContext.experimental.extraOptimizations: _*)
+class SparkOptimizer(experimentalMethods: ExperimentalMethods) extends Optimizer {
+ override def batches: Seq[Batch] = super.batches :+ Batch(
+ "User Provided Optimizers", FixedPoint(100), experimentalMethods.extraOptimizations: _*)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
index 292d366e72..9da2c74c62 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
@@ -21,14 +21,18 @@ import org.apache.spark.SparkContext
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, FileSourceStrategy}
+import org.apache.spark.sql.internal.SQLConf
-class SparkPlanner(val sqlContext: SQLContext) extends SparkStrategies {
- val sparkContext: SparkContext = sqlContext.sparkContext
+class SparkPlanner(
+ val sparkContext: SparkContext,
+ val conf: SQLConf,
+ val experimentalMethods: ExperimentalMethods)
+ extends SparkStrategies {
- def numPartitions: Int = sqlContext.conf.numShufflePartitions
+ def numPartitions: Int = conf.numShufflePartitions
def strategies: Seq[Strategy] =
- sqlContext.experimental.extraStrategies ++ (
+ experimentalMethods.extraStrategies ++ (
FileSourceStrategy ::
DataSourceStrategy ::
DDLStrategy ::
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 6352c48c76..113cf9ae2f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -80,8 +80,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
*/
object CanBroadcast {
def unapply(plan: LogicalPlan): Option[LogicalPlan] = {
- if (sqlContext.conf.autoBroadcastJoinThreshold > 0 &&
- plan.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold) {
+ if (conf.autoBroadcastJoinThreshold > 0 &&
+ plan.statistics.sizeInBytes <= conf.autoBroadcastJoinThreshold) {
Some(plan)
} else {
None
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index 8fb4705581..81676d3ebb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql.execution
import org.apache.spark.broadcast
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
@@ -29,6 +28,7 @@ import org.apache.spark.sql.catalyst.util.toCommentSafeString
import org.apache.spark.sql.execution.aggregate.TungstenAggregate
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin}
import org.apache.spark.sql.execution.metric.LongSQLMetricValue
+import org.apache.spark.sql.internal.SQLConf
/**
* An interface for those physical operators that support codegen.
@@ -427,7 +427,7 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup
/**
* Find the chained plans that support codegen, collapse them together as WholeStageCodegen.
*/
-private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Rule[SparkPlan] {
+case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] {
private def supportCodegen(e: Expression): Boolean = e match {
case e: LeafExpression => true
@@ -472,7 +472,7 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru
}
def apply(plan: SparkPlan): SparkPlan = {
- if (sqlContext.conf.wholeStageEnabled) {
+ if (conf.wholeStageEnabled) {
insertWholeStageCodegen(plan)
} else {
plan
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala
index 6e36a15a6d..e711797c1b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala
@@ -358,13 +358,14 @@ case class ShowFunctions(db: Option[String], pattern: Option[String]) extends Ru
case Some(p) =>
try {
val regex = java.util.regex.Pattern.compile(p)
- sqlContext.functionRegistry.listFunction().filter(regex.matcher(_).matches()).map(Row(_))
+ sqlContext.sessionState.functionRegistry.listFunction()
+ .filter(regex.matcher(_).matches()).map(Row(_))
} catch {
// probably will failed in the regex that user provided, then returns empty row.
case _: Throwable => Seq.empty[Row]
}
case None =>
- sqlContext.functionRegistry.listFunction().map(Row(_))
+ sqlContext.sessionState.functionRegistry.listFunction().map(Row(_))
}
}
@@ -395,7 +396,7 @@ case class DescribeFunction(
}
override def run(sqlContext: SQLContext): Seq[Row] = {
- sqlContext.functionRegistry.lookupFunction(functionName) match {
+ sqlContext.sessionState.functionRegistry.lookupFunction(functionName) match {
case Some(info) =>
val result =
Row(s"Function: ${info.getName}") ::
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
index 709a424636..4864db7f2a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
@@ -17,11 +17,11 @@
package org.apache.spark.sql.execution.exchange
-import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
+import org.apache.spark.sql.internal.SQLConf
/**
* Ensures that the [[org.apache.spark.sql.catalyst.plans.physical.Partitioning Partitioning]]
@@ -30,15 +30,15 @@ import org.apache.spark.sql.execution._
* each operator by inserting [[ShuffleExchange]] Operators where required. Also ensure that the
* input partition ordering requirements are met.
*/
-private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[SparkPlan] {
- private def defaultNumPreShufflePartitions: Int = sqlContext.conf.numShufflePartitions
+case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
+ private def defaultNumPreShufflePartitions: Int = conf.numShufflePartitions
- private def targetPostShuffleInputSize: Long = sqlContext.conf.targetPostShuffleInputSize
+ private def targetPostShuffleInputSize: Long = conf.targetPostShuffleInputSize
- private def adaptiveExecutionEnabled: Boolean = sqlContext.conf.adaptiveExecutionEnabled
+ private def adaptiveExecutionEnabled: Boolean = conf.adaptiveExecutionEnabled
private def minNumPostShufflePartitions: Option[Int] = {
- val minNumPostShufflePartitions = sqlContext.conf.minNumPostShufflePartitions
+ val minNumPostShufflePartitions = conf.minNumPostShufflePartitions
if (minNumPostShufflePartitions > 0) Some(minNumPostShufflePartitions) else None
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala
index 12513e9106..9eaadea1b1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala
@@ -22,11 +22,11 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.broadcast
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode}
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
/**
@@ -64,10 +64,10 @@ case class ReusedExchange(override val output: Seq[Attribute], child: Exchange)
* Find out duplicated exchanges in the spark plan, then use the same exchange for all the
* references.
*/
-private[sql] case class ReuseExchange(sqlContext: SQLContext) extends Rule[SparkPlan] {
+case class ReuseExchange(conf: SQLConf) extends Rule[SparkPlan] {
def apply(plan: SparkPlan): SparkPlan = {
- if (!sqlContext.conf.exchangeReuseEnabled) {
+ if (!conf.exchangeReuseEnabled) {
return plan
}
// Build a hash map using schema of exchanges to avoid O(N*N) sameResult calls.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
index e6d7480b04..0d580703f5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
@@ -17,12 +17,12 @@
package org.apache.spark.sql.execution
-import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.{expressions, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{ExprId, Literal, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.internal.SessionState
import org.apache.spark.sql.types.DataType
/**
@@ -62,12 +62,12 @@ case class ScalarSubquery(
/**
* Convert the subquery from logical plan into executed plan.
*/
-case class PlanSubqueries(sqlContext: SQLContext) extends Rule[SparkPlan] {
+case class PlanSubqueries(sessionState: SessionState) extends Rule[SparkPlan] {
def apply(plan: SparkPlan): SparkPlan = {
plan.transformAllExpressions {
case subquery: expressions.ScalarSubquery =>
- val sparkPlan = sqlContext.planner.plan(ReturnAnswer(subquery.query)).next()
- val executedPlan = sqlContext.prepareForExecution.execute(sparkPlan)
+ val sparkPlan = sessionState.planner.plan(ReturnAnswer(subquery.query)).next()
+ val executedPlan = sessionState.prepareForExecution.execute(sparkPlan)
ScalarSubquery(executedPlan, subquery.exprId)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 326c1e5a7c..dd4aa9e93a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1161,7 +1161,7 @@ object functions {
* @group normal_funcs
*/
def expr(expr: String): Column = {
- val parser = SQLContext.getActive().map(_.sqlParser).getOrElse(new CatalystQl())
+ val parser = SQLContext.getActive().map(_.sessionState.sqlParser).getOrElse(new CatalystQl())
Column(parser.parseExpression(expr))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
index 98ada4d58a..e6be0ab3bc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.internal
-import org.apache.spark.sql.{ContinuousQueryManager, SQLContext, UDFRegistration}
+import org.apache.spark.sql.{ContinuousQueryManager, ExperimentalMethods, SQLContext, UDFRegistration}
import org.apache.spark.sql.catalyst.analysis.{Analyzer, Catalog, FunctionRegistry, SimpleCatalog}
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.catalyst.parser.ParserInterface
@@ -40,6 +40,8 @@ private[sql] class SessionState(ctx: SQLContext) {
*/
lazy val conf = new SQLConf
+ lazy val experimentalMethods = new ExperimentalMethods
+
/**
* Internal catalog for managing table and database states.
*/
@@ -73,7 +75,7 @@ private[sql] class SessionState(ctx: SQLContext) {
/**
* Logical query plan optimizer.
*/
- lazy val optimizer: Optimizer = new SparkOptimizer(ctx)
+ lazy val optimizer: Optimizer = new SparkOptimizer(experimentalMethods)
/**
* Parser that extracts expressions, plans, table identifiers etc. from SQL texts.
@@ -83,7 +85,7 @@ private[sql] class SessionState(ctx: SQLContext) {
/**
* Planner that converts optimized logical plans to physical plans.
*/
- lazy val planner: SparkPlanner = new SparkPlanner(ctx)
+ lazy val planner: SparkPlanner = new SparkPlanner(ctx.sparkContext, conf, experimentalMethods)
/**
* Prepares a planned [[SparkPlan]] for execution by inserting shuffle operations and internal
@@ -91,10 +93,10 @@ private[sql] class SessionState(ctx: SQLContext) {
*/
lazy val prepareForExecution = new RuleExecutor[SparkPlan] {
override val batches: Seq[Batch] = Seq(
- Batch("Subquery", Once, PlanSubqueries(ctx)),
- Batch("Add exchange", Once, EnsureRequirements(ctx)),
- Batch("Whole stage codegen", Once, CollapseCodegenStages(ctx)),
- Batch("Reuse duplicated exchanges", Once, ReuseExchange(ctx))
+ Batch("Subquery", Once, PlanSubqueries(SessionState.this)),
+ Batch("Add exchange", Once, EnsureRequirements(conf)),
+ Batch("Whole stage codegen", Once, CollapseCodegenStages(conf)),
+ Batch("Reuse duplicated exchanges", Once, ReuseExchange(conf))
)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 2bd29ef19b..50647c2840 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -37,7 +37,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
val x = testData2.as("x")
val y = testData2.as("y")
val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.optimizedPlan
- val planned = sqlContext.planner.EquiJoinSelection(join)
+ val planned = sqlContext.sessionState.planner.EquiJoinSelection(join)
assert(planned.size === 1)
}
@@ -139,7 +139,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
val x = testData2.as("x")
val y = testData2.as("y")
val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.optimizedPlan
- val planned = sqlContext.planner.EquiJoinSelection(join)
+ val planned = sqlContext.sessionState.planner.EquiJoinSelection(join)
assert(planned.size === 1)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala
index ec19d97d8c..2ad92b52c4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala
@@ -76,6 +76,6 @@ class SQLContextSuite extends SparkFunSuite with SharedSparkContext{
test("Catalyst optimization passes are modifiable at runtime") {
val sqlContext = SQLContext.getOrCreate(sc)
sqlContext.experimental.extraOptimizations = Seq(DummyRule)
- assert(sqlContext.optimizer.batches.flatMap(_.rules).contains(DummyRule))
+ assert(sqlContext.sessionState.optimizer.batches.flatMap(_.rules).contains(DummyRule))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 98d0008489..836fb1ce85 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -54,7 +54,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
test("show functions") {
def getFunctions(pattern: String): Seq[Row] = {
val regex = java.util.regex.Pattern.compile(pattern)
- sqlContext.functionRegistry.listFunction().filter(regex.matcher(_).matches()).map(Row(_))
+ sqlContext.sessionState.functionRegistry.listFunction()
+ .filter(regex.matcher(_).matches()).map(Row(_))
}
checkAnswer(sql("SHOW functions"), getFunctions(".*"))
Seq("^c.*", ".*e$", "log.*", ".*date.*").foreach { pattern =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index ab0a7ff628..88fbcda296 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -37,7 +37,7 @@ class PlannerSuite extends SharedSQLContext {
setupTestData()
private def testPartialAggregationPlan(query: LogicalPlan): Unit = {
- val planner = sqlContext.planner
+ val planner = sqlContext.sessionState.planner
import planner._
val plannedOption = Aggregation(query).headOption
val planned =
@@ -294,7 +294,7 @@ class PlannerSuite extends SharedSQLContext {
requiredChildDistribution = Seq(distribution, distribution),
requiredChildOrdering = Seq(Seq.empty, Seq.empty)
)
- val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan)
+ val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan)
assertDistributionRequirementsAreSatisfied(outputPlan)
if (outputPlan.collect { case e: ShuffleExchange => true }.isEmpty) {
fail(s"Exchange should have been added:\n$outputPlan")
@@ -314,7 +314,7 @@ class PlannerSuite extends SharedSQLContext {
requiredChildDistribution = Seq(distribution, distribution),
requiredChildOrdering = Seq(Seq.empty, Seq.empty)
)
- val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan)
+ val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan)
assertDistributionRequirementsAreSatisfied(outputPlan)
}
@@ -332,7 +332,7 @@ class PlannerSuite extends SharedSQLContext {
requiredChildDistribution = Seq(distribution, distribution),
requiredChildOrdering = Seq(Seq.empty, Seq.empty)
)
- val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan)
+ val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan)
assertDistributionRequirementsAreSatisfied(outputPlan)
if (outputPlan.collect { case e: ShuffleExchange => true }.isEmpty) {
fail(s"Exchange should have been added:\n$outputPlan")
@@ -352,7 +352,7 @@ class PlannerSuite extends SharedSQLContext {
requiredChildDistribution = Seq(distribution, distribution),
requiredChildOrdering = Seq(Seq.empty, Seq.empty)
)
- val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan)
+ val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan)
assertDistributionRequirementsAreSatisfied(outputPlan)
if (outputPlan.collect { case e: ShuffleExchange => true }.nonEmpty) {
fail(s"Exchange should not have been added:\n$outputPlan")
@@ -375,7 +375,7 @@ class PlannerSuite extends SharedSQLContext {
requiredChildDistribution = Seq(distribution, distribution),
requiredChildOrdering = Seq(outputOrdering, outputOrdering)
)
- val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan)
+ val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan)
assertDistributionRequirementsAreSatisfied(outputPlan)
if (outputPlan.collect { case e: ShuffleExchange => true }.nonEmpty) {
fail(s"No Exchanges should have been added:\n$outputPlan")
@@ -391,7 +391,7 @@ class PlannerSuite extends SharedSQLContext {
requiredChildOrdering = Seq(Seq(orderingB)),
requiredChildDistribution = Seq(UnspecifiedDistribution)
)
- val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan)
+ val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan)
assertDistributionRequirementsAreSatisfied(outputPlan)
if (outputPlan.collect { case s: Sort => true }.isEmpty) {
fail(s"Sort should have been added:\n$outputPlan")
@@ -407,7 +407,7 @@ class PlannerSuite extends SharedSQLContext {
requiredChildOrdering = Seq(Seq(orderingA)),
requiredChildDistribution = Seq(UnspecifiedDistribution)
)
- val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan)
+ val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan)
assertDistributionRequirementsAreSatisfied(outputPlan)
if (outputPlan.collect { case s: Sort => true }.nonEmpty) {
fail(s"No sorts should have been added:\n$outputPlan")
@@ -424,7 +424,7 @@ class PlannerSuite extends SharedSQLContext {
requiredChildOrdering = Seq(Seq(orderingA, orderingB)),
requiredChildDistribution = Seq(UnspecifiedDistribution)
)
- val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan)
+ val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan)
assertDistributionRequirementsAreSatisfied(outputPlan)
if (outputPlan.collect { case s: Sort => true }.isEmpty) {
fail(s"Sort should have been added:\n$outputPlan")
@@ -443,7 +443,7 @@ class PlannerSuite extends SharedSQLContext {
requiredChildOrdering = Seq(Seq.empty)),
None)
- val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan)
+ val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan)
assertDistributionRequirementsAreSatisfied(outputPlan)
if (outputPlan.collect { case e: ShuffleExchange => true }.size == 2) {
fail(s"Topmost Exchange should have been eliminated:\n$outputPlan")
@@ -463,7 +463,7 @@ class PlannerSuite extends SharedSQLContext {
requiredChildOrdering = Seq(Seq.empty)),
None)
- val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan)
+ val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan)
assertDistributionRequirementsAreSatisfied(outputPlan)
if (outputPlan.collect { case e: ShuffleExchange => true }.size == 1) {
fail(s"Topmost Exchange should not have been eliminated:\n$outputPlan")
@@ -491,7 +491,7 @@ class PlannerSuite extends SharedSQLContext {
shuffle,
shuffle)
- val outputPlan = ReuseExchange(sqlContext).apply(inputPlan)
+ val outputPlan = ReuseExchange(sqlContext.sessionState.conf).apply(inputPlan)
if (outputPlan.collect { case e: ReusedExchange => true }.size != 1) {
fail(s"Should re-use the shuffle:\n$outputPlan")
}
@@ -507,7 +507,7 @@ class PlannerSuite extends SharedSQLContext {
ShuffleExchange(finalPartitioning, inputPlan),
ShuffleExchange(finalPartitioning, inputPlan))
- val outputPlan2 = ReuseExchange(sqlContext).apply(inputPlan2)
+ val outputPlan2 = ReuseExchange(sqlContext.sessionState.conf).apply(inputPlan2)
if (outputPlan2.collect { case e: ReusedExchange => true }.size != 2) {
fail(s"Should re-use the two shuffles:\n$outputPlan2")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
index aa928cfc80..ed0d3f56e5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
@@ -233,7 +233,7 @@ object SparkPlanTest {
private def executePlan(outputPlan: SparkPlan, sqlContext: SQLContext): Seq[Row] = {
// A very simple resolver to make writing tests easier. In contrast to the real resolver
// this is always case sensitive and does not try to handle scoping or complex type resolution.
- val resolvedPlan = sqlContext.prepareForExecution.execute(
+ val resolvedPlan = sqlContext.sessionState.prepareForExecution.execute(
outputPlan transform {
case plan: SparkPlan =>
val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
index a256ee95a1..6d5b777733 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
@@ -63,7 +63,8 @@ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll {
// Comparison at the end is for broadcast left semi join
val joinExpression = df1("key") === df2("key") && df1("value") > df2("value")
val df3 = df1.join(broadcast(df2), joinExpression, joinType)
- val plan = EnsureRequirements(sqlContext).apply(df3.queryExecution.sparkPlan)
+ val plan =
+ EnsureRequirements(sqlContext.sessionState.conf).apply(df3.queryExecution.sparkPlan)
assert(plan.collect { case p: T => p }.size === 1)
plan.executeCollect()
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
index 7eb15249eb..eeb44404e9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
@@ -98,7 +98,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
boundCondition,
leftPlan,
rightPlan)
- EnsureRequirements(sqlContext).apply(broadcastJoin)
+ EnsureRequirements(sqlContext.sessionState.conf).apply(broadcastJoin)
}
def makeSortMergeJoin(
@@ -109,7 +109,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
rightPlan: SparkPlan) = {
val sortMergeJoin =
joins.SortMergeJoin(leftKeys, rightKeys, boundCondition, leftPlan, rightPlan)
- EnsureRequirements(sqlContext).apply(sortMergeJoin)
+ EnsureRequirements(sqlContext.sessionState.conf).apply(sortMergeJoin)
}
test(s"$testName using BroadcastHashJoin (build=left)") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
index 0d1c29fe57..4525486430 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
@@ -98,7 +98,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext {
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
- EnsureRequirements(sqlContext).apply(
+ EnsureRequirements(sqlContext.sessionState.conf).apply(
SortMergeOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
index bc341db557..d8c9564f1e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
@@ -76,7 +76,7 @@ class SemiJoinSuite extends SparkPlanTest with SharedSQLContext {
extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
- EnsureRequirements(left.sqlContext).apply(
+ EnsureRequirements(left.sqlContext.sessionState.conf).apply(
LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)