diff options
author | Josh Rosen <joshrosen@eecs.berkeley.edu> | 2013-01-16 19:15:14 -0800 |
---|---|---|
committer | Josh Rosen <joshrosen@eecs.berkeley.edu> | 2013-01-20 13:19:19 -0800 |
commit | 7ed1bf4b485131d58ea6728e7247b79320aca9e6 (patch) | |
tree | 4c9e91c1c997d328bed7c939fdb69f6e8eed516f /python/pyspark/tests.py | |
parent | fe85a075117a79675971aff0cd020bba446c0233 (diff) | |
download | spark-7ed1bf4b485131d58ea6728e7247b79320aca9e6.tar.gz spark-7ed1bf4b485131d58ea6728e7247b79320aca9e6.tar.bz2 spark-7ed1bf4b485131d58ea6728e7247b79320aca9e6.zip |
Add RDD checkpointing to Python API.
Diffstat (limited to 'python/pyspark/tests.py')
-rw-r--r-- | python/pyspark/tests.py | 46 |
1 files changed, 46 insertions, 0 deletions
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py new file mode 100644 index 0000000000..c959d5dec7 --- /dev/null +++ b/python/pyspark/tests.py @@ -0,0 +1,46 @@ +""" +Unit tests for PySpark; additional tests are implemented as doctests in +individual modules. +""" +import atexit +import os +import shutil +from tempfile import NamedTemporaryFile +import time +import unittest + +from pyspark.context import SparkContext + + +class TestCheckpoint(unittest.TestCase): + + def setUp(self): + self.sc = SparkContext('local[4]', 'TestPartitioning', batchSize=2) + + def tearDown(self): + self.sc.stop() + + def test_basic_checkpointing(self): + checkpointDir = NamedTemporaryFile(delete=False) + os.unlink(checkpointDir.name) + self.sc.setCheckpointDir(checkpointDir.name) + + parCollection = self.sc.parallelize([1, 2, 3, 4]) + flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1)) + + self.assertFalse(flatMappedRDD.isCheckpointed()) + self.assertIsNone(flatMappedRDD.getCheckpointFile()) + + flatMappedRDD.checkpoint() + result = flatMappedRDD.collect() + time.sleep(1) # 1 second + self.assertTrue(flatMappedRDD.isCheckpointed()) + self.assertEqual(flatMappedRDD.collect(), result) + self.assertEqual(checkpointDir.name, + os.path.dirname(flatMappedRDD.getCheckpointFile())) + + atexit.register(lambda: shutil.rmtree(checkpointDir.name)) + + +if __name__ == "__main__": + unittest.main() |