diff options
author | Josh Rosen <joshrosen@eecs.berkeley.edu> | 2013-01-20 15:38:11 -0800 |
---|---|---|
committer | Josh Rosen <joshrosen@eecs.berkeley.edu> | 2013-01-20 15:38:11 -0800 |
commit | 00d70cd6602d5ff2718e319ec04defbdd486237e (patch) | |
tree | ca2146992f1fcc8142d53bb4b7b1c2778e78b0c5 /python/pyspark | |
parent | 5b6ea9e9a04994553d0319c541ca356e2e3064a7 (diff) | |
download | spark-00d70cd6602d5ff2718e319ec04defbdd486237e.tar.gz spark-00d70cd6602d5ff2718e319ec04defbdd486237e.tar.bz2 spark-00d70cd6602d5ff2718e319ec04defbdd486237e.zip |
Clean up setup code in PySpark checkpointing tests
Diffstat (limited to 'python/pyspark')
-rw-r--r-- | python/pyspark/rdd.py | 3 | ||||
-rw-r--r-- | python/pyspark/tests.py | 19 |
2 files changed, 6 insertions, 16 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 7b6ab956ee..097cdb13b4 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -691,7 +691,7 @@ class PipelinedRDD(RDD): 20 """ def __init__(self, prev, func, preservesPartitioning=False): - if isinstance(prev, PipelinedRDD) and prev._is_pipelinable: + if isinstance(prev, PipelinedRDD) and prev._is_pipelinable(): prev_func = prev.func def pipeline_func(split, iterator): return func(split, prev_func(split, iterator)) @@ -737,7 +737,6 @@ class PipelinedRDD(RDD): self._jrdd_val = python_rdd.asJavaRDD() return self._jrdd_val - @property def _is_pipelinable(self): return not (self.is_cached or self.is_checkpointed) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 83283fca4f..b0a403b580 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -2,7 +2,6 @@ Unit tests for PySpark; additional tests are implemented as doctests in individual modules. """ -import atexit import os import shutil from tempfile import NamedTemporaryFile @@ -16,18 +15,18 @@ class TestCheckpoint(unittest.TestCase): def setUp(self): self.sc = SparkContext('local[4]', 'TestPartitioning', batchSize=2) + self.checkpointDir = NamedTemporaryFile(delete=False) + os.unlink(self.checkpointDir.name) + self.sc.setCheckpointDir(self.checkpointDir.name) def tearDown(self): self.sc.stop() # To avoid Akka rebinding to the same port, since it doesn't unbind # immediately on shutdown self.sc.jvm.System.clearProperty("spark.master.port") + shutil.rmtree(self.checkpointDir.name) def test_basic_checkpointing(self): - checkpointDir = NamedTemporaryFile(delete=False) - os.unlink(checkpointDir.name) - self.sc.setCheckpointDir(checkpointDir.name) - parCollection = self.sc.parallelize([1, 2, 3, 4]) flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1)) @@ -39,16 +38,10 @@ class TestCheckpoint(unittest.TestCase): time.sleep(1) # 1 second self.assertTrue(flatMappedRDD.isCheckpointed()) self.assertEqual(flatMappedRDD.collect(), result) - self.assertEqual(checkpointDir.name, + self.assertEqual(self.checkpointDir.name, os.path.dirname(flatMappedRDD.getCheckpointFile())) - atexit.register(lambda: shutil.rmtree(checkpointDir.name)) - def test_checkpoint_and_restore(self): - checkpointDir = NamedTemporaryFile(delete=False) - os.unlink(checkpointDir.name) - self.sc.setCheckpointDir(checkpointDir.name) - parCollection = self.sc.parallelize([1, 2, 3, 4]) flatMappedRDD = parCollection.flatMap(lambda x: [x]) @@ -63,8 +56,6 @@ class TestCheckpoint(unittest.TestCase): recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile()) self.assertEquals([1, 2, 3, 4], recovered.collect()) - atexit.register(lambda: shutil.rmtree(checkpointDir.name)) - if __name__ == "__main__": unittest.main() |