aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala17
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala5
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala2
3 files changed, 14 insertions, 10 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index cd1fdc4edb..39471d2fb7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -1229,7 +1229,7 @@ class SQLContext private[sql](
// construction of the instance.
sparkContext.addSparkListener(new SparkListener {
override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = {
- SQLContext.clearInstantiatedContext(self)
+ SQLContext.clearInstantiatedContext()
}
})
@@ -1270,13 +1270,13 @@ object SQLContext {
*/
def getOrCreate(sparkContext: SparkContext): SQLContext = {
val ctx = activeContext.get()
- if (ctx != null) {
+ if (ctx != null && !ctx.sparkContext.isStopped) {
return ctx
}
synchronized {
val ctx = instantiatedContext.get()
- if (ctx == null) {
+ if (ctx == null || ctx.sparkContext.isStopped) {
new SQLContext(sparkContext)
} else {
ctx
@@ -1284,12 +1284,17 @@ object SQLContext {
}
}
- private[sql] def clearInstantiatedContext(sqlContext: SQLContext): Unit = {
- instantiatedContext.compareAndSet(sqlContext, null)
+ private[sql] def clearInstantiatedContext(): Unit = {
+ instantiatedContext.set(null)
}
private[sql] def setInstantiatedContext(sqlContext: SQLContext): Unit = {
- instantiatedContext.compareAndSet(null, sqlContext)
+ synchronized {
+ val ctx = instantiatedContext.get()
+ if (ctx == null || ctx.sparkContext.isStopped) {
+ instantiatedContext.set(sqlContext)
+ }
+ }
}
private[sql] def getInstantiatedContextOption(): Option[SQLContext] = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala
index 0e8fcb6a85..34c5c68fd1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala
@@ -31,7 +31,7 @@ class MultiSQLContextsSuite extends SparkFunSuite with BeforeAndAfterAll {
originalInstantiatedSQLContext = SQLContext.getInstantiatedContextOption()
SQLContext.clearActive()
- originalInstantiatedSQLContext.foreach(ctx => SQLContext.clearInstantiatedContext(ctx))
+ SQLContext.clearInstantiatedContext()
sparkConf =
new SparkConf(false)
.setMaster("local[*]")
@@ -89,10 +89,9 @@ class MultiSQLContextsSuite extends SparkFunSuite with BeforeAndAfterAll {
testNewSession(rootSQLContext)
testNewSession(rootSQLContext)
testCreatingNewSQLContext(allowMultipleSQLContexts)
-
- SQLContext.clearInstantiatedContext(rootSQLContext)
} finally {
sc.stop()
+ SQLContext.clearInstantiatedContext()
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala
index 25f2f5caee..b96d50a70b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala
@@ -34,7 +34,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll {
originalInstantiatedSQLContext = SQLContext.getInstantiatedContextOption()
SQLContext.clearActive()
- originalInstantiatedSQLContext.foreach(ctx => SQLContext.clearInstantiatedContext(ctx))
+ SQLContext.clearInstantiatedContext()
}
override protected def afterAll(): Unit = {