aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2016-05-19 21:53:26 -0700
committerReynold Xin <rxin@databricks.com>2016-05-19 21:53:26 -0700
commitf2ee0ed4b7ecb2855cc4928a9613a07d45446f4e (patch)
tree3c923b935bcf35219f158ed5a8ca34edfb7c9322
parent17591d90e6873f30a042112f56a1686726ccbd60 (diff)
downloadspark-f2ee0ed4b7ecb2855cc4928a9613a07d45446f4e.tar.gz
spark-f2ee0ed4b7ecb2855cc4928a9613a07d45446f4e.tar.bz2
spark-f2ee0ed4b7ecb2855cc4928a9613a07d45446f4e.zip
[SPARK-15075][SPARK-15345][SQL] Clean up SparkSession builder and propagate config options to existing sessions if specified
## What changes were proposed in this pull request? Currently SparkSession.Builder use SQLContext.getOrCreate. It should probably the the other way around, i.e. all the core logic goes in SparkSession, and SQLContext just calls that. This patch does that. This patch also makes sure config options specified in the builder are propagated to the existing (and of course the new) SparkSession. ## How was this patch tested? Updated tests to reflect the change, and also introduced a new SparkSessionBuilderSuite that should cover all the branches. Author: Reynold Xin <rxin@databricks.com> Closes #13200 from rxin/SPARK-15075.
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java2
-rw-r--r--python/pyspark/sql/context.py5
-rw-r--r--python/pyspark/sql/session.py17
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala124
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala155
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala13
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala19
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala13
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala100
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala93
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala16
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala14
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala18
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala4
-rw-r--r--sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala2
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala2
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala42
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala2
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala2
43 files changed, 367 insertions, 357 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java
index da623d1d15..7bda219243 100644
--- a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java
@@ -56,7 +56,7 @@ public class JavaDefaultReadWriteSuite extends SharedSparkSession {
} catch (IOException e) {
// expected
}
- instance.write().context(spark.wrapped()).overwrite().save(outputPath);
+ instance.write().context(spark.sqlContext()).overwrite().save(outputPath);
MyParams newInstance = MyParams.load(outputPath);
Assert.assertEquals("UID should match.", instance.uid(), newInstance.uid());
Assert.assertEquals("Params should be preserved.",
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index e8e60c6412..486733a390 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -34,7 +34,10 @@ __all__ = ["SQLContext", "HiveContext", "UDFRegistration"]
class SQLContext(object):
- """Wrapper around :class:`SparkSession`, the main entry point to Spark SQL functionality.
+ """The entry point for working with structured data (rows and columns) in Spark, in Spark 1.x.
+
+ As of Spark 2.0, this is replaced by :class:`SparkSession`. However, we are keeping the class
+ here for backward compatibility.
A SQLContext can be used create :class:`DataFrame`, register :class:`DataFrame` as
tables, execute SQL over tables, cache tables, and read parquet files.
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index 257a239c8d..0e04b88265 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -120,6 +120,8 @@ class SparkSession(object):
def appName(self, name):
"""Sets a name for the application, which will be shown in the Spark web UI.
+ If no application name is set, a randomly generated name will be used.
+
:param name: an application name
"""
return self.config("spark.app.name", name)
@@ -133,8 +135,17 @@ class SparkSession(object):
@since(2.0)
def getOrCreate(self):
- """Gets an existing :class:`SparkSession` or, if there is no existing one, creates a new
- one based on the options set in this builder.
+ """Gets an existing :class:`SparkSession` or, if there is no existing one, creates a
+ new one based on the options set in this builder.
+
+ This method first checks whether there is a valid thread-local SparkSession,
+ and if yes, return that one. It then checks whether there is a valid global
+ default SparkSession, and if yes, return that one. If no valid global default
+ SparkSession exists, the method creates a new SparkSession and assigns the
+ newly created SparkSession as the global default.
+
+ In case an existing SparkSession is returned, the config options specified
+ in this builder will be applied to the existing SparkSession.
"""
with self._lock:
from pyspark.conf import SparkConf
@@ -175,7 +186,7 @@ class SparkSession(object):
if jsparkSession is None:
jsparkSession = self._jvm.SparkSession(self._jsc.sc())
self._jsparkSession = jsparkSession
- self._jwrapped = self._jsparkSession.wrapped()
+ self._jwrapped = self._jsparkSession.sqlContext()
self._wrapped = SQLContext(self._sc, self, self._jwrapped)
_monkey_patch_RDD(self)
install_exception_handler()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 02dd6547a4..78a167eef2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -213,7 +213,7 @@ class Dataset[T] private[sql](
private implicit def classTag = unresolvedTEncoder.clsTag
// sqlContext must be val because a stable identifier is expected when you import implicits
- @transient lazy val sqlContext: SQLContext = sparkSession.wrapped
+ @transient lazy val sqlContext: SQLContext = sparkSession.sqlContext
protected[sql] def resolve(colName: String): NamedExpression = {
queryExecution.analyzed.resolveQuoted(colName, sparkSession.sessionState.analyzer.resolver)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index a3e2b49556..14d12d30bc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -19,25 +19,22 @@ package org.apache.spark.sql
import java.beans.BeanInfo
import java.util.Properties
-import java.util.concurrent.atomic.AtomicReference
import scala.collection.immutable
import scala.reflect.runtime.universe.TypeTag
-import org.apache.spark.{SparkConf, SparkContext, SparkException}
+import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.ConfigEntry
import org.apache.spark.rdd.RDD
-import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.command.ShowTablesCommand
-import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab}
import org.apache.spark.sql.internal.{SessionState, SharedState, SQLConf}
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types._
@@ -46,8 +43,8 @@ import org.apache.spark.sql.util.ExecutionListenerManager
/**
* The entry point for working with structured data (rows and columns) in Spark, in Spark 1.x.
*
- * As of Spark 2.0, this is replaced by [[SparkSession]]. However, we are keeping the class here
- * for backward compatibility.
+ * As of Spark 2.0, this is replaced by [[SparkSession]]. However, we are keeping the class
+ * here for backward compatibility.
*
* @groupname basic Basic Operations
* @groupname ddl_ops Persistent Catalog DDL
@@ -76,42 +73,21 @@ class SQLContext private[sql](
this(sparkSession, true)
}
+ @deprecated("Use SparkSession.builder instead", "2.0.0")
def this(sc: SparkContext) = {
this(new SparkSession(sc))
}
+ @deprecated("Use SparkSession.builder instead", "2.0.0")
def this(sparkContext: JavaSparkContext) = this(sparkContext.sc)
// TODO: move this logic into SparkSession
- // If spark.sql.allowMultipleContexts is true, we will throw an exception if a user
- // wants to create a new root SQLContext (a SQLContext that is not created by newSession).
- private val allowMultipleContexts =
- sparkContext.conf.getBoolean(
- SQLConf.ALLOW_MULTIPLE_CONTEXTS.key,
- SQLConf.ALLOW_MULTIPLE_CONTEXTS.defaultValue.get)
-
- // Assert no root SQLContext is running when allowMultipleContexts is false.
- {
- if (!allowMultipleContexts && isRootContext) {
- SQLContext.getInstantiatedContextOption() match {
- case Some(rootSQLContext) =>
- val errMsg = "Only one SQLContext/HiveContext may be running in this JVM. " +
- s"It is recommended to use SQLContext.getOrCreate to get the instantiated " +
- s"SQLContext/HiveContext. To ignore this error, " +
- s"set ${SQLConf.ALLOW_MULTIPLE_CONTEXTS.key} = true in SparkConf."
- throw new SparkException(errMsg)
- case None => // OK
- }
- }
- }
-
protected[sql] def sessionState: SessionState = sparkSession.sessionState
protected[sql] def sharedState: SharedState = sparkSession.sharedState
protected[sql] def conf: SQLConf = sessionState.conf
protected[sql] def runtimeConf: RuntimeConfig = sparkSession.conf
protected[sql] def cacheManager: CacheManager = sparkSession.cacheManager
- protected[sql] def listener: SQLListener = sparkSession.listener
protected[sql] def externalCatalog: ExternalCatalog = sparkSession.externalCatalog
def sparkContext: SparkContext = sparkSession.sparkContext
@@ -123,7 +99,7 @@ class SQLContext private[sql](
*
* @since 1.6.0
*/
- def newSession(): SQLContext = sparkSession.newSession().wrapped
+ def newSession(): SQLContext = sparkSession.newSession().sqlContext
/**
* An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s
@@ -760,21 +736,6 @@ class SQLContext private[sql](
schema: StructType): DataFrame = {
sparkSession.applySchemaToPythonRDD(rdd, schema)
}
-
- // TODO: move this logic into SparkSession
-
- // Register a successfully instantiated context to the singleton. This should be at the end of
- // the class definition so that the singleton is updated only if there is no exception in the
- // construction of the instance.
- sparkContext.addSparkListener(new SparkListener {
- override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = {
- SQLContext.clearInstantiatedContext()
- SQLContext.clearSqlListener()
- }
- })
-
- sparkSession.setWrappedContext(self)
- SQLContext.setInstantiatedContext(self)
}
/**
@@ -788,19 +749,6 @@ class SQLContext private[sql](
object SQLContext {
/**
- * The active SQLContext for the current thread.
- */
- private val activeContext: InheritableThreadLocal[SQLContext] =
- new InheritableThreadLocal[SQLContext]
-
- /**
- * Reference to the created SQLContext.
- */
- @transient private val instantiatedContext = new AtomicReference[SQLContext]()
-
- @transient private val sqlListener = new AtomicReference[SQLListener]()
-
- /**
* Get the singleton SQLContext if it exists or create a new one using the given SparkContext.
*
* This function can be used to create a singleton SQLContext object that can be shared across
@@ -811,41 +759,9 @@ object SQLContext {
*
* @since 1.5.0
*/
+ @deprecated("Use SparkSession.builder instead", "2.0.0")
def getOrCreate(sparkContext: SparkContext): SQLContext = {
- val ctx = activeContext.get()
- if (ctx != null && !ctx.sparkContext.isStopped) {
- return ctx
- }
-
- synchronized {
- val ctx = instantiatedContext.get()
- if (ctx == null || ctx.sparkContext.isStopped) {
- new SQLContext(sparkContext)
- } else {
- ctx
- }
- }
- }
-
- private[sql] def clearInstantiatedContext(): Unit = {
- instantiatedContext.set(null)
- }
-
- private[sql] def setInstantiatedContext(sqlContext: SQLContext): Unit = {
- synchronized {
- val ctx = instantiatedContext.get()
- if (ctx == null || ctx.sparkContext.isStopped) {
- instantiatedContext.set(sqlContext)
- }
- }
- }
-
- private[sql] def getInstantiatedContextOption(): Option[SQLContext] = {
- Option(instantiatedContext.get())
- }
-
- private[sql] def clearSqlListener(): Unit = {
- sqlListener.set(null)
+ SparkSession.builder().sparkContext(sparkContext).getOrCreate().sqlContext
}
/**
@@ -855,8 +771,9 @@ object SQLContext {
*
* @since 1.6.0
*/
+ @deprecated("Use SparkSession.setActiveSession instead", "2.0.0")
def setActive(sqlContext: SQLContext): Unit = {
- activeContext.set(sqlContext)
+ SparkSession.setActiveSession(sqlContext.sparkSession)
}
/**
@@ -865,12 +782,9 @@ object SQLContext {
*
* @since 1.6.0
*/
+ @deprecated("Use SparkSession.clearActiveSession instead", "2.0.0")
def clearActive(): Unit = {
- activeContext.remove()
- }
-
- private[sql] def getActive(): Option[SQLContext] = {
- Option(activeContext.get())
+ SparkSession.clearActiveSession()
}
/**
@@ -895,20 +809,6 @@ object SQLContext {
}
/**
- * Create a SQLListener then add it into SparkContext, and create an SQLTab if there is SparkUI.
- */
- private[sql] def createListenerAndUI(sc: SparkContext): SQLListener = {
- if (sqlListener.get() == null) {
- val listener = new SQLListener(sc.conf)
- if (sqlListener.compareAndSet(null, listener)) {
- sc.addSparkListener(listener)
- sc.ui.foreach(new SQLTab(listener, _))
- }
- }
- sqlListener.get()
- }
-
- /**
* Extract `spark.sql.*` properties from the conf and return them as a [[Properties]].
*/
private[sql] def getSQLProperties(sparkConf: SparkConf): Properties = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 629243bd1a..f697769bdc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql
import java.beans.Introspector
+import java.util.concurrent.atomic.AtomicReference
import scala.collection.JavaConverters._
import scala.reflect.ClassTag
@@ -30,6 +31,7 @@ import org.apache.spark.api.java.JavaRDD
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.CATALOG_IMPLEMENTATION
import org.apache.spark.rdd.RDD
+import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
import org.apache.spark.sql.catalog.Catalog
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.catalog._
@@ -98,24 +100,10 @@ class SparkSession private(
}
/**
- * A wrapped version of this session in the form of a [[SQLContext]].
+ * A wrapped version of this session in the form of a [[SQLContext]], for backward compatibility.
*/
@transient
- private var _wrapped: SQLContext = _
-
- @transient
- private val _wrappedLock = new Object
-
- protected[sql] def wrapped: SQLContext = _wrappedLock.synchronized {
- if (_wrapped == null) {
- _wrapped = new SQLContext(self, isRootContext = false)
- }
- _wrapped
- }
-
- protected[sql] def setWrappedContext(sqlContext: SQLContext): Unit = _wrappedLock.synchronized {
- _wrapped = sqlContext
- }
+ private[sql] val sqlContext: SQLContext = new SQLContext(this)
protected[sql] def cacheManager: CacheManager = sharedState.cacheManager
protected[sql] def listener: SQLListener = sharedState.listener
@@ -238,7 +226,7 @@ class SparkSession private(
*/
@Experimental
def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = {
- SQLContext.setActive(wrapped)
+ SparkSession.setActiveSession(this)
val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
val attributeSeq = schema.toAttributes
val rowRDD = RDDConversions.productToRowRdd(rdd, schema.map(_.dataType))
@@ -254,7 +242,7 @@ class SparkSession private(
*/
@Experimental
def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = {
- SQLContext.setActive(wrapped)
+ SparkSession.setActiveSession(this)
val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
val attributeSeq = schema.toAttributes
Dataset.ofRows(self, LocalRelation.fromProduct(attributeSeq, data))
@@ -573,7 +561,7 @@ class SparkSession private(
*/
@Experimental
object implicits extends SQLImplicits with Serializable {
- protected override def _sqlContext: SQLContext = wrapped
+ protected override def _sqlContext: SQLContext = SparkSession.this.sqlContext
}
// scalastyle:on
@@ -649,8 +637,16 @@ object SparkSession {
private[this] val options = new scala.collection.mutable.HashMap[String, String]
+ private[this] var userSuppliedContext: Option[SparkContext] = None
+
+ private[sql] def sparkContext(sparkContext: SparkContext): Builder = synchronized {
+ userSuppliedContext = Option(sparkContext)
+ this
+ }
+
/**
* Sets a name for the application, which will be shown in the Spark web UI.
+ * If no application name is set, a randomly generated name will be used.
*
* @since 2.0.0
*/
@@ -735,29 +731,130 @@ object SparkSession {
}
/**
- * Gets an existing [[SparkSession]] or, if there is no existing one, creates a new one
- * based on the options set in this builder.
+ * Gets an existing [[SparkSession]] or, if there is no existing one, creates a new
+ * one based on the options set in this builder.
+ *
+ * This method first checks whether there is a valid thread-local SparkSession,
+ * and if yes, return that one. It then checks whether there is a valid global
+ * default SparkSession, and if yes, return that one. If no valid global default
+ * SparkSession exists, the method creates a new SparkSession and assigns the
+ * newly created SparkSession as the global default.
+ *
+ * In case an existing SparkSession is returned, the config options specified in
+ * this builder will be applied to the existing SparkSession.
*
* @since 2.0.0
*/
def getOrCreate(): SparkSession = synchronized {
- // Step 1. Create a SparkConf
- // Step 2. Get a SparkContext
- // Step 3. Get a SparkSession
- val sparkConf = new SparkConf()
- options.foreach { case (k, v) => sparkConf.set(k, v) }
- val sparkContext = SparkContext.getOrCreate(sparkConf)
-
- SQLContext.getOrCreate(sparkContext).sparkSession
+ // Get the session from current thread's active session.
+ var session = activeThreadSession.get()
+ if ((session ne null) && !session.sparkContext.isStopped) {
+ options.foreach { case (k, v) => session.conf.set(k, v) }
+ return session
+ }
+
+ // Global synchronization so we will only set the default session once.
+ SparkSession.synchronized {
+ // If the current thread does not have an active session, get it from the global session.
+ session = defaultSession.get()
+ if ((session ne null) && !session.sparkContext.isStopped) {
+ options.foreach { case (k, v) => session.conf.set(k, v) }
+ return session
+ }
+
+ // No active nor global default session. Create a new one.
+ val sparkContext = userSuppliedContext.getOrElse {
+ // set app name if not given
+ if (!options.contains("spark.app.name")) {
+ options += "spark.app.name" -> java.util.UUID.randomUUID().toString
+ }
+
+ val sparkConf = new SparkConf()
+ options.foreach { case (k, v) => sparkConf.set(k, v) }
+ SparkContext.getOrCreate(sparkConf)
+ }
+ session = new SparkSession(sparkContext)
+ options.foreach { case (k, v) => session.conf.set(k, v) }
+ defaultSession.set(session)
+
+ // Register a successfully instantiated context to the singleton. This should be at the
+ // end of the class definition so that the singleton is updated only if there is no
+ // exception in the construction of the instance.
+ sparkContext.addSparkListener(new SparkListener {
+ override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = {
+ defaultSession.set(null)
+ sqlListener.set(null)
+ }
+ })
+ }
+
+ return session
}
}
/**
* Creates a [[SparkSession.Builder]] for constructing a [[SparkSession]].
+ *
* @since 2.0.0
*/
def builder(): Builder = new Builder
+ /**
+ * Changes the SparkSession that will be returned in this thread and its children when
+ * SparkSession.getOrCreate() is called. This can be used to ensure that a given thread receives
+ * a SparkSession with an isolated session, instead of the global (first created) context.
+ *
+ * @since 2.0.0
+ */
+ def setActiveSession(session: SparkSession): Unit = {
+ activeThreadSession.set(session)
+ }
+
+ /**
+ * Clears the active SparkSession for current thread. Subsequent calls to getOrCreate will
+ * return the first created context instead of a thread-local override.
+ *
+ * @since 2.0.0
+ */
+ def clearActiveSession(): Unit = {
+ activeThreadSession.remove()
+ }
+
+ /**
+ * Sets the default SparkSession that is returned by the builder.
+ *
+ * @since 2.0.0
+ */
+ def setDefaultSession(session: SparkSession): Unit = {
+ defaultSession.set(session)
+ }
+
+ /**
+ * Clears the default SparkSession that is returned by the builder.
+ *
+ * @since 2.0.0
+ */
+ def clearDefaultSession(): Unit = {
+ defaultSession.set(null)
+ }
+
+ private[sql] def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get)
+
+ private[sql] def getDefaultSession: Option[SparkSession] = Option(defaultSession.get)
+
+ /** A global SQL listener used for the SQL UI. */
+ private[sql] val sqlListener = new AtomicReference[SQLListener]()
+
+ ////////////////////////////////////////////////////////////////////////////////////////
+ // Private methods from now on
+ ////////////////////////////////////////////////////////////////////////////////////////
+
+ /** The active SparkSession for the current thread. */
+ private val activeThreadSession = new InheritableThreadLocal[SparkSession]
+
+ /** Reference to the root SparkSession. */
+ private val defaultSession = new AtomicReference[SparkSession]
+
private val HIVE_SHARED_STATE_CLASS_NAME = "org.apache.spark.sql.hive.HiveSharedState"
private val HIVE_SESSION_STATE_CLASS_NAME = "org.apache.spark.sql.hive.HiveSessionState"
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index 85af4faf4d..d8911f88b0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -157,7 +157,8 @@ private[sql] case class RowDataSourceScanExec(
val outputUnsafeRows = relation match {
case r: HadoopFsRelation if r.fileFormat.isInstanceOf[ParquetSource] =>
- !SQLContext.getActive().get.conf.getConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED)
+ !SparkSession.getActiveSession.get.sessionState.conf.getConf(
+ SQLConf.PARQUET_VECTORIZED_READER_ENABLED)
case _: HadoopFsRelation => true
case _ => false
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index cb3c46a98b..34187b9a1a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -60,7 +60,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
}
lazy val analyzed: LogicalPlan = {
- SQLContext.setActive(sparkSession.wrapped)
+ SparkSession.setActiveSession(sparkSession)
sparkSession.sessionState.analyzer.execute(logical)
}
@@ -73,7 +73,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
lazy val optimizedPlan: LogicalPlan = sparkSession.sessionState.optimizer.execute(withCachedData)
lazy val sparkPlan: SparkPlan = {
- SQLContext.setActive(sparkSession.wrapped)
+ SparkSession.setActiveSession(sparkSession)
planner.plan(ReturnAnswer(optimizedPlan)).next()
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index b94b84d77a..045ccc7bd6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -27,7 +27,7 @@ import org.apache.spark.{broadcast, SparkEnv}
import org.apache.spark.internal.Logging
import org.apache.spark.io.CompressionCodec
import org.apache.spark.rdd.{RDD, RDDOperationScope}
-import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.{Row, SparkSession, SQLContext}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
@@ -50,7 +50,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
* populated by the query planning infrastructure.
*/
@transient
- protected[spark] final val sqlContext = SQLContext.getActive().orNull
+ final val sqlContext = SparkSession.getActiveSession.map(_.sqlContext).orNull
protected def sparkContext = sqlContext.sparkContext
@@ -65,7 +65,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
/** Overridden make copy also propagates sqlContext to copied plan. */
override def makeCopy(newArgs: Array[AnyRef]): SparkPlan = {
- SQLContext.setActive(sqlContext)
+ SparkSession.setActiveSession(sqlContext.sparkSession)
super.makeCopy(newArgs)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index ccad9b3fd5..2e17b763a5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -178,7 +178,7 @@ case class DataSource(
providingClass.newInstance() match {
case s: StreamSourceProvider =>
val (name, schema) = s.sourceSchema(
- sparkSession.wrapped, userSpecifiedSchema, className, options)
+ sparkSession.sqlContext, userSpecifiedSchema, className, options)
SourceInfo(name, schema)
case format: FileFormat =>
@@ -198,7 +198,8 @@ case class DataSource(
def createSource(metadataPath: String): Source = {
providingClass.newInstance() match {
case s: StreamSourceProvider =>
- s.createSource(sparkSession.wrapped, metadataPath, userSpecifiedSchema, className, options)
+ s.createSource(
+ sparkSession.sqlContext, metadataPath, userSpecifiedSchema, className, options)
case format: FileFormat =>
val path = new CaseInsensitiveMap(options).getOrElse("path", {
@@ -215,7 +216,7 @@ case class DataSource(
/** Returns a sink that can be used to continually write data. */
def createSink(): Sink = {
providingClass.newInstance() match {
- case s: StreamSinkProvider => s.createSink(sparkSession.wrapped, options, partitionColumns)
+ case s: StreamSinkProvider => s.createSink(sparkSession.sqlContext, options, partitionColumns)
case parquet: parquet.DefaultSource =>
val caseInsensitiveOptions = new CaseInsensitiveMap(options)
@@ -265,9 +266,9 @@ case class DataSource(
val relation = (providingClass.newInstance(), userSpecifiedSchema) match {
// TODO: Throw when too much is given.
case (dataSource: SchemaRelationProvider, Some(schema)) =>
- dataSource.createRelation(sparkSession.wrapped, caseInsensitiveOptions, schema)
+ dataSource.createRelation(sparkSession.sqlContext, caseInsensitiveOptions, schema)
case (dataSource: RelationProvider, None) =>
- dataSource.createRelation(sparkSession.wrapped, caseInsensitiveOptions)
+ dataSource.createRelation(sparkSession.sqlContext, caseInsensitiveOptions)
case (_: SchemaRelationProvider, None) =>
throw new AnalysisException(s"A schema needs to be specified when using $className.")
case (_: RelationProvider, Some(_)) =>
@@ -383,7 +384,7 @@ case class DataSource(
providingClass.newInstance() match {
case dataSource: CreatableRelationProvider =>
- dataSource.createRelation(sparkSession.wrapped, mode, options, data)
+ dataSource.createRelation(sparkSession.sqlContext, mode, options, data)
case format: FileFormat =>
// Don't glob path for the write path. The contracts here are:
// 1. Only one output path can be specified on the write path;
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala
index 8d332df029..88125a2b4d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala
@@ -142,7 +142,7 @@ case class HadoopFsRelation(
fileFormat: FileFormat,
options: Map[String, String]) extends BaseRelation with FileRelation {
- override def sqlContext: SQLContext = sparkSession.wrapped
+ override def sqlContext: SQLContext = sparkSession.sqlContext
val schema: StructType = {
val dataSchemaColumnNames = dataSchema.map(_.name.toLowerCase).toSet
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
index bcf70fdc4a..233b7891d6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
@@ -92,7 +92,7 @@ private[sql] case class JDBCRelation(
with PrunedFilteredScan
with InsertableRelation {
- override def sqlContext: SQLContext = sparkSession.wrapped
+ override def sqlContext: SQLContext = sparkSession.sqlContext
override val needConversion: Boolean = false
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index df6304d85f..7d09bdcebd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -173,7 +173,7 @@ class StreamExecution(
startLatch.countDown()
// While active, repeatedly attempt to run batches.
- SQLContext.setActive(sparkSession.wrapped)
+ SparkSession.setActiveSession(sparkSession)
triggerExecutor.execute(() => {
if (isActive) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 65bc043076..0b490fe71c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1168,7 +1168,7 @@ object functions {
* @group normal_funcs
*/
def expr(expr: String): Column = {
- val parser = SQLContext.getActive().map(_.sessionState.sqlParser).getOrElse {
+ val parser = SparkSession.getActiveSession.map(_.sessionState.sqlParser).getOrElse {
new SparkSqlParser(new SQLConf)
}
Column(parser.parseExpression(expr))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 5d18689801..35d67ca2d8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -70,16 +70,6 @@ object SQLConf {
.intConf
.createWithDefault(10)
- val ALLOW_MULTIPLE_CONTEXTS = SQLConfigBuilder("spark.sql.allowMultipleContexts")
- .doc("When set to true, creating multiple SQLContexts/HiveContexts is allowed. " +
- "When set to false, only one SQLContext/HiveContext is allowed to be created " +
- "through the constructor (new SQLContexts/HiveContexts created through newSession " +
- "method is allowed). Please note that this conf needs to be set in Spark Conf. Once " +
- "a SQLContext/HiveContext has been created, changing the value of this conf will not " +
- "have effect.")
- .booleanConf
- .createWithDefault(true)
-
val COMPRESS_CACHED = SQLConfigBuilder("spark.sql.inMemoryColumnarStorage.compressed")
.internal()
.doc("When set to true Spark SQL will automatically select a compression codec for each " +
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala
index eaf993aaed..9f6137d6e3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala
@@ -18,10 +18,10 @@
package org.apache.spark.sql.internal
import org.apache.spark.SparkContext
-import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.{SparkSession, SQLContext}
import org.apache.spark.sql.catalyst.catalog.{ExternalCatalog, InMemoryCatalog}
import org.apache.spark.sql.execution.CacheManager
-import org.apache.spark.sql.execution.ui.SQLListener
+import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab}
import org.apache.spark.util.MutableURLClassLoader
@@ -38,7 +38,7 @@ private[sql] class SharedState(val sparkContext: SparkContext) {
/**
* A listener for SQL-specific [[org.apache.spark.scheduler.SparkListenerEvent]]s.
*/
- val listener: SQLListener = SQLContext.createListenerAndUI(sparkContext)
+ val listener: SQLListener = createListenerAndUI(sparkContext)
/**
* A catalog that interacts with external systems.
@@ -51,6 +51,19 @@ private[sql] class SharedState(val sparkContext: SparkContext) {
val jarClassLoader = new NonClosableMutableURLClassLoader(
org.apache.spark.util.Utils.getContextOrSparkClassLoader)
+ /**
+ * Create a SQLListener then add it into SparkContext, and create an SQLTab if there is SparkUI.
+ */
+ private def createListenerAndUI(sc: SparkContext): SQLListener = {
+ if (SparkSession.sqlListener.get() == null) {
+ val listener = new SQLListener(sc.conf)
+ if (SparkSession.sqlListener.compareAndSet(null, listener)) {
+ sc.addSparkListener(listener)
+ sc.ui.foreach(new SQLTab(listener, _))
+ }
+ }
+ SparkSession.sqlListener.get()
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala
index 65fe271b69..b447006761 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala
@@ -39,7 +39,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex
test("get all tables") {
checkAnswer(
- spark.wrapped.tables().filter("tableName = 'listtablessuitetable'"),
+ spark.sqlContext.tables().filter("tableName = 'listtablessuitetable'"),
Row("listtablessuitetable", true))
checkAnswer(
@@ -48,12 +48,12 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex
spark.sessionState.catalog.dropTable(
TableIdentifier("listtablessuitetable"), ignoreIfNotExists = true)
- assert(spark.wrapped.tables().filter("tableName = 'listtablessuitetable'").count() === 0)
+ assert(spark.sqlContext.tables().filter("tableName = 'listtablessuitetable'").count() === 0)
}
test("getting all tables with a database name has no impact on returned table names") {
checkAnswer(
- spark.wrapped.tables("default").filter("tableName = 'listtablessuitetable'"),
+ spark.sqlContext.tables("default").filter("tableName = 'listtablessuitetable'"),
Row("listtablessuitetable", true))
checkAnswer(
@@ -62,7 +62,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex
spark.sessionState.catalog.dropTable(
TableIdentifier("listtablessuitetable"), ignoreIfNotExists = true)
- assert(spark.wrapped.tables().filter("tableName = 'listtablessuitetable'").count() === 0)
+ assert(spark.sqlContext.tables().filter("tableName = 'listtablessuitetable'").count() === 0)
}
test("query the returned DataFrame of tables") {
@@ -70,7 +70,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex
StructField("tableName", StringType, false) ::
StructField("isTemporary", BooleanType, false) :: Nil)
- Seq(spark.wrapped.tables(), sql("SHOW TABLes")).foreach {
+ Seq(spark.sqlContext.tables(), sql("SHOW TABLes")).foreach {
case tableDF =>
assert(expectedSchema === tableDF.schema)
@@ -81,7 +81,8 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex
Row(true, "listtablessuitetable")
)
checkAnswer(
- spark.wrapped.tables().filter("tableName = 'tables'").select("tableName", "isTemporary"),
+ spark.sqlContext.tables()
+ .filter("tableName = 'tables'").select("tableName", "isTemporary"),
Row("tables", true))
spark.catalog.dropTempView("tables")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala
deleted file mode 100644
index 0b5a92c256..0000000000
--- a/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala
+++ /dev/null
@@ -1,100 +0,0 @@
-/*
-* Licensed to the Apache Software Foundation (ASF) under one or more
-* contributor license agreements. See the NOTICE file distributed with
-* this work for additional information regarding copyright ownership.
-* The ASF licenses this file to You under the Apache License, Version 2.0
-* (the "License"); you may not use this file except in compliance with
-* the License. You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*/
-
-package org.apache.spark.sql
-
-import org.scalatest.BeforeAndAfterAll
-
-import org.apache.spark._
-import org.apache.spark.sql.internal.SQLConf
-
-class MultiSQLContextsSuite extends SparkFunSuite with BeforeAndAfterAll {
-
- private var originalActiveSQLContext: Option[SQLContext] = _
- private var originalInstantiatedSQLContext: Option[SQLContext] = _
- private var sparkConf: SparkConf = _
-
- override protected def beforeAll(): Unit = {
- originalActiveSQLContext = SQLContext.getActive()
- originalInstantiatedSQLContext = SQLContext.getInstantiatedContextOption()
-
- SQLContext.clearActive()
- SQLContext.clearInstantiatedContext()
- sparkConf =
- new SparkConf(false)
- .setMaster("local[*]")
- .setAppName("test")
- .set("spark.ui.enabled", "false")
- .set("spark.driver.allowMultipleContexts", "true")
- }
-
- override protected def afterAll(): Unit = {
- // Set these states back.
- originalActiveSQLContext.foreach(ctx => SQLContext.setActive(ctx))
- originalInstantiatedSQLContext.foreach(ctx => SQLContext.setInstantiatedContext(ctx))
- }
-
- def testNewSession(rootSQLContext: SQLContext): Unit = {
- // Make sure we can successfully create new Session.
- rootSQLContext.newSession()
-
- // Reset the state. It is always safe to clear the active context.
- SQLContext.clearActive()
- }
-
- def testCreatingNewSQLContext(allowsMultipleContexts: Boolean): Unit = {
- val conf =
- sparkConf
- .clone
- .set(SQLConf.ALLOW_MULTIPLE_CONTEXTS.key, allowsMultipleContexts.toString)
- val sparkContext = new SparkContext(conf)
-
- try {
- if (allowsMultipleContexts) {
- new SQLContext(sparkContext)
- SQLContext.clearActive()
- } else {
- // If allowsMultipleContexts is false, make sure we can get the error.
- val message = intercept[SparkException] {
- new SQLContext(sparkContext)
- }.getMessage
- assert(message.contains("Only one SQLContext/HiveContext may be running"))
- }
- } finally {
- sparkContext.stop()
- }
- }
-
- test("test the flag to disallow creating multiple root SQLContext") {
- Seq(false, true).foreach { allowMultipleSQLContexts =>
- val conf =
- sparkConf
- .clone
- .set(SQLConf.ALLOW_MULTIPLE_CONTEXTS.key, allowMultipleSQLContexts.toString)
- val sc = new SparkContext(conf)
- try {
- val rootSQLContext = new SQLContext(sc)
- testNewSession(rootSQLContext)
- testNewSession(rootSQLContext)
- testCreatingNewSQLContext(allowMultipleSQLContexts)
- } finally {
- sc.stop()
- SQLContext.clearInstantiatedContext()
- }
- }
- }
-}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala
index 38d7b6e25b..c9594a7e9a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala
@@ -40,7 +40,7 @@ class SQLContextSuite extends SparkFunSuite with SharedSparkContext {
val newSession = sqlContext.newSession()
assert(SQLContext.getOrCreate(sc).eq(sqlContext),
"SQLContext.getOrCreate after explicitly created SQLContext did not return the context")
- SQLContext.setActive(newSession)
+ SparkSession.setActiveSession(newSession.sparkSession)
assert(SQLContext.getOrCreate(sc).eq(newSession),
"SQLContext.getOrCreate after explicitly setActive() did not return the active context")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 743a27aa7a..460e34a5ff 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -1042,7 +1042,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
test("SET commands semantics using sql()") {
- spark.wrapped.conf.clear()
+ spark.sqlContext.conf.clear()
val testKey = "test.key.0"
val testVal = "test.val.0"
val nonexistentKey = "nonexistent"
@@ -1083,17 +1083,17 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
sql(s"SET $nonexistentKey"),
Row(nonexistentKey, "<undefined>")
)
- spark.wrapped.conf.clear()
+ spark.sqlContext.conf.clear()
}
test("SET commands with illegal or inappropriate argument") {
- spark.wrapped.conf.clear()
+ spark.sqlContext.conf.clear()
// Set negative mapred.reduce.tasks for automatically determining
// the number of reducers is not supported
intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-1"))
intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-01"))
intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-2"))
- spark.wrapped.conf.clear()
+ spark.sqlContext.conf.clear()
}
test("apply schema") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala
index b489b74fec..cd6b2647e0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala
@@ -25,6 +25,6 @@ class SerializationSuite extends SparkFunSuite with SharedSQLContext {
test("[SPARK-5235] SQLContext should be serializable") {
val spark = SparkSession.builder.getOrCreate()
- new JavaSerializer(new SparkConf()).newInstance().serialize(spark.wrapped)
+ new JavaSerializer(new SparkConf()).newInstance().serialize(spark.sqlContext)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala
new file mode 100644
index 0000000000..ec6a2b3575
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.{SparkContext, SparkFunSuite}
+
+/**
+ * Test cases for the builder pattern of [[SparkSession]].
+ */
+class SparkSessionBuilderSuite extends SparkFunSuite {
+
+ private var initialSession: SparkSession = _
+
+ private lazy val sparkContext: SparkContext = {
+ initialSession = SparkSession.builder()
+ .master("local")
+ .config("spark.ui.enabled", value = false)
+ .config("some-config", "v2")
+ .getOrCreate()
+ initialSession.sparkContext
+ }
+
+ test("create with config options and propagate them to SparkContext and SparkSession") {
+ // Creating a new session with config - this works by just calling the lazy val
+ sparkContext
+ assert(initialSession.sparkContext.conf.get("some-config") == "v2")
+ assert(initialSession.conf.get("some-config") == "v2")
+ SparkSession.clearDefaultSession()
+ }
+
+ test("use global default session") {
+ val session = SparkSession.builder().getOrCreate()
+ assert(SparkSession.builder().getOrCreate() == session)
+ SparkSession.clearDefaultSession()
+ }
+
+ test("config options are propagated to existing SparkSession") {
+ val session1 = SparkSession.builder().config("spark-config1", "a").getOrCreate()
+ assert(session1.conf.get("spark-config1") == "a")
+ val session2 = SparkSession.builder().config("spark-config1", "b").getOrCreate()
+ assert(session1 == session2)
+ assert(session1.conf.get("spark-config1") == "b")
+ SparkSession.clearDefaultSession()
+ }
+
+ test("use session from active thread session and propagate config options") {
+ val defaultSession = SparkSession.builder().getOrCreate()
+ val activeSession = defaultSession.newSession()
+ SparkSession.setActiveSession(activeSession)
+ val session = SparkSession.builder().config("spark-config2", "a").getOrCreate()
+
+ assert(activeSession != defaultSession)
+ assert(session == activeSession)
+ assert(session.conf.get("spark-config2") == "a")
+ SparkSession.clearActiveSession()
+
+ assert(SparkSession.builder().getOrCreate() == defaultSession)
+ SparkSession.clearDefaultSession()
+ }
+
+ test("create a new session if the default session has been stopped") {
+ val defaultSession = SparkSession.builder().getOrCreate()
+ SparkSession.setDefaultSession(defaultSession)
+ defaultSession.stop()
+ val newSession = SparkSession.builder().master("local").getOrCreate()
+ assert(newSession != defaultSession)
+ newSession.stop()
+ }
+
+ test("create a new session if the active thread session has been stopped") {
+ val activeSession = SparkSession.builder().master("local").getOrCreate()
+ SparkSession.setActiveSession(activeSession)
+ activeSession.stop()
+ val newSession = SparkSession.builder().master("local").getOrCreate()
+ assert(newSession != activeSession)
+ newSession.stop()
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala
index 9523f6f9f5..4de3cf605c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala
@@ -26,9 +26,9 @@ class StatisticsSuite extends QueryTest with SharedSQLContext {
val rdd = sparkContext.range(1, 100).map(i => Row(i, i))
val df = spark.createDataFrame(rdd, new StructType().add("a", LongType).add("b", LongType))
assert(df.queryExecution.analyzed.statistics.sizeInBytes >
- spark.wrapped.conf.autoBroadcastJoinThreshold)
+ spark.sessionState.conf.autoBroadcastJoinThreshold)
assert(df.selectExpr("a").queryExecution.analyzed.statistics.sizeInBytes >
- spark.wrapped.conf.autoBroadcastJoinThreshold)
+ spark.sessionState.conf.autoBroadcastJoinThreshold)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala
index 70a00a43f7..2f45db3925 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala
@@ -27,21 +27,21 @@ import org.apache.spark.sql.internal.SQLConf
class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll {
- private var originalActiveSQLContext: Option[SQLContext] = _
- private var originalInstantiatedSQLContext: Option[SQLContext] = _
+ private var originalActiveSQLContext: Option[SparkSession] = _
+ private var originalInstantiatedSQLContext: Option[SparkSession] = _
override protected def beforeAll(): Unit = {
- originalActiveSQLContext = SQLContext.getActive()
- originalInstantiatedSQLContext = SQLContext.getInstantiatedContextOption()
+ originalActiveSQLContext = SparkSession.getActiveSession
+ originalInstantiatedSQLContext = SparkSession.getDefaultSession
- SQLContext.clearActive()
- SQLContext.clearInstantiatedContext()
+ SparkSession.clearActiveSession()
+ SparkSession.clearDefaultSession()
}
override protected def afterAll(): Unit = {
// Set these states back.
- originalActiveSQLContext.foreach(ctx => SQLContext.setActive(ctx))
- originalInstantiatedSQLContext.foreach(ctx => SQLContext.setInstantiatedContext(ctx))
+ originalActiveSQLContext.foreach(ctx => SparkSession.setActiveSession(ctx))
+ originalInstantiatedSQLContext.foreach(ctx => SparkSession.setDefaultSession(ctx))
}
private def checkEstimation(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 2a5295d0d2..8243470b19 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -155,7 +155,7 @@ class PlannerSuite extends SharedSQLContext {
val path = file.getCanonicalPath
testData.write.parquet(path)
val df = spark.read.parquet(path)
- spark.wrapped.registerDataFrameAsTable(df, "testPushed")
+ spark.sqlContext.registerDataFrameAsTable(df, "testPushed")
withTempTable("testPushed") {
val exp = sql("select * from testPushed where key = 15").queryExecution.sparkPlan
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
index d7eae21f9f..9fe0e9646e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
@@ -91,7 +91,7 @@ private[sql] abstract class SparkPlanTest extends SparkFunSuite {
expectedAnswer: Seq[Row],
sortAnswers: Boolean = true): Unit = {
SparkPlanTest
- .checkAnswer(input, planFunction, expectedAnswer, sortAnswers, spark.wrapped) match {
+ .checkAnswer(input, planFunction, expectedAnswer, sortAnswers, spark.sqlContext) match {
case Some(errorMessage) => fail(errorMessage)
case None =>
}
@@ -115,7 +115,7 @@ private[sql] abstract class SparkPlanTest extends SparkFunSuite {
expectedPlanFunction: SparkPlan => SparkPlan,
sortAnswers: Boolean = true): Unit = {
SparkPlanTest.checkAnswer(
- input, planFunction, expectedPlanFunction, sortAnswers, spark.wrapped) match {
+ input, planFunction, expectedPlanFunction, sortAnswers, spark.sqlContext) match {
case Some(errorMessage) => fail(errorMessage)
case None =>
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala
index b5fc51603e..1753b84ba6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala
@@ -90,7 +90,7 @@ private[sql] trait ParquetTest extends SQLTestUtils {
(data: Seq[T], tableName: String, testVectorized: Boolean = true)
(f: => Unit): Unit = {
withParquetDataFrame(data, testVectorized) { df =>
- spark.wrapped.registerDataFrameAsTable(df, tableName)
+ spark.sqlContext.registerDataFrameAsTable(df, tableName)
withTempTable(tableName)(f)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
index 4fa1754253..bd197be655 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
@@ -60,13 +60,13 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
val opId = 0
val rdd1 =
makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore(
- spark.wrapped, path, opId, storeVersion = 0, keySchema, valueSchema)(
+ spark.sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(
increment)
assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
// Generate next version of stores
val rdd2 = makeRDD(spark.sparkContext, Seq("a", "c")).mapPartitionsWithStateStore(
- spark.wrapped, path, opId, storeVersion = 1, keySchema, valueSchema)(increment)
+ spark.sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(increment)
assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1))
// Make sure the previous RDD still has the same data.
@@ -82,7 +82,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
spark: SparkSession,
seq: Seq[String],
storeVersion: Int): RDD[(String, Int)] = {
- implicit val sqlContext = spark.wrapped
+ implicit val sqlContext = spark.sqlContext
makeRDD(spark.sparkContext, Seq("a")).mapPartitionsWithStateStore(
sqlContext, path, opId, storeVersion, keySchema, valueSchema)(increment)
}
@@ -102,7 +102,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
test("usage with iterators - only gets and only puts") {
withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark =>
- implicit val sqlContext = spark.wrapped
+ implicit val sqlContext = spark.sqlContext
val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
val opId = 0
@@ -131,7 +131,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
}
val rddOfGets1 = makeRDD(spark.sparkContext, Seq("a", "b", "c")).mapPartitionsWithStateStore(
- spark.wrapped, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfGets)
+ spark.sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfGets)
assert(rddOfGets1.collect().toSet === Set("a" -> None, "b" -> None, "c" -> None))
val rddOfPuts = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore(
@@ -150,7 +150,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark =>
- implicit val sqlContext = spark.wrapped
+ implicit val sqlContext = spark.sqlContext
val coordinatorRef = sqlContext.streams.stateStoreCoordinator
coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 0), "host1", "exec1")
coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 1), "host2", "exec2")
@@ -183,7 +183,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
SparkSession.builder
.config(sparkConf.setMaster("local-cluster[2, 1, 1024]"))
.getOrCreate()) { spark =>
- implicit val sqlContext = spark.wrapped
+ implicit val sqlContext = spark.sqlContext
val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
val opId = 0
val rdd1 = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
index 1c467137ba..2374ffaaa5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
@@ -24,7 +24,7 @@ import org.mockito.Mockito.mock
import org.apache.spark._
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.scheduler._
-import org.apache.spark.sql.{DataFrame, SQLContext}
+import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext}
import org.apache.spark.sql.catalyst.util.quietly
import org.apache.spark.sql.execution.{SparkPlanInfo, SQLExecution}
import org.apache.spark.sql.execution.metric.SQLMetrics
@@ -400,8 +400,8 @@ class SQLListenerMemoryLeakSuite extends SparkFunSuite {
.set("spark.sql.ui.retainedExecutions", "50") // Set it to 50 to run this test quickly
val sc = new SparkContext(conf)
try {
- SQLContext.clearSqlListener()
- val spark = new SQLContext(sc)
+ SparkSession.sqlListener.set(null)
+ val spark = new SparkSession(sc)
import spark.implicits._
// Run 100 successful executions and 100 failed executions.
// Each execution only has one job and one stage.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala
index 81bc973be7..0296229100 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala
@@ -35,7 +35,7 @@ class SQLConfSuite extends QueryTest with SharedSQLContext {
// Set a conf first.
spark.conf.set(testKey, testVal)
// Clear the conf.
- spark.wrapped.conf.clear()
+ spark.sqlContext.conf.clear()
// After clear, only overrideConfs used by unit test should be in the SQLConf.
assert(spark.conf.getAll === TestSQLContext.overrideConfs)
@@ -50,11 +50,11 @@ class SQLConfSuite extends QueryTest with SharedSQLContext {
assert(spark.conf.get(testKey, testVal + "_") === testVal)
assert(spark.conf.getAll.contains(testKey))
- spark.wrapped.conf.clear()
+ spark.sqlContext.conf.clear()
}
test("parse SQL set commands") {
- spark.wrapped.conf.clear()
+ spark.sqlContext.conf.clear()
sql(s"set $testKey=$testVal")
assert(spark.conf.get(testKey, testVal + "_") === testVal)
assert(spark.conf.get(testKey, testVal + "_") === testVal)
@@ -72,11 +72,11 @@ class SQLConfSuite extends QueryTest with SharedSQLContext {
sql(s"set $key=")
assert(spark.conf.get(key, "0") === "")
- spark.wrapped.conf.clear()
+ spark.sqlContext.conf.clear()
}
test("set command for display") {
- spark.wrapped.conf.clear()
+ spark.sessionState.conf.clear()
checkAnswer(
sql("SET").where("key = 'spark.sql.groupByOrdinal'").select("key", "value"),
Nil)
@@ -97,7 +97,7 @@ class SQLConfSuite extends QueryTest with SharedSQLContext {
}
test("deprecated property") {
- spark.wrapped.conf.clear()
+ spark.sqlContext.conf.clear()
val original = spark.conf.get(SQLConf.SHUFFLE_PARTITIONS)
try{
sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10")
@@ -108,7 +108,7 @@ class SQLConfSuite extends QueryTest with SharedSQLContext {
}
test("invalid conf value") {
- spark.wrapped.conf.clear()
+ spark.sqlContext.conf.clear()
val e = intercept[IllegalArgumentException] {
sql(s"set ${SQLConf.CASE_SENSITIVE.key}=10")
}
@@ -116,7 +116,7 @@ class SQLConfSuite extends QueryTest with SharedSQLContext {
}
test("Test SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE's method") {
- spark.wrapped.conf.clear()
+ spark.sqlContext.conf.clear()
spark.conf.set(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "100")
assert(spark.conf.get(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE) === 100)
@@ -144,7 +144,7 @@ class SQLConfSuite extends QueryTest with SharedSQLContext {
spark.conf.set(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "-90000000000g")
}
- spark.wrapped.conf.clear()
+ spark.sqlContext.conf.clear()
}
test("SparkSession can access configs set in SparkConf") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala
index 612cfc7ec7..a34f70ed65 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala
@@ -41,7 +41,7 @@ case class SimpleDDLScan(
table: String)(@transient val sparkSession: SparkSession)
extends BaseRelation with TableScan {
- override def sqlContext: SQLContext = sparkSession.wrapped
+ override def sqlContext: SQLContext = sparkSession.sqlContext
override def schema: StructType =
StructType(Seq(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
index 51d04f2f4e..f969660ddd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
@@ -40,7 +40,7 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sparkSession: S
extends BaseRelation
with PrunedFilteredScan {
- override def sqlContext: SQLContext = sparkSession.wrapped
+ override def sqlContext: SQLContext = sparkSession.sqlContext
override def schema: StructType =
StructType(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala
index cd0256db43..9cdf7dea76 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala
@@ -37,7 +37,7 @@ case class SimplePrunedScan(from: Int, to: Int)(@transient val sparkSession: Spa
extends BaseRelation
with PrunedScan {
- override def sqlContext: SQLContext = sparkSession.wrapped
+ override def sqlContext: SQLContext = sparkSession.sqlContext
override def schema: StructType =
StructType(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
index 34b8726a92..cddf4a1884 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
@@ -38,7 +38,7 @@ class SimpleScanSource extends RelationProvider {
case class SimpleScan(from: Int, to: Int)(@transient val sparkSession: SparkSession)
extends BaseRelation with TableScan {
- override def sqlContext: SQLContext = sparkSession.wrapped
+ override def sqlContext: SQLContext = sparkSession.sqlContext
override def schema: StructType =
StructType(StructField("i", IntegerType, nullable = false) :: Nil)
@@ -70,7 +70,7 @@ case class AllDataTypesScan(
extends BaseRelation
with TableScan {
- override def sqlContext: SQLContext = sparkSession.wrapped
+ override def sqlContext: SQLContext = sparkSession.sqlContext
override def schema: StructType = userSpecifiedSchema
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala
index ff53505549..e6c0ce95e7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala
@@ -355,14 +355,14 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B
q.stop()
verify(LastOptions.mockStreamSourceProvider).createSource(
- spark.wrapped,
+ spark.sqlContext,
checkpointLocation + "/sources/0",
None,
"org.apache.spark.sql.streaming.test",
Map.empty)
verify(LastOptions.mockStreamSourceProvider).createSource(
- spark.wrapped,
+ spark.sqlContext,
checkpointLocation + "/sources/1",
None,
"org.apache.spark.sql.streaming.test",
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
index 421f6bca7f..0cfe260e52 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
@@ -30,7 +30,7 @@ private[sql] trait SQLTestData { self =>
// Helper object to import SQL implicits without a concrete SQLContext
private object internalImplicits extends SQLImplicits {
- protected override def _sqlContext: SQLContext = self.spark.wrapped
+ protected override def _sqlContext: SQLContext = self.spark.sqlContext
}
import internalImplicits._
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
index 51538eca64..853dd0ff3f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
@@ -66,7 +66,7 @@ private[sql] trait SQLTestUtils
* but the implicits import is needed in the constructor.
*/
protected object testImplicits extends SQLImplicits {
- protected override def _sqlContext: SQLContext = self.spark.wrapped
+ protected override def _sqlContext: SQLContext = self.spark.sqlContext
}
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala
index 620bfa995a..79c37faa4e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala
@@ -44,13 +44,13 @@ trait SharedSQLContext extends SQLTestUtils {
/**
* The [[TestSQLContext]] to use for all tests in this suite.
*/
- protected implicit def sqlContext: SQLContext = _spark.wrapped
+ protected implicit def sqlContext: SQLContext = _spark.sqlContext
/**
* Initialize the [[TestSparkSession]].
*/
protected override def beforeAll(): Unit = {
- SQLContext.clearSqlListener()
+ SparkSession.sqlListener.set(null)
if (_spark == null) {
_spark = new TestSparkSession(sparkConf)
}
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala
index 8de223f444..638911599a 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala
@@ -56,7 +56,7 @@ private[hive] object SparkSQLEnv extends Logging {
val sparkSession = SparkSession.builder.config(sparkConf).enableHiveSupport().getOrCreate()
sparkContext = sparkSession.sparkContext
- sqlContext = sparkSession.wrapped
+ sqlContext = sparkSession.sqlContext
val sessionState = sparkSession.sessionState.asInstanceOf[HiveSessionState]
sessionState.metadataHive.setOut(new PrintStream(System.out, true, "UTF-8"))
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala
index d2cb62c617..7c74a0308d 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala
@@ -30,7 +30,7 @@ class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAnd
override protected def beforeEach(): Unit = {
super.beforeEach()
- if (spark.wrapped.tableNames().contains("src")) {
+ if (spark.sqlContext.tableNames().contains("src")) {
spark.catalog.dropTempView("src")
}
Seq((1, "")).toDF("key", "value").createOrReplaceTempView("src")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala
index 6c9ce208db..622b043581 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala
@@ -36,11 +36,11 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
withTempDatabase { db =>
activateDatabase(db) {
df.write.mode(SaveMode.Overwrite).saveAsTable("t")
- assert(spark.wrapped.tableNames().contains("t"))
+ assert(spark.sqlContext.tableNames().contains("t"))
checkAnswer(spark.table("t"), df)
}
- assert(spark.wrapped.tableNames(db).contains("t"))
+ assert(spark.sqlContext.tableNames(db).contains("t"))
checkAnswer(spark.table(s"$db.t"), df)
checkTablePath(db, "t")
@@ -50,7 +50,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
test(s"saveAsTable() to non-default database - without USE - Overwrite") {
withTempDatabase { db =>
df.write.mode(SaveMode.Overwrite).saveAsTable(s"$db.t")
- assert(spark.wrapped.tableNames(db).contains("t"))
+ assert(spark.sqlContext.tableNames(db).contains("t"))
checkAnswer(spark.table(s"$db.t"), df)
checkTablePath(db, "t")
@@ -65,7 +65,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
df.write.format("parquet").mode(SaveMode.Overwrite).save(path)
spark.catalog.createExternalTable("t", path, "parquet")
- assert(spark.wrapped.tableNames(db).contains("t"))
+ assert(spark.sqlContext.tableNames(db).contains("t"))
checkAnswer(spark.table("t"), df)
sql(
@@ -76,7 +76,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
| path '$path'
|)
""".stripMargin)
- assert(spark.wrapped.tableNames(db).contains("t1"))
+ assert(spark.sqlContext.tableNames(db).contains("t1"))
checkAnswer(spark.table("t1"), df)
}
}
@@ -90,7 +90,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
df.write.format("parquet").mode(SaveMode.Overwrite).save(path)
spark.catalog.createExternalTable(s"$db.t", path, "parquet")
- assert(spark.wrapped.tableNames(db).contains("t"))
+ assert(spark.sqlContext.tableNames(db).contains("t"))
checkAnswer(spark.table(s"$db.t"), df)
sql(
@@ -101,7 +101,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
| path '$path'
|)
""".stripMargin)
- assert(spark.wrapped.tableNames(db).contains("t1"))
+ assert(spark.sqlContext.tableNames(db).contains("t1"))
checkAnswer(spark.table(s"$db.t1"), df)
}
}
@@ -112,11 +112,11 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
activateDatabase(db) {
df.write.mode(SaveMode.Overwrite).saveAsTable("t")
df.write.mode(SaveMode.Append).saveAsTable("t")
- assert(spark.wrapped.tableNames().contains("t"))
+ assert(spark.sqlContext.tableNames().contains("t"))
checkAnswer(spark.table("t"), df.union(df))
}
- assert(spark.wrapped.tableNames(db).contains("t"))
+ assert(spark.sqlContext.tableNames(db).contains("t"))
checkAnswer(spark.table(s"$db.t"), df.union(df))
checkTablePath(db, "t")
@@ -127,7 +127,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
withTempDatabase { db =>
df.write.mode(SaveMode.Overwrite).saveAsTable(s"$db.t")
df.write.mode(SaveMode.Append).saveAsTable(s"$db.t")
- assert(spark.wrapped.tableNames(db).contains("t"))
+ assert(spark.sqlContext.tableNames(db).contains("t"))
checkAnswer(spark.table(s"$db.t"), df.union(df))
checkTablePath(db, "t")
@@ -138,7 +138,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
withTempDatabase { db =>
activateDatabase(db) {
df.write.mode(SaveMode.Overwrite).saveAsTable("t")
- assert(spark.wrapped.tableNames().contains("t"))
+ assert(spark.sqlContext.tableNames().contains("t"))
df.write.insertInto(s"$db.t")
checkAnswer(spark.table(s"$db.t"), df.union(df))
@@ -150,10 +150,10 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
withTempDatabase { db =>
activateDatabase(db) {
df.write.mode(SaveMode.Overwrite).saveAsTable("t")
- assert(spark.wrapped.tableNames().contains("t"))
+ assert(spark.sqlContext.tableNames().contains("t"))
}
- assert(spark.wrapped.tableNames(db).contains("t"))
+ assert(spark.sqlContext.tableNames(db).contains("t"))
df.write.insertInto(s"$db.t")
checkAnswer(spark.table(s"$db.t"), df.union(df))
@@ -175,21 +175,21 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
withTempDatabase { db =>
activateDatabase(db) {
sql(s"CREATE TABLE t (key INT)")
- assert(spark.wrapped.tableNames().contains("t"))
- assert(!spark.wrapped.tableNames("default").contains("t"))
+ assert(spark.sqlContext.tableNames().contains("t"))
+ assert(!spark.sqlContext.tableNames("default").contains("t"))
}
- assert(!spark.wrapped.tableNames().contains("t"))
- assert(spark.wrapped.tableNames(db).contains("t"))
+ assert(!spark.sqlContext.tableNames().contains("t"))
+ assert(spark.sqlContext.tableNames(db).contains("t"))
activateDatabase(db) {
sql(s"DROP TABLE t")
- assert(!spark.wrapped.tableNames().contains("t"))
- assert(!spark.wrapped.tableNames("default").contains("t"))
+ assert(!spark.sqlContext.tableNames().contains("t"))
+ assert(!spark.sqlContext.tableNames("default").contains("t"))
}
- assert(!spark.wrapped.tableNames().contains("t"))
- assert(!spark.wrapped.tableNames(db).contains("t"))
+ assert(!spark.sqlContext.tableNames().contains("t"))
+ assert(!spark.sqlContext.tableNames(db).contains("t"))
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index 81f3ea8a6e..8a31a49d97 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -1417,7 +1417,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
""".stripMargin)
checkAnswer(
- spark.wrapped.tables().select('isTemporary).filter('tableName === "t2"),
+ spark.sqlContext.tables().select('isTemporary).filter('tableName === "t2"),
Row(true)
)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala
index aba60da33f..bb351e20c5 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala
@@ -61,7 +61,7 @@ private[sql] trait OrcTest extends SQLTestUtils with TestHiveSingleton {
(data: Seq[T], tableName: String)
(f: => Unit): Unit = {
withOrcDataFrame(data) { df =>
- spark.wrapped.registerDataFrameAsTable(df, tableName)
+ spark.sqlContext.registerDataFrameAsTable(df, tableName)
withTempTable(tableName)(f)
}
}