diff options
author | Matthew Farrellee <matt@redhat.com> | 2014-09-09 18:54:54 -0700 |
---|---|---|
committer | Reynold Xin <rxin@apache.org> | 2014-09-09 18:54:54 -0700 |
commit | 25b5b867d5e18bac1c5bcdc6f8c63d97858194c7 (patch) | |
tree | c89d169c2b6461d7d01eff0a6cbf61c3235acef0 | |
parent | c110614b33a690a3db6ccb1a920fb6a3795aa5a0 (diff) | |
download | spark-25b5b867d5e18bac1c5bcdc6f8c63d97858194c7.tar.gz spark-25b5b867d5e18bac1c5bcdc6f8c63d97858194c7.tar.bz2 spark-25b5b867d5e18bac1c5bcdc6f8c63d97858194c7.zip |
[SPARK-3458] enable python "with" statements for SparkContext
allow for best practice code,
```
try:
sc = SparkContext()
app(sc)
finally:
sc.stop()
```
to be written using a "with" statement,
```
with SparkContext() as sc:
app(sc)
```
Author: Matthew Farrellee <matt@redhat.com>
Closes #2335 from mattf/SPARK-3458 and squashes the following commits:
5b4e37c [Matthew Farrellee] [SPARK-3458] enable python "with" statements for SparkContext
-rw-r--r-- | python/pyspark/context.py | 14 | ||||
-rw-r--r-- | python/pyspark/tests.py | 29 |
2 files changed, 43 insertions, 0 deletions
diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 5a30431568..84bc0a3b7c 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -232,6 +232,20 @@ class SparkContext(object): else: SparkContext._active_spark_context = instance + def __enter__(self): + """ + Enable 'with SparkContext(...) as sc: app(sc)' syntax. + """ + return self + + def __exit__(self, type, value, trace): + """ + Enable 'with SparkContext(...) as sc: app' syntax. + + Specifically stop the context on exit of the with block. + """ + self.stop() + @classmethod def setSystemProperty(cls, key, value): """ diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 0bd2a9e6c5..bb84ebe72c 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1254,6 +1254,35 @@ class TestSparkSubmit(unittest.TestCase): self.assertIn("[2, 4, 6]", out) +class ContextStopTests(unittest.TestCase): + + def test_stop(self): + sc = SparkContext() + self.assertNotEqual(SparkContext._active_spark_context, None) + sc.stop() + self.assertEqual(SparkContext._active_spark_context, None) + + def test_with(self): + with SparkContext() as sc: + self.assertNotEqual(SparkContext._active_spark_context, None) + self.assertEqual(SparkContext._active_spark_context, None) + + def test_with_exception(self): + try: + with SparkContext() as sc: + self.assertNotEqual(SparkContext._active_spark_context, None) + raise Exception() + except: + pass + self.assertEqual(SparkContext._active_spark_context, None) + + def test_with_stop(self): + with SparkContext() as sc: + self.assertNotEqual(SparkContext._active_spark_context, None) + sc.stop() + self.assertEqual(SparkContext._active_spark_context, None) + + @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase): |