aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/tests.py
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@eecs.berkeley.edu>2013-01-20 15:38:11 -0800
committerJosh Rosen <joshrosen@eecs.berkeley.edu>2013-01-20 15:38:11 -0800
commit00d70cd6602d5ff2718e319ec04defbdd486237e (patch)
treeca2146992f1fcc8142d53bb4b7b1c2778e78b0c5 /python/pyspark/tests.py
parent5b6ea9e9a04994553d0319c541ca356e2e3064a7 (diff)
downloadspark-00d70cd6602d5ff2718e319ec04defbdd486237e.tar.gz
spark-00d70cd6602d5ff2718e319ec04defbdd486237e.tar.bz2
spark-00d70cd6602d5ff2718e319ec04defbdd486237e.zip
Clean up setup code in PySpark checkpointing tests
Diffstat (limited to 'python/pyspark/tests.py')
-rw-r--r--python/pyspark/tests.py19
1 files changed, 5 insertions, 14 deletions
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 83283fca4f..b0a403b580 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -2,7 +2,6 @@
Unit tests for PySpark; additional tests are implemented as doctests in
individual modules.
"""
-import atexit
import os
import shutil
from tempfile import NamedTemporaryFile
@@ -16,18 +15,18 @@ class TestCheckpoint(unittest.TestCase):
def setUp(self):
self.sc = SparkContext('local[4]', 'TestPartitioning', batchSize=2)
+ self.checkpointDir = NamedTemporaryFile(delete=False)
+ os.unlink(self.checkpointDir.name)
+ self.sc.setCheckpointDir(self.checkpointDir.name)
def tearDown(self):
self.sc.stop()
# To avoid Akka rebinding to the same port, since it doesn't unbind
# immediately on shutdown
self.sc.jvm.System.clearProperty("spark.master.port")
+ shutil.rmtree(self.checkpointDir.name)
def test_basic_checkpointing(self):
- checkpointDir = NamedTemporaryFile(delete=False)
- os.unlink(checkpointDir.name)
- self.sc.setCheckpointDir(checkpointDir.name)
-
parCollection = self.sc.parallelize([1, 2, 3, 4])
flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1))
@@ -39,16 +38,10 @@ class TestCheckpoint(unittest.TestCase):
time.sleep(1) # 1 second
self.assertTrue(flatMappedRDD.isCheckpointed())
self.assertEqual(flatMappedRDD.collect(), result)
- self.assertEqual(checkpointDir.name,
+ self.assertEqual(self.checkpointDir.name,
os.path.dirname(flatMappedRDD.getCheckpointFile()))
- atexit.register(lambda: shutil.rmtree(checkpointDir.name))
-
def test_checkpoint_and_restore(self):
- checkpointDir = NamedTemporaryFile(delete=False)
- os.unlink(checkpointDir.name)
- self.sc.setCheckpointDir(checkpointDir.name)
-
parCollection = self.sc.parallelize([1, 2, 3, 4])
flatMappedRDD = parCollection.flatMap(lambda x: [x])
@@ -63,8 +56,6 @@ class TestCheckpoint(unittest.TestCase):
recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile())
self.assertEquals([1, 2, 3, 4], recovered.collect())
- atexit.register(lambda: shutil.rmtree(checkpointDir.name))
-
if __name__ == "__main__":
unittest.main()