aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java2
-rw-r--r--python/pyspark/sql/context.py5
-rw-r--r--python/pyspark/sql/session.py17
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala124
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala155
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala13
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala19
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala13
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala100
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala93
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala16
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala14
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala18
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala4
-rw-r--r--sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala2
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala2
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala42
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala2
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala2
43 files changed, 367 insertions, 357 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java
index da623d1d15..7bda219243 100644
--- a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java
@@ -56,7 +56,7 @@ public class JavaDefaultReadWriteSuite extends SharedSparkSession {
} catch (IOException e) {
// expected
}
- instance.write().context(spark.wrapped()).overwrite().save(outputPath);
+ instance.write().context(spark.sqlContext()).overwrite().save(outputPath);
MyParams newInstance = MyParams.load(outputPath);
Assert.assertEquals("UID should match.", instance.uid(), newInstance.uid());
Assert.assertEquals("Params should be preserved.",
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index e8e60c6412..486733a390 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -34,7 +34,10 @@ __all__ = ["SQLContext", "HiveContext", "UDFRegistration"]
class SQLContext(object):
- """Wrapper around :class:`SparkSession`, the main entry point to Spark SQL functionality.
+ """The entry point for working with structured data (rows and columns) in Spark, in Spark 1.x.
+
+ As of Spark 2.0, this is replaced by :class:`SparkSession`. However, we are keeping the class
+ here for backward compatibility.
A SQLContext can be used create :class:`DataFrame`, register :class:`DataFrame` as
tables, execute SQL over tables, cache tables, and read parquet files.
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index 257a239c8d..0e04b88265 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -120,6 +120,8 @@ class SparkSession(object):
def appName(self, name):
"""Sets a name for the application, which will be shown in the Spark web UI.
+ If no application name is set, a randomly generated name will be used.
+
:param name: an application name
"""
return self.config("spark.app.name", name)
@@ -133,8 +135,17 @@ class SparkSession(object):
@since(2.0)
def getOrCreate(self):
- """Gets an existing :class:`SparkSession` or, if there is no existing one, creates a new
- one based on the options set in this builder.
+ """Gets an existing :class:`SparkSession` or, if there is no existing one, creates a
+ new one based on the options set in this builder.
+
+ This method first checks whether there is a valid thread-local SparkSession,
+ and if yes, return that one. It then checks whether there is a valid global
+ default SparkSession, and if yes, return that one. If no valid global default
+ SparkSession exists, the method creates a new SparkSession and assigns the
+ newly created SparkSession as the global default.
+
+ In case an existing SparkSession is returned, the config options specified
+ in this builder will be applied to the existing SparkSession.
"""
with self._lock:
from pyspark.conf import SparkConf
@@ -175,7 +186,7 @@ class SparkSession(object):
if jsparkSession is None:
jsparkSession = self._jvm.SparkSession(self._jsc.sc())
self._jsparkSession = jsparkSession
- self._jwrapped = self._jsparkSession.wrapped()
+ self._jwrapped = self._jsparkSession.sqlContext()
self._wrapped = SQLContext(self._sc, self, self._jwrapped)
_monkey_patch_RDD(self)
install_exception_handler()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 02dd6547a4..78a167eef2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -213,7 +213,7 @@ class Dataset[T] private[sql](
private implicit def classTag = unresolvedTEncoder.clsTag
// sqlContext must be val because a stable identifier is expected when you import implicits
- @transient lazy val sqlContext: SQLContext = sparkSession.wrapped
+ @transient lazy val sqlContext: SQLContext = sparkSession.sqlContext
protected[sql] def resolve(colName: String): NamedExpression = {
queryExecution.analyzed.resolveQuoted(colName, sparkSession.sessionState.analyzer.resolver)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index a3e2b49556..14d12d30bc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -19,25 +19,22 @@ package org.apache.spark.sql
import java.beans.BeanInfo
import java.util.Properties
-import java.util.concurrent.atomic.AtomicReference
import scala.collection.immutable
import scala.reflect.runtime.universe.TypeTag
-import org.apache.spark.{SparkConf, SparkContext, SparkException}
+import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.ConfigEntry
import org.apache.spark.rdd.RDD
-import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.command.ShowTablesCommand
-import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab}
import org.apache.spark.sql.internal.{SessionState, SharedState, SQLConf}
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types._
@@ -46,8 +43,8 @@ import org.apache.spark.sql.util.ExecutionListenerManager
/**
* The entry point for working with structured data (rows and columns) in Spark, in Spark 1.x.
*
- * As of Spark 2.0, this is replaced by [[SparkSession]]. However, we are keeping the class here
- * for backward compatibility.
+ * As of Spark 2.0, this is replaced by [[SparkSession]]. However, we are keeping the class
+ * here for backward compatibility.
*
* @groupname basic Basic Operations
* @groupname ddl_ops Persistent Catalog DDL
@@ -76,42 +73,21 @@ class SQLContext private[sql](
this(sparkSession, true)
}
+ @deprecated("Use SparkSession.builder instead", "2.0.0")
def this(sc: SparkContext) = {
this(new SparkSession(sc))
}
+ @deprecated("Use SparkSession.builder instead", "2.0.0")
def this(sparkContext: JavaSparkContext) = this(sparkContext.sc)
// TODO: move this logic into SparkSession
- // If spark.sql.allowMultipleContexts is true, we will throw an exception if a user
- // wants to create a new root SQLContext (a SQLContext that is not created by newSession).
- private val allowMultipleContexts =
- sparkContext.conf.getBoolean(
- SQLConf.ALLOW_MULTIPLE_CONTEXTS.key,
- SQLConf.ALLOW_MULTIPLE_CONTEXTS.defaultValue.get)
-
- // Assert no root SQLContext is running when allowMultipleContexts is false.
- {
- if (!allowMultipleContexts && isRootContext) {
- SQLContext.getInstantiatedContextOption() match {
- case Some(rootSQLContext) =>
- val errMsg = "Only one SQLContext/HiveContext may be running in this JVM. " +
- s"It is recommended to use SQLContext.getOrCreate to get the instantiated " +
- s"SQLContext/HiveContext. To ignore this error, " +
- s"set ${SQLConf.ALLOW_MULTIPLE_CONTEXTS.key} = true in SparkConf."
- throw new SparkException(errMsg)
- case None => // OK
- }
- }
- }
-
protected[sql] def sessionState: SessionState = sparkSession.sessionState
protected[sql] def sharedState: SharedState = sparkSession.sharedState
protected[sql] def conf: SQLConf = sessionState.conf
protected[sql] def runtimeConf: RuntimeConfig = sparkSession.conf
protected[sql] def cacheManager: CacheManager = sparkSession.cacheManager
- protected[sql] def listener: SQLListener = sparkSession.listener
protected[sql] def externalCatalog: ExternalCatalog = sparkSession.externalCatalog
def sparkContext: SparkContext = sparkSession.sparkContext
@@ -123,7 +99,7 @@ class SQLContext private[sql](
*
* @since 1.6.0
*/
- def newSession(): SQLContext = sparkSession.newSession().wrapped
+ def newSession(): SQLContext = sparkSession.newSession().sqlContext
/**
* An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s
@@ -760,21 +736,6 @@ class SQLContext private[sql](
schema: StructType): DataFrame = {
sparkSession.applySchemaToPythonRDD(rdd, schema)
}
-
- // TODO: move this logic into SparkSession
-
- // Register a successfully instantiated context to the singleton. This should be at the end of
- // the class definition so that the singleton is updated only if there is no exception in the
- // construction of the instance.
- sparkContext.addSparkListener(new SparkListener {
- override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = {
- SQLContext.clearInstantiatedContext()
- SQLContext.clearSqlListener()
- }
- })
-
- sparkSession.setWrappedContext(self)
- SQLContext.setInstantiatedContext(self)
}
/**
@@ -788,19 +749,6 @@ class SQLContext private[sql](
object SQLContext {
/**
- * The active SQLContext for the current thread.
- */
- private val activeContext: InheritableThreadLocal[SQLContext] =
- new InheritableThreadLocal[SQLContext]
-
- /**
- * Reference to the created SQLContext.
- */
- @transient private val instantiatedContext = new AtomicReference[SQLContext]()
-
- @transient private val sqlListener = new AtomicReference[SQLListener]()
-
- /**
* Get the singleton SQLContext if it exists or create a new one using the given SparkContext.
*
* This function can be used to create a singleton SQLContext object that can be shared across
@@ -811,41 +759,9 @@ object SQLContext {
*
* @since 1.5.0
*/
+ @deprecated("Use SparkSession.builder instead", "2.0.0")
def getOrCreate(sparkContext: SparkContext): SQLContext = {
- val ctx = activeContext.get()
- if (ctx != null && !ctx.sparkContext.isStopped) {
- return ctx
- }
-
- synchronized {
- val ctx = instantiatedContext.get()
- if (ctx == null || ctx.sparkContext.isStopped) {
- new SQLContext(sparkContext)
- } else {
- ctx
- }
- }
- }
-
- private[sql] def clearInstantiatedContext(): Unit = {
- instantiatedContext.set(null)
- }
-
- private[sql] def setInstantiatedContext(sqlContext: SQLContext): Unit = {
- synchronized {
- val ctx = instantiatedContext.get()
- if (ctx == null || ctx.sparkContext.isStopped) {
- instantiatedContext.set(sqlContext)
- }
- }
- }
-
- private[sql] def getInstantiatedContextOption(): Option[SQLContext] = {
- Option(instantiatedContext.get())
- }
-
- private[sql] def clearSqlListener(): Unit = {
- sqlListener.set(null)
+ SparkSession.builder().sparkContext(sparkContext).getOrCreate().sqlContext
}
/**
@@ -855,8 +771,9 @@ object SQLContext {
*
* @since 1.6.0
*/
+ @deprecated("Use SparkSession.setActiveSession instead", "2.0.0")
def setActive(sqlContext: SQLContext): Unit = {
- activeContext.set(sqlContext)
+ SparkSession.setActiveSession(sqlContext.sparkSession)
}
/**
@@ -865,12 +782,9 @@ object SQLContext {
*
* @since 1.6.0
*/
+ @deprecated("Use SparkSession.clearActiveSession instead", "2.0.0")
def clearActive(): Unit = {
- activeContext.remove()
- }
-
- private[sql] def getActive(): Option[SQLContext] = {
- Option(activeContext.get())
+ SparkSession.clearActiveSession()
}
/**
@@ -895,20 +809,6 @@ object SQLContext {
}
/**
- * Create a SQLListener then add it into SparkContext, and create an SQLTab if there is SparkUI.
- */
- private[sql] def createListenerAndUI(sc: SparkContext): SQLListener = {
- if (sqlListener.get() == null) {
- val listener = new SQLListener(sc.conf)
- if (sqlListener.compareAndSet(null, listener)) {
- sc.addSparkListener(listener)
- sc.ui.foreach(new SQLTab(listener, _))
- }
- }
- sqlListener.get()
- }
-
- /**
* Extract `spark.sql.*` properties from the conf and return them as a [[Properties]].
*/
private[sql] def getSQLProperties(sparkConf: SparkConf): Properties = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 629243bd1a..f697769bdc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql
import java.beans.Introspector
+import java.util.concurrent.atomic.AtomicReference
import scala.collection.JavaConverters._
import scala.reflect.ClassTag
@@ -30,6 +31,7 @@ import org.apache.spark.api.java.JavaRDD
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.CATALOG_IMPLEMENTATION
import org.apache.spark.rdd.RDD
+import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
import org.apache.spark.sql.catalog.Catalog
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.catalog._
@@ -98,24 +100,10 @@ class SparkSession private(
}
/**
- * A wrapped version of this session in the form of a [[SQLContext]].
+ * A wrapped version of this session in the form of a [[SQLContext]], for backward compatibility.
*/
@transient
- private var _wrapped: SQLContext = _
-
- @transient
- private val _wrappedLock = new Object
-
- protected[sql] def wrapped: SQLContext = _wrappedLock.synchronized {
- if (_wrapped == null) {
- _wrapped = new SQLContext(self, isRootContext = false)
- }
- _wrapped
- }
-
- protected[sql] def setWrappedContext(sqlContext: SQLContext): Unit = _wrappedLock.synchronized {
- _wrapped = sqlContext
- }
+ private[sql] val sqlContext: SQLContext = new SQLContext(this)
protected[sql] def cacheManager: CacheManager = sharedState.cacheManager
protected[sql] def listener: SQLListener = sharedState.listener
@@ -238,7 +226,7 @@ class SparkSession private(
*/
@Experimental
def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = {
- SQLContext.setActive(wrapped)
+ SparkSession.setActiveSession(this)
val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
val attributeSeq = schema.toAttributes
val rowRDD = RDDConversions.productToRowRdd(rdd, schema.map(_.dataType))
@@ -254,7 +242,7 @@ class SparkSession private(
*/
@Experimental
def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = {
- SQLContext.setActive(wrapped)
+ SparkSession.setActiveSession(this)
val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
val attributeSeq = schema.toAttributes
Dataset.ofRows(self, LocalRelation.fromProduct(attributeSeq, data))
@@ -573,7 +561,7 @@ class SparkSession private(
*/
@Experimental
object implicits extends SQLImplicits with Serializable {
- protected override def _sqlContext: SQLContext = wrapped
+ protected override def _sqlContext: SQLContext = SparkSession.this.sqlContext
}
// scalastyle:on
@@ -649,8 +637,16 @@ object SparkSession {
private[this] val options = new scala.collection.mutable.HashMap[String, String]
+ private[this] var userSuppliedContext: Option[SparkContext] = None
+
+ private[sql] def sparkContext(sparkContext: SparkContext): Builder = synchronized {
+ userSuppliedContext = Option(sparkContext)
+ this
+ }
+
/**
* Sets a name for the application, which will be shown in the Spark web UI.
+ * If no application name is set, a randomly generated name will be used.
*
* @since 2.0.0
*/
@@ -735,29 +731,130 @@ object SparkSession {
}
/**
- * Gets an existing [[SparkSession]] or, if there is no existing one, creates a new one
- * based on the options set in this builder.
+ * Gets an existing [[SparkSession]] or, if there is no existing one, creates a new
+ * one based on the options set in this builder.
+ *
+ * This method first checks whether there is a valid thread-local SparkSession,
+ * and if yes, return that one. It then checks whether there is a valid global
+ * default SparkSession, and if yes, return that one. If no valid global default
+ * SparkSession exists, the method creates a new SparkSession and assigns the
+ * newly created SparkSession as the global default.
+ *
+ * In case an existing SparkSession is returned, the config options specified in
+ * this builder will be applied to the existing SparkSession.
*
* @since 2.0.0
*/
def getOrCreate(): SparkSession = synchronized {
- // Step 1. Create a SparkConf
- // Step 2. Get a SparkContext
- // Step 3. Get a SparkSession
- val sparkConf = new SparkConf()
- options.foreach { case (k, v) => sparkConf.set(k, v) }
- val sparkContext = SparkContext.getOrCreate(sparkConf)
-
- SQLContext.getOrCreate(sparkContext).sparkSession
+ // Get the session from current thread's active session.
+ var session = activeThreadSession.get()
+ if ((session ne null) && !session.sparkContext.isStopped) {
+ options.foreach { case (k, v) => session.conf.set(k, v) }
+ return session
+ }
+
+ // Global synchronization so we will only set the default session once.
+ SparkSession.synchronized {
+ // If the current thread does not have an active session, get it from the global session.
+ session = defaultSession.get()
+ if ((session ne null) && !session.sparkContext.isStopped) {
+ options.foreach { case (k, v) => session.conf.set(k, v) }
+ return session
+ }
+
+ // No active nor global default session. Create a new one.
+ val sparkContext = userSuppliedContext.getOrElse {
+ // set app name if not given
+ if (!options.contains("spark.app.name")) {
+ options += "spark.app.name" -> java.util.UUID.randomUUID().toString
+ }
+
+ val sparkConf = new SparkConf()
+ options.foreach { case (k, v) => sparkConf.set(k, v) }
+ SparkContext.getOrCreate(sparkConf)
+ }
+ session = new SparkSession(sparkContext)
+ options.foreach { case (k, v) => session.conf.set(k, v) }
+ defaultSession.set(session)
+
+ // Register a successfully instantiated context to the singleton. This should be at the
+ // end of the class definition so that the singleton is updated only if there is no
+ // exception in the construction of the instance.
+ sparkContext.addSparkListener(new SparkListener {
+ override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = {
+ defaultSession.set(null)
+ sqlListener.set(null)
+ }
+ })
+ }
+
+ return session
}
}
/**
* Creates a [[SparkSession.Builder]] for constructing a [[SparkSession]].
+ *
* @since 2.0.0
*/
def builder(): Builder = new Builder
+ /**
+ * Changes the SparkSession that will be returned in this thread and its children when
+ * SparkSession.getOrCreate() is called. This can be used to ensure that a given thread receives
+ * a SparkSession with an isolated session, instead of the global (first created) context.
+ *
+ * @since 2.0.0
+ */
+ def setActiveSession(session: SparkSession): Unit = {
+ activeThreadSession.set(session)
+ }
+
+ /**
+ * Clears the active SparkSession for current thread. Subsequent calls to getOrCreate will
+ * return the first created context instead of a thread-local override.
+ *
+ * @since 2.0.0
+ */
+ def clearActiveSession(): Unit = {
+ activeThreadSession.remove()
+ }
+
+ /**
+ * Sets the default SparkSession that is returned by the builder.
+ *
+ * @since 2.0.0
+ */
+ def setDefaultSession(session: SparkSession): Unit = {
+ defaultSession.set(session)
+ }
+
+ /**
+ * Clears the default SparkSession that is returned by the builder.
+ *
+ * @since 2.0.0
+ */
+ def clearDefaultSession(): Unit = {
+ defaultSession.set(null)
+ }
+
+ private[sql] def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get)
+
+ private[sql] def getDefaultSession: Option[SparkSession] = Option(defaultSession.get)
+
+ /** A global SQL listener used for the SQL UI. */
+ private[sql] val sqlListener = new AtomicReference[SQLListener]()
+
+ ////////////////////////////////////////////////////////////////////////////////////////
+ // Private methods from now on
+ ////////////////////////////////////////////////////////////////////////////////////////
+
+ /** The active SparkSession for the current thread. */
+ private val activeThreadSession = new InheritableThreadLocal[SparkSession]
+
+ /** Reference to the root SparkSession. */
+ private val defaultSession = new AtomicReference[SparkSession]
+
private val HIVE_SHARED_STATE_CLASS_NAME = "org.apache.spark.sql.hive.HiveSharedState"
private val HIVE_SESSION_STATE_CLASS_NAME = "org.apache.spark.sql.hive.HiveSessionState"
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index 85af4faf4d..d8911f88b0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -157,7 +157,8 @@ private[sql] case class RowDataSourceScanExec(
val outputUnsafeRows = relation match {
case r: HadoopFsRelation if r.fileFormat.isInstanceOf[ParquetSource] =>
- !SQLContext.getActive().get.conf.getConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED)
+ !SparkSession.getActiveSession.get.sessionState.conf.getConf(
+ SQLConf.PARQUET_VECTORIZED_READER_ENABLED)
case _: HadoopFsRelation => true
case _ => false
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index cb3c46a98b..34187b9a1a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -60,7 +60,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
}
lazy val analyzed: LogicalPlan = {
- SQLContext.setActive(sparkSession.wrapped)
+ SparkSession.setActiveSession(sparkSession)
sparkSession.sessionState.analyzer.execute(logical)
}
@@ -73,7 +73,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
lazy val optimizedPlan: LogicalPlan = sparkSession.sessionState.optimizer.execute(withCachedData)
lazy val sparkPlan: SparkPlan = {
- SQLContext.setActive(sparkSession.wrapped)
+ SparkSession.setActiveSession(sparkSession)
planner.plan(ReturnAnswer(optimizedPlan)).next()
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index b94b84d77a..045ccc7bd6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -27,7 +27,7 @@ import org.apache.spark.{broadcast, SparkEnv}
import org.apache.spark.internal.Logging
import org.apache.spark.io.CompressionCodec
import org.apache.spark.rdd.{RDD, RDDOperationScope}
-import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.{Row, SparkSession, SQLContext}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
@@ -50,7 +50,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
* populated by the query planning infrastructure.
*/
@transient
- protected[spark] final val sqlContext = SQLContext.getActive().orNull
+ final val sqlContext = SparkSession.getActiveSession.map(_.sqlContext).orNull
protected def sparkContext = sqlContext.sparkContext
@@ -65,7 +65,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
/** Overridden make copy also propagates sqlContext to copied plan. */
override def makeCopy(newArgs: Array[AnyRef]): SparkPlan = {
- SQLContext.setActive(sqlContext)
+ SparkSession.setActiveSession(sqlContext.sparkSession)
super.makeCopy(newArgs)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index ccad9b3fd5..2e17b763a5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -178,7 +178,7 @@ case class DataSource(
providingClass.newInstance() match {
case s: StreamSourceProvider =>
val (name, schema) = s.sourceSchema(
- sparkSession.wrapped, userSpecifiedSchema, className, options)
+ sparkSession.sqlContext, userSpecifiedSchema, className, options)
SourceInfo(name, schema)
case format: FileFormat =>
@@ -198,7 +198,8 @@ case class DataSource(
def createSource(metadataPath: String): Source = {
providingClass.newInstance() match {
case s: StreamSourceProvider =>
- s.createSource(sparkSession.wrapped, metadataPath, userSpecifiedSchema, className, options)
+ s.createSource(
+ sparkSession.sqlContext, metadataPath, userSpecifiedSchema, className, options)
case format: FileFormat =>
val path = new CaseInsensitiveMap(options).getOrElse("path", {
@@ -215,7 +216,7 @@ case class DataSource(
/** Returns a sink that can be used to continually write data. */
def createSink(): Sink = {
providingClass.newInstance() match {
- case s: StreamSinkProvider => s.createSink(sparkSession.wrapped, options, partitionColumns)
+ case s: StreamSinkProvider => s.createSink(sparkSession.sqlContext, options, partitionColumns)
case parquet: parquet.DefaultSource =>
val caseInsensitiveOptions = new CaseInsensitiveMap(options)
@@ -265,9 +266,9 @@ case class DataSource(
val relation = (providingClass.newInstance(), userSpecifiedSchema) match {
// TODO: Throw when too much is given.
case (dataSource: SchemaRelationProvider, Some(schema)) =>
- dataSource.createRelation(sparkSession.wrapped, caseInsensitiveOptions, schema)
+ dataSource.createRelation(sparkSession.sqlContext, caseInsensitiveOptions, schema)
case (dataSource: RelationProvider, None) =>
- dataSource.createRelation(sparkSession.wrapped, caseInsensitiveOptions)
+ dataSource.createRelation(sparkSession.sqlContext, caseInsensitiveOptions)
case (_: SchemaRelationProvider, None) =>
throw new AnalysisException(s"A schema needs to be specified when using $className.")
case (_: RelationProvider, Some(_)) =>
@@ -383,7 +384,7 @@ case class DataSource(
providingClass.newInstance() match {
case dataSource: CreatableRelationProvider =>
- dataSource.createRelation(sparkSession.wrapped, mode, options, data)
+ dataSource.createRelation(sparkSession.sqlContext, mode, options, data)
case format: FileFormat =>
// Don't glob path for the write path. The contracts here are:
// 1. Only one output path can be specified on the write path;
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala
index 8d332df029..88125a2b4d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala
@@ -142,7 +142,7 @@ case class HadoopFsRelation(
fileFormat: FileFormat,
options: Map[String, String]) extends BaseRelation with FileRelation {
- override def sqlContext: SQLContext = sparkSession.wrapped
+ override def sqlContext: SQLContext = sparkSession.sqlContext
val schema: StructType = {
val dataSchemaColumnNames = dataSchema.map(_.name.toLowerCase).toSet
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
index bcf70fdc4a..233b7891d6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
@@ -92,7 +92,7 @@ private[sql] case class JDBCRelation(
with PrunedFilteredScan
with InsertableRelation {
- override def sqlContext: SQLContext = sparkSession.wrapped
+ override def sqlContext: SQLContext = sparkSession.sqlContext
override val needConversion: Boolean = false
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index df6304d85f..7d09bdcebd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -173,7 +173,7 @@ class StreamExecution(
startLatch.countDown()
// While active, repeatedly attempt to run batches.
- SQLContext.setActive(sparkSession.wrapped)
+ SparkSession.setActiveSession(sparkSession)
triggerExecutor.execute(() => {
if (isActive) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 65bc043076..0b490fe71c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1168,7 +1168,7 @@ object functions {
* @group normal_funcs
*/
def expr(expr: String): Column = {
- val parser = SQLContext.getActive().map(_.sessionState.sqlParser).getOrElse {
+ val parser = SparkSession.getActiveSession.map(_.sessionState.sqlParser).getOrElse {
new SparkSqlParser(new SQLConf)
}
Column(parser.parseExpression(expr))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 5d18689801..35d67ca2d8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -70,16 +70,6 @@ object SQLConf {
.intConf
.createWithDefault(10)
- val ALLOW_MULTIPLE_CONTEXTS = SQLConfigBuilder("spark.sql.allowMultipleContexts")
- .doc("When set to true, creating multiple SQLContexts/HiveContexts is allowed. " +
- "When set to false, only one SQLContext/HiveContext is allowed to be created " +
- "through the constructor (new SQLContexts/HiveContexts created through newSession " +
- "method is allowed). Please note that this conf needs to be set in Spark Conf. Once " +
- "a SQLContext/HiveContext has been created, changing the value of this conf will not " +
- "have effect.")
- .booleanConf
- .createWithDefault(true)
-
val COMPRESS_CACHED = SQLConfigBuilder("spark.sql.inMemoryColumnarStorage.compressed")
.internal()
.doc("When set to true Spark SQL will automatically select a compression codec for each " +
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala
index eaf993aaed..9f6137d6e3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala
@@ -18,10 +18,10 @@
package org.apache.spark.sql.internal
import org.apache.spark.SparkContext
-import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.{SparkSession, SQLContext}
import org.apache.spark.sql.catalyst.catalog.{ExternalCatalog, InMemoryCatalog}
import org.apache.spark.sql.execution.CacheManager
-import org.apache.spark.sql.execution.ui.SQLListener
+import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab}
import org.apache.spark.util.MutableURLClassLoader
@@ -38,7 +38,7 @@ private[sql] class SharedState(val sparkContext: SparkContext) {
/**
* A listener for SQL-specific [[org.apache.spark.scheduler.SparkListenerEvent]]s.
*/
- val listener: SQLListener = SQLContext.createListenerAndUI(sparkContext)
+ val listener: SQLListener = createListenerAndUI(sparkContext)
/**
* A catalog that interacts with external systems.
@@ -51,6 +51,19 @@ private[sql] class SharedState(val sparkContext: SparkContext) {
val jarClassLoader = new NonClosableMutableURLClassLoader(
org.apache.spark.util.Utils.getContextOrSparkClassLoader)
+ /**
+ * Create a SQLListener then add it into SparkContext, and create an SQLTab if there is SparkUI.
+ */
+ private def createListenerAndUI(sc: SparkContext): SQLListener = {
+ if (SparkSession.sqlListener.get() == null) {
+ val listener = new SQLListener(sc.conf)
+ if (SparkSession.sqlListener.compareAndSet(null, listener)) {
+ sc.addSparkListener(listener)
+ sc.ui.foreach(new SQLTab(listener, _))
+ }
+ }
+ SparkSession.sqlListener.get()
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala
index 65fe271b69..b447006761 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala
@@ -39,7 +39,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex
test("get all tables") {
checkAnswer(
- spark.wrapped.tables().filter("tableName = 'listtablessuitetable'"),
+ spark.sqlContext.tables().filter("tableName = 'listtablessuitetable'"),
Row("listtablessuitetable", true))
checkAnswer(
@@ -48,12 +48,12 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex
spark.sessionState.catalog.dropTable(
TableIdentifier("listtablessuitetable"), ignoreIfNotExists = true)
- assert(spark.wrapped.tables().filter("tableName = 'listtablessuitetable'").count() === 0)
+ assert(spark.sqlContext.tables().filter("tableName = 'listtablessuitetable'").count() === 0)
}
test("getting all tables with a database name has no impact on returned table names") {
checkAnswer(
- spark.wrapped.tables("default").filter("tableName = 'listtablessuitetable'"),
+ spark.sqlContext.tables("default").filter("tableName = 'listtablessuitetable'"),
Row("listtablessuitetable", true))
checkAnswer(
@@ -62,7 +62,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex
spark.sessionState.catalog.dropTable(
TableIdentifier("listtablessuitetable"), ignoreIfNotExists = true)
- assert(spark.wrapped.tables().filter("tableName = 'listtablessuitetable'").count() === 0)
+ assert(spark.sqlContext.tables().filter("tableName = 'listtablessuitetable'").count() === 0)
}
test("query the returned DataFrame of tables") {
@@ -70,7 +70,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex
StructField("tableName", StringType, false) ::
StructField("isTemporary", BooleanType, false) :: Nil)
- Seq(spark.wrapped.tables(), sql("SHOW TABLes")).foreach {
+ Seq(spark.sqlContext.tables(), sql("SHOW TABLes")).foreach {
case tableDF =>
assert(expectedSchema === tableDF.schema)
@@ -81,7 +81,8 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex
Row(true, "listtablessuitetable")
)
checkAnswer(
- spark.wrapped.tables().filter("tableName = 'tables'").select("tableName", "isTemporary"),
+ spark.sqlContext.tables()
+ .filter("tableName = 'tables'").select("tableName", "isTemporary"),
Row("tables", true))
spark.catalog.dropTempView("tables")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala
deleted file mode 100644
index 0b5a92c256..0000000000
--- a/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala
+++ /dev/null
@@ -1,100 +0,0 @@
-/*
-* Licensed to the Apache Software Foundation (ASF) under one or more
-* contributor license agreements. See the NOTICE file distributed with
-* this work for additional information regarding copyright ownership.
-* The ASF licenses this file to You under the Apache License, Version 2.0
-* (the "License"); you may not use this file except in compliance with
-* the License. You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*/
-
-package org.apache.spark.sql
-
-import org.scalatest.BeforeAndAfterAll
-
-import org.apache.spark._
-import org.apache.spark.sql.internal.SQLConf
-
-class MultiSQLContextsSuite extends SparkFunSuite with BeforeAndAfterAll {
-
- private var originalActiveSQLContext: Option[SQLContext] = _
- private var originalInstantiatedSQLContext: Option[SQLContext] = _
- private var sparkConf: SparkConf = _
-
- override protected def beforeAll(): Unit = {
- originalActiveSQLContext = SQLContext.getActive()
- originalInstantiatedSQLContext = SQLContext.getInstantiatedContextOption()
-
- SQLContext.clearActive()
- SQLContext.clearInstantiatedContext()
- sparkConf =
- new SparkConf(false)
- .setMaster("local[*]")
- .setAppName("test")
- .set("spark.ui.enabled", "false")
- .set("spark.driver.allowMultipleContexts", "true")
- }
-
- override protected def afterAll(): Unit = {
- // Set these states back.
- originalActiveSQLContext.foreach(ctx => SQLContext.setActive(ctx))
- originalInstantiatedSQLContext.foreach(ctx => SQLContext.setInstantiatedContext(ctx))
- }
-
- def testNewSession(rootSQLContext: SQLContext): Unit = {
- // Make sure we can successfully create new Session.
- rootSQLContext.newSession()
-
- // Reset the state. It is always safe to clear the active context.
- SQLContext.clearActive()
- }
-
- def testCreatingNewSQLContext(allowsMultipleContexts: Boolean): Unit = {
- val conf =
- sparkConf
- .clone
- .set(SQLConf.ALLOW_MULTIPLE_CONTEXTS.key, allowsMultipleContexts.toString)
- val sparkContext = new SparkContext(conf)
-
- try {
- if (allowsMultipleContexts) {
- new SQLContext(sparkContext)
- SQLContext.clearActive()
- } else {
- // If allowsMultipleContexts is false, make sure we can get the error.
- val message = intercept[SparkException] {
- new SQLContext(sparkContext)
- }.getMessage
- assert(message.contains("Only one SQLContext/HiveContext may be running"))
- }
- } finally {
- sparkContext.stop()
- }
- }
-
- test("test the flag to disallow creating multiple root SQLContext") {
- Seq(false, true).foreach { allowMultipleSQLContexts =>
- val conf =
- sparkConf
- .clone
- .set(SQLConf.ALLOW_MULTIPLE_CONTEXTS.key, allowMultipleSQLContexts.toString)
- val sc = new SparkContext(conf)
- try {
- val rootSQLContext = new SQLContext(sc)
- testNewSession(rootSQLContext)
- testNewSession(rootSQLContext)
- testCreatingNewSQLContext(allowMultipleSQLContexts)
- } finally {
- sc.stop()
- SQLContext.clearInstantiatedContext()
- }
- }
- }
-}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala
index 38d7b6e25b..c9594a7e9a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala
@@ -40,7 +40,7 @@ class SQLContextSuite extends SparkFunSuite with SharedSparkContext {
val newSession = sqlContext.newSession()
assert(SQLContext.getOrCreate(sc).eq(sqlContext),
"SQLContext.getOrCreate after explicitly created SQLContext did not return the context")
- SQLContext.setActive(newSession)
+ SparkSession.setActiveSession(newSession.sparkSession)
assert(SQLContext.getOrCreate(sc).eq(newSession),
"SQLContext.getOrCreate after explicitly setActive() did not return the active context")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 743a27aa7a..460e34a5ff 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -1042,7 +1042,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
test("SET commands semantics using sql()") {
- spark.wrapped.conf.clear()
+ spark.sqlContext.conf.clear()
val testKey = "test.key.0"
val testVal = "test.val.0"
val nonexistentKey = "nonexistent"
@@ -1083,17 +1083,17 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
sql(s"SET $nonexistentKey"),
Row(nonexistentKey, "<undefined>")
)
- spark.wrapped.conf.clear()
+ spark.sqlContext.conf.clear()
}
test("SET commands with illegal or inappropriate argument") {
- spark.wrapped.conf.clear()
+ spark.sqlContext.conf.clear()
// Set negative mapred.reduce.tasks for automatically determining
// the number of reducers is not supported
intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-1"))
intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-01"))
intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-2"))
- spark.wrapped.conf.clear()
+ spark.sqlContext.conf.clear()
}
test("apply schema") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala
index b489b74fec..cd6b2647e0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala
@@ -25,6 +25,6 @@ class SerializationSuite extends SparkFunSuite with SharedSQLContext {
test("[SPARK-5235] SQLContext should be serializable") {
val spark = SparkSession.builder.getOrCreate()
- new JavaSerializer(new SparkConf()).newInstance().serialize(spark.wrapped)
+ new JavaSerializer(new SparkConf()).newInstance().serialize(spark.sqlContext)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala
new file mode 100644
index 0000000000..ec6a2b3575
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.{SparkContext, SparkFunSuite}
+
+/**
+ * Test cases for the builder pattern of [[SparkSession]].
+ */
+class SparkSessionBuilderSuite extends SparkFunSuite {
+
+ private var initialSession: SparkSession = _
+
+ private lazy val sparkContext: SparkContext = {
+ initialSession = SparkSession.builder()
+ .master("local")
+ .config("spark.ui.enabled", value = false)
+ .config("some-config", "v2")
+ .getOrCreate()
+ initialSession.sparkContext
+ }
+
+ test("create with config options and propagate them to SparkContext and SparkSession") {
+ // Creating a new session with config - this works by just calling the lazy val
+ sparkContext
+ assert(initialSession.sparkContext.conf.get("some-config") == "v2")
+ assert(initialSession.conf.get("some-config") == "v2")
+ SparkSession.clearDefaultSession()
+ }
+
+ test("use global default session") {
+ val session = SparkSession.builder().getOrCreate()
+ assert(SparkSession.builder().getOrCreate() == session)
+ SparkSession.clearDefaultSession()
+ }
+
+ test("config options are propagated to existing SparkSession") {
+ val session1 = SparkSession.builder().config("spark-config1", "a").getOrCreate()
+ assert(session1.conf.get("spark-config1") == "a")
+ val session2 = SparkSession.builder().config("spark-config1", "b").getOrCreate()
+ assert(session1 == session2)
+ assert(session1.conf.get("spark-config1") == "b")
+ SparkSession.clearDefaultSession()
+ }
+
+ test("use session from active thread session and propagate config options") {
+ val defaultSession = SparkSession.builder().getOrCreate()
+ val activeSession = defaultSession.newSession()
+ SparkSession.setActiveSession(activeSession)
+ val session = SparkSession.builder().config("spark-config2", "a").getOrCreate()
+
+ assert(activeSession != defaultSession)
+ assert(session == activeSession)
+ assert(session.conf.get("spark-config2") == "a")
+ SparkSession.clearActiveSession()
+
+ assert(SparkSession.builder().getOrCreate() == defaultSession)
+ SparkSession.clearDefaultSession()
+ }
+
+ test("create a new session if the default session has been stopped") {
+ val defaultSession = SparkSession.builder().getOrCreate()
+ SparkSession.setDefaultSession(defaultSession)
+ defaultSession.stop()
+ val newSession = SparkSession.builder().master("local").getOrCreate()
+ assert(newSession != defaultSession)
+ newSession.stop()
+ }
+
+ test("create a new session if the active thread session has been stopped") {
+ val activeSession = SparkSession.builder().master("local").getOrCreate()
+ SparkSession.setActiveSession(activeSession)
+ activeSession.stop()
+ val newSession = SparkSession.builder().master("local").getOrCreate()
+ assert(newSession != activeSession)
+ newSession.stop()
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala
index 9523f6f9f5..4de3cf605c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala
@@ -26,9 +26,9 @@ class StatisticsSuite extends QueryTest with SharedSQLContext {
val rdd = sparkContext.range(1, 100).map(i => Row(i, i))
val df = spark.createDataFrame(rdd, new StructType().add("a", LongType).add("b", LongType))
assert(df.queryExecution.analyzed.statistics.sizeInBytes >
- spark.wrapped.conf.autoBroadcastJoinThreshold)
+ spark.sessionState.conf.autoBroadcastJoinThreshold)
assert(df.selectExpr("a").queryExecution.analyzed.statistics.sizeInBytes >
- spark.wrapped.conf.autoBroadcastJoinThreshold)
+ spark.sessionState.conf.autoBroadcastJoinThreshold)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala
index 70a00a43f7..2f45db3925 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala
@@ -27,21 +27,21 @@ import org.apache.spark.sql.internal.SQLConf
class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll {
- private var originalActiveSQLContext: Option[SQLContext] = _
- private var originalInstantiatedSQLContext: Option[SQLContext] = _
+ private var originalActiveSQLContext: Option[SparkSession] = _
+ private var originalInstantiatedSQLContext: Option[SparkSession] = _
override protected def beforeAll(): Unit = {
- originalActiveSQLContext = SQLContext.getActive()
- originalInstantiatedSQLContext = SQLContext.getInstantiatedContextOption()
+ originalActiveSQLContext = SparkSession.getActiveSession
+ originalInstantiatedSQLContext = SparkSession.getDefaultSession
- SQLContext.clearActive()
- SQLContext.clearInstantiatedContext()
+ SparkSession.clearActiveSession()
+ SparkSession.clearDefaultSession()
}
override protected def afterAll(): Unit = {
// Set these states back.
- originalActiveSQLContext.foreach(ctx => SQLContext.setActive(ctx))
- originalInstantiatedSQLContext.foreach(ctx => SQLContext.setInstantiatedContext(ctx))
+ originalActiveSQLContext.foreach(ctx => SparkSession.setActiveSession(ctx))
+ originalInstantiatedSQLContext.foreach(ctx => SparkSession.setDefaultSession(ctx))
}
private def checkEstimation(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 2a5295d0d2..8243470b19 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -155,7 +155,7 @@ class PlannerSuite extends SharedSQLContext {
val path = file.getCanonicalPath
testData.write.parquet(path)
val df = spark.read.parquet(path)
- spark.wrapped.registerDataFrameAsTable(df, "testPushed")
+ spark.sqlContext.registerDataFrameAsTable(df, "testPushed")
withTempTable("testPushed") {
val exp = sql("select * from testPushed where key = 15").queryExecution.sparkPlan
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
index d7eae21f9f..9fe0e9646e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
@@ -91,7 +91,7 @@ private[sql] abstract class SparkPlanTest extends SparkFunSuite {
expectedAnswer: Seq[Row],
sortAnswers: Boolean = true): Unit = {
SparkPlanTest
- .checkAnswer(input, planFunction, expectedAnswer, sortAnswers, spark.wrapped) match {
+ .checkAnswer(input, planFunction, expectedAnswer, sortAnswers, spark.sqlContext) match {
case Some(errorMessage) => fail(errorMessage)
case None =>
}
@@ -115,7 +115,7 @@ private[sql] abstract class SparkPlanTest extends SparkFunSuite {
expectedPlanFunction: SparkPlan => SparkPlan,
sortAnswers: Boolean = true): Unit = {
SparkPlanTest.checkAnswer(
- input, planFunction, expectedPlanFunction, sortAnswers, spark.wrapped) match {
+ input, planFunction, expectedPlanFunction, sortAnswers, spark.sqlContext) match {
case Some(errorMessage) => fail(errorMessage)
case None =>
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala
index b5fc51603e..1753b84ba6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala
@@ -90,7 +90,7 @@ private[sql] trait ParquetTest extends SQLTestUtils {
(data: Seq[T], tableName: String, testVectorized: Boolean = true)
(f: => Unit): Unit = {
withParquetDataFrame(data, testVectorized) { df =>
- spark.wrapped.registerDataFrameAsTable(df, tableName)
+ spark.sqlContext.registerDataFrameAsTable(df, tableName)
withTempTable(tableName)(f)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
index 4fa1754253..bd197be655 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
@@ -60,13 +60,13 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
val opId = 0
val rdd1 =
makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore(
- spark.wrapped, path, opId, storeVersion = 0, keySchema, valueSchema)(
+ spark.sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(
increment)
assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
// Generate next version of stores
val rdd2 = makeRDD(spark.sparkContext, Seq("a", "c")).mapPartitionsWithStateStore(
- spark.wrapped, path, opId, storeVersion = 1, keySchema, valueSchema)(increment)
+ spark.sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(increment)
assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1))
// Make sure the previous RDD still has the same data.
@@ -82,7 +82,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
spark: SparkSession,
seq: Seq[String],
storeVersion: Int): RDD[(String, Int)] = {
- implicit val sqlContext = spark.wrapped
+ implicit val sqlContext = spark.sqlContext
makeRDD(spark.sparkContext, Seq("a")).mapPartitionsWithStateStore(
sqlContext, path, opId, storeVersion, keySchema, valueSchema)(increment)
}
@@ -102,7 +102,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
test("usage with iterators - only gets and only puts") {
withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark =>
- implicit val sqlContext = spark.wrapped
+ implicit val sqlContext = spark.sqlContext
val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
val opId = 0
@@ -131,7 +131,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
}
val rddOfGets1 = makeRDD(spark.sparkContext, Seq("a", "b", "c")).mapPartitionsWithStateStore(
- spark.wrapped, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfGets)
+ spark.sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfGets)
assert(rddOfGets1.collect().toSet === Set("a" -> None, "b" -> None, "c" -> None))
val rddOfPuts = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore(
@@ -150,7 +150,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark =>
- implicit val sqlContext = spark.wrapped
+ implicit val sqlContext = spark.sqlContext
val coordinatorRef = sqlContext.streams.stateStoreCoordinator
coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 0), "host1", "exec1")
coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 1), "host2", "exec2")
@@ -183,7 +183,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
SparkSession.builder
.config(sparkConf.setMaster("local-cluster[2, 1, 1024]"))
.getOrCreate()) { spark =>
- implicit val sqlContext = spark.wrapped
+ implicit val sqlContext = spark.sqlContext
val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
val opId = 0
val rdd1 = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
index 1c467137ba..2374ffaaa5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
@@ -24,7 +24,7 @@ import org.mockito.Mockito.mock
import org.apache.spark._
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.scheduler._
-import org.apache.spark.sql.{DataFrame, SQLContext}
+import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext}
import org.apache.spark.sql.catalyst.util.quietly
import org.apache.spark.sql.execution.{SparkPlanInfo, SQLExecution}
import org.apache.spark.sql.execution.metric.SQLMetrics
@@ -400,8 +400,8 @@ class SQLListenerMemoryLeakSuite extends SparkFunSuite {
.set("spark.sql.ui.retainedExecutions", "50") // Set it to 50 to run this test quickly
val sc = new SparkContext(conf)
try {
- SQLContext.clearSqlListener()
- val spark = new SQLContext(sc)
+ SparkSession.sqlListener.set(null)
+ val spark = new SparkSession(sc)
import spark.implicits._
// Run 100 successful executions and 100 failed executions.
// Each execution only has one job and one stage.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala
index 81bc973be7..0296229100 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala
@@ -35,7 +35,7 @@ class SQLConfSuite extends QueryTest with SharedSQLContext {
// Set a conf first.
spark.conf.set(testKey, testVal)
// Clear the conf.
- spark.wrapped.conf.clear()
+ spark.sqlContext.conf.clear()
// After clear, only overrideConfs used by unit test should be in the SQLConf.
assert(spark.conf.getAll === TestSQLContext.overrideConfs)
@@ -50,11 +50,11 @@ class SQLConfSuite extends QueryTest with SharedSQLContext {
assert(spark.conf.get(testKey, testVal + "_") === testVal)
assert(spark.conf.getAll.contains(testKey))
- spark.wrapped.conf.clear()
+ spark.sqlContext.conf.clear()
}
test("parse SQL set commands") {
- spark.wrapped.conf.clear()
+ spark.sqlContext.conf.clear()
sql(s"set $testKey=$testVal")
assert(spark.conf.get(testKey, testVal + "_") === testVal)
assert(spark.conf.get(testKey, testVal + "_") === testVal)
@@ -72,11 +72,11 @@ class SQLConfSuite extends QueryTest with SharedSQLContext {
sql(s"set $key=")
assert(spark.conf.get(key, "0") === "")
- spark.wrapped.conf.clear()
+ spark.sqlContext.conf.clear()
}
test("set command for display") {
- spark.wrapped.conf.clear()
+ spark.sessionState.conf.clear()
checkAnswer(
sql("SET").where("key = 'spark.sql.groupByOrdinal'").select("key", "value"),
Nil)
@@ -97,7 +97,7 @@ class SQLConfSuite extends QueryTest with SharedSQLContext {
}
test("deprecated property") {
- spark.wrapped.conf.clear()
+ spark.sqlContext.conf.clear()
val original = spark.conf.get(SQLConf.SHUFFLE_PARTITIONS)
try{
sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10")
@@ -108,7 +108,7 @@ class SQLConfSuite extends QueryTest with SharedSQLContext {
}
test("invalid conf value") {
- spark.wrapped.conf.clear()
+ spark.sqlContext.conf.clear()
val e = intercept[IllegalArgumentException] {
sql(s"set ${SQLConf.CASE_SENSITIVE.key}=10")
}
@@ -116,7 +116,7 @@ class SQLConfSuite extends QueryTest with SharedSQLContext {
}
test("Test SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE's method") {
- spark.wrapped.conf.clear()
+ spark.sqlContext.conf.clear()
spark.conf.set(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "100")
assert(spark.conf.get(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE) === 100)
@@ -144,7 +144,7 @@ class SQLConfSuite extends QueryTest with SharedSQLContext {
spark.conf.set(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "-90000000000g")
}
- spark.wrapped.conf.clear()
+ spark.sqlContext.conf.clear()
}
test("SparkSession can access configs set in SparkConf") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala
index 612cfc7ec7..a34f70ed65 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala
@@ -41,7 +41,7 @@ case class SimpleDDLScan(
table: String)(@transient val sparkSession: SparkSession)
extends BaseRelation with TableScan {
- override def sqlContext: SQLContext = sparkSession.wrapped
+ override def sqlContext: SQLContext = sparkSession.sqlContext
override def schema: StructType =
StructType(Seq(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
index 51d04f2f4e..f969660ddd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
@@ -40,7 +40,7 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sparkSession: S
extends BaseRelation
with PrunedFilteredScan {
- override def sqlContext: SQLContext = sparkSession.wrapped
+ override def sqlContext: SQLContext = sparkSession.sqlContext
override def schema: StructType =
StructType(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala
index cd0256db43..9cdf7dea76 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala
@@ -37,7 +37,7 @@ case class SimplePrunedScan(from: Int, to: Int)(@transient val sparkSession: Spa
extends BaseRelation
with PrunedScan {
- override def sqlContext: SQLContext = sparkSession.wrapped
+ override def sqlContext: SQLContext = sparkSession.sqlContext
override def schema: StructType =
StructType(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
index 34b8726a92..cddf4a1884 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
@@ -38,7 +38,7 @@ class SimpleScanSource extends RelationProvider {
case class SimpleScan(from: Int, to: Int)(@transient val sparkSession: SparkSession)
extends BaseRelation with TableScan {
- override def sqlContext: SQLContext = sparkSession.wrapped
+ override def sqlContext: SQLContext = sparkSession.sqlContext
override def schema: StructType =
StructType(StructField("i", IntegerType, nullable = false) :: Nil)
@@ -70,7 +70,7 @@ case class AllDataTypesScan(
extends BaseRelation
with TableScan {
- override def sqlContext: SQLContext = sparkSession.wrapped
+ override def sqlContext: SQLContext = sparkSession.sqlContext
override def schema: StructType = userSpecifiedSchema
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala
index ff53505549..e6c0ce95e7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala
@@ -355,14 +355,14 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B
q.stop()
verify(LastOptions.mockStreamSourceProvider).createSource(
- spark.wrapped,
+ spark.sqlContext,
checkpointLocation + "/sources/0",
None,
"org.apache.spark.sql.streaming.test",
Map.empty)
verify(LastOptions.mockStreamSourceProvider).createSource(
- spark.wrapped,
+ spark.sqlContext,
checkpointLocation + "/sources/1",
None,
"org.apache.spark.sql.streaming.test",
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
index 421f6bca7f..0cfe260e52 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
@@ -30,7 +30,7 @@ private[sql] trait SQLTestData { self =>
// Helper object to import SQL implicits without a concrete SQLContext
private object internalImplicits extends SQLImplicits {
- protected override def _sqlContext: SQLContext = self.spark.wrapped
+ protected override def _sqlContext: SQLContext = self.spark.sqlContext
}
import internalImplicits._
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
index 51538eca64..853dd0ff3f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
@@ -66,7 +66,7 @@ private[sql] trait SQLTestUtils
* but the implicits import is needed in the constructor.
*/
protected object testImplicits extends SQLImplicits {
- protected override def _sqlContext: SQLContext = self.spark.wrapped
+ protected override def _sqlContext: SQLContext = self.spark.sqlContext
}
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala
index 620bfa995a..79c37faa4e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala
@@ -44,13 +44,13 @@ trait SharedSQLContext extends SQLTestUtils {
/**
* The [[TestSQLContext]] to use for all tests in this suite.
*/
- protected implicit def sqlContext: SQLContext = _spark.wrapped
+ protected implicit def sqlContext: SQLContext = _spark.sqlContext
/**
* Initialize the [[TestSparkSession]].
*/
protected override def beforeAll(): Unit = {
- SQLContext.clearSqlListener()
+ SparkSession.sqlListener.set(null)
if (_spark == null) {
_spark = new TestSparkSession(sparkConf)
}
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala
index 8de223f444..638911599a 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala
@@ -56,7 +56,7 @@ private[hive] object SparkSQLEnv extends Logging {
val sparkSession = SparkSession.builder.config(sparkConf).enableHiveSupport().getOrCreate()
sparkContext = sparkSession.sparkContext
- sqlContext = sparkSession.wrapped
+ sqlContext = sparkSession.sqlContext
val sessionState = sparkSession.sessionState.asInstanceOf[HiveSessionState]
sessionState.metadataHive.setOut(new PrintStream(System.out, true, "UTF-8"))
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala
index d2cb62c617..7c74a0308d 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala
@@ -30,7 +30,7 @@ class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAnd
override protected def beforeEach(): Unit = {
super.beforeEach()
- if (spark.wrapped.tableNames().contains("src")) {
+ if (spark.sqlContext.tableNames().contains("src")) {
spark.catalog.dropTempView("src")
}
Seq((1, "")).toDF("key", "value").createOrReplaceTempView("src")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala
index 6c9ce208db..622b043581 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala
@@ -36,11 +36,11 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
withTempDatabase { db =>
activateDatabase(db) {
df.write.mode(SaveMode.Overwrite).saveAsTable("t")
- assert(spark.wrapped.tableNames().contains("t"))
+ assert(spark.sqlContext.tableNames().contains("t"))
checkAnswer(spark.table("t"), df)
}
- assert(spark.wrapped.tableNames(db).contains("t"))
+ assert(spark.sqlContext.tableNames(db).contains("t"))
checkAnswer(spark.table(s"$db.t"), df)
checkTablePath(db, "t")
@@ -50,7 +50,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
test(s"saveAsTable() to non-default database - without USE - Overwrite") {
withTempDatabase { db =>
df.write.mode(SaveMode.Overwrite).saveAsTable(s"$db.t")
- assert(spark.wrapped.tableNames(db).contains("t"))
+ assert(spark.sqlContext.tableNames(db).contains("t"))
checkAnswer(spark.table(s"$db.t"), df)
checkTablePath(db, "t")
@@ -65,7 +65,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
df.write.format("parquet").mode(SaveMode.Overwrite).save(path)
spark.catalog.createExternalTable("t", path, "parquet")
- assert(spark.wrapped.tableNames(db).contains("t"))
+ assert(spark.sqlContext.tableNames(db).contains("t"))
checkAnswer(spark.table("t"), df)
sql(
@@ -76,7 +76,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
| path '$path'
|)
""".stripMargin)
- assert(spark.wrapped.tableNames(db).contains("t1"))
+ assert(spark.sqlContext.tableNames(db).contains("t1"))
checkAnswer(spark.table("t1"), df)
}
}
@@ -90,7 +90,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
df.write.format("parquet").mode(SaveMode.Overwrite).save(path)
spark.catalog.createExternalTable(s"$db.t", path, "parquet")
- assert(spark.wrapped.tableNames(db).contains("t"))
+ assert(spark.sqlContext.tableNames(db).contains("t"))
checkAnswer(spark.table(s"$db.t"), df)
sql(
@@ -101,7 +101,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
| path '$path'
|)
""".stripMargin)
- assert(spark.wrapped.tableNames(db).contains("t1"))
+ assert(spark.sqlContext.tableNames(db).contains("t1"))
checkAnswer(spark.table(s"$db.t1"), df)
}
}
@@ -112,11 +112,11 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
activateDatabase(db) {
df.write.mode(SaveMode.Overwrite).saveAsTable("t")
df.write.mode(SaveMode.Append).saveAsTable("t")
- assert(spark.wrapped.tableNames().contains("t"))
+ assert(spark.sqlContext.tableNames().contains("t"))
checkAnswer(spark.table("t"), df.union(df))
}
- assert(spark.wrapped.tableNames(db).contains("t"))
+ assert(spark.sqlContext.tableNames(db).contains("t"))
checkAnswer(spark.table(s"$db.t"), df.union(df))
checkTablePath(db, "t")
@@ -127,7 +127,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
withTempDatabase { db =>
df.write.mode(SaveMode.Overwrite).saveAsTable(s"$db.t")
df.write.mode(SaveMode.Append).saveAsTable(s"$db.t")
- assert(spark.wrapped.tableNames(db).contains("t"))
+ assert(spark.sqlContext.tableNames(db).contains("t"))
checkAnswer(spark.table(s"$db.t"), df.union(df))
checkTablePath(db, "t")
@@ -138,7 +138,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
withTempDatabase { db =>
activateDatabase(db) {
df.write.mode(SaveMode.Overwrite).saveAsTable("t")
- assert(spark.wrapped.tableNames().contains("t"))
+ assert(spark.sqlContext.tableNames().contains("t"))
df.write.insertInto(s"$db.t")
checkAnswer(spark.table(s"$db.t"), df.union(df))
@@ -150,10 +150,10 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
withTempDatabase { db =>
activateDatabase(db) {
df.write.mode(SaveMode.Overwrite).saveAsTable("t")
- assert(spark.wrapped.tableNames().contains("t"))
+ assert(spark.sqlContext.tableNames().contains("t"))
}
- assert(spark.wrapped.tableNames(db).contains("t"))
+ assert(spark.sqlContext.tableNames(db).contains("t"))
df.write.insertInto(s"$db.t")
checkAnswer(spark.table(s"$db.t"), df.union(df))
@@ -175,21 +175,21 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
withTempDatabase { db =>
activateDatabase(db) {
sql(s"CREATE TABLE t (key INT)")
- assert(spark.wrapped.tableNames().contains("t"))
- assert(!spark.wrapped.tableNames("default").contains("t"))
+ assert(spark.sqlContext.tableNames().contains("t"))
+ assert(!spark.sqlContext.tableNames("default").contains("t"))
}
- assert(!spark.wrapped.tableNames().contains("t"))
- assert(spark.wrapped.tableNames(db).contains("t"))
+ assert(!spark.sqlContext.tableNames().contains("t"))
+ assert(spark.sqlContext.tableNames(db).contains("t"))
activateDatabase(db) {
sql(s"DROP TABLE t")
- assert(!spark.wrapped.tableNames().contains("t"))
- assert(!spark.wrapped.tableNames("default").contains("t"))
+ assert(!spark.sqlContext.tableNames().contains("t"))
+ assert(!spark.sqlContext.tableNames("default").contains("t"))
}
- assert(!spark.wrapped.tableNames().contains("t"))
- assert(!spark.wrapped.tableNames(db).contains("t"))
+ assert(!spark.sqlContext.tableNames().contains("t"))
+ assert(!spark.sqlContext.tableNames(db).contains("t"))
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index 81f3ea8a6e..8a31a49d97 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -1417,7 +1417,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
""".stripMargin)
checkAnswer(
- spark.wrapped.tables().select('isTemporary).filter('tableName === "t2"),
+ spark.sqlContext.tables().select('isTemporary).filter('tableName === "t2"),
Row(true)
)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala
index aba60da33f..bb351e20c5 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala
@@ -61,7 +61,7 @@ private[sql] trait OrcTest extends SQLTestUtils with TestHiveSingleton {
(data: Seq[T], tableName: String)
(f: => Unit): Unit = {
withOrcDataFrame(data) { df =>
- spark.wrapped.registerDataFrameAsTable(df, tableName)
+ spark.sqlContext.registerDataFrameAsTable(df, tableName)
withTempTable(tableName)(f)
}
}