aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <viirya@gmail.com>2017-01-12 20:53:31 +0800
committerWenchen Fan <wenchen@databricks.com>2017-01-12 20:53:31 +0800
commitc6c37b8af714c8ddc8c77ac943a379f703558f27 (patch)
treeca94dfda92a9242c63572fff433ad8e60da1f7da /python
parent871d266649ddfed38c64dfda7158d8bb58d4b979 (diff)
downloadspark-c6c37b8af714c8ddc8c77ac943a379f703558f27.tar.gz
spark-c6c37b8af714c8ddc8c77ac943a379f703558f27.tar.bz2
spark-c6c37b8af714c8ddc8c77ac943a379f703558f27.zip
[SPARK-19055][SQL][PYSPARK] Fix SparkSession initialization when SparkContext is stopped
## What changes were proposed in this pull request? In SparkSession initialization, we store created the instance of SparkSession into a class variable _instantiatedContext. Next time we can use SparkSession.builder.getOrCreate() to retrieve the existing SparkSession instance. However, when the active SparkContext is stopped and we create another new SparkContext to use, the existing SparkSession is still associated with the stopped SparkContext. So the operations with this existing SparkSession will be failed. We need to detect such case in SparkSession and renew the class variable _instantiatedContext if needed. ## How was this patch tested? New test added in PySpark. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #16454 from viirya/fix-pyspark-sparksession.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/sql/session.py16
-rw-r--r--python/pyspark/sql/tests.py23
2 files changed, 33 insertions, 6 deletions
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index 1e40b9c39f..9f4772eec9 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -161,8 +161,8 @@ class SparkSession(object):
with self._lock:
from pyspark.context import SparkContext
from pyspark.conf import SparkConf
- session = SparkSession._instantiatedContext
- if session is None:
+ session = SparkSession._instantiatedSession
+ if session is None or session._sc._jsc is None:
sparkConf = SparkConf()
for key, value in self._options.items():
sparkConf.set(key, value)
@@ -183,7 +183,7 @@ class SparkSession(object):
builder = Builder()
- _instantiatedContext = None
+ _instantiatedSession = None
@ignore_unicode_prefix
def __init__(self, sparkContext, jsparkSession=None):
@@ -214,8 +214,12 @@ class SparkSession(object):
self._wrapped = SQLContext(self._sc, self, self._jwrapped)
_monkey_patch_RDD(self)
install_exception_handler()
- if SparkSession._instantiatedContext is None:
- SparkSession._instantiatedContext = self
+ # If we had an instantiated SparkSession attached with a SparkContext
+ # which is stopped now, we need to renew the instantiated SparkSession.
+ # Otherwise, we will use invalid SparkSession when we call Builder.getOrCreate.
+ if SparkSession._instantiatedSession is None \
+ or SparkSession._instantiatedSession._sc._jsc is None:
+ SparkSession._instantiatedSession = self
@since(2.0)
def newSession(self):
@@ -595,7 +599,7 @@ class SparkSession(object):
"""Stop the underlying :class:`SparkContext`.
"""
self._sc.stop()
- SparkSession._instantiatedContext = None
+ SparkSession._instantiatedSession = None
@since(2.0)
def __enter__(self):
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 18fd68ec5e..d1782857e6 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -46,6 +46,7 @@ if sys.version_info[:2] <= (2, 6):
else:
import unittest
+from pyspark import SparkContext
from pyspark.sql import SparkSession, HiveContext, Column, Row
from pyspark.sql.types import *
from pyspark.sql.types import UserDefinedType, _infer_type
@@ -1886,6 +1887,28 @@ class HiveSparkSubmitTests(SparkSubmitTests):
self.assertTrue(os.path.exists(metastore_path))
+class SQLTests2(ReusedPySparkTestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ ReusedPySparkTestCase.setUpClass()
+ cls.spark = SparkSession(cls.sc)
+
+ @classmethod
+ def tearDownClass(cls):
+ ReusedPySparkTestCase.tearDownClass()
+ cls.spark.stop()
+
+ # We can't include this test into SQLTests because we will stop class's SparkContext and cause
+ # other tests failed.
+ def test_sparksession_with_stopped_sparkcontext(self):
+ self.sc.stop()
+ sc = SparkContext('local[4]', self.sc.appName)
+ spark = SparkSession.builder.getOrCreate()
+ df = spark.createDataFrame([(1, 2)], ["c", "c"])
+ df.collect()
+
+
class HiveContextSQLTests(ReusedPySparkTestCase):
@classmethod