aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/tests.py
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@eecs.berkeley.edu>2013-01-16 19:15:14 -0800
committerJosh Rosen <joshrosen@eecs.berkeley.edu>2013-01-20 13:19:19 -0800
commit7ed1bf4b485131d58ea6728e7247b79320aca9e6 (patch)
tree4c9e91c1c997d328bed7c939fdb69f6e8eed516f /python/pyspark/tests.py
parentfe85a075117a79675971aff0cd020bba446c0233 (diff)
downloadspark-7ed1bf4b485131d58ea6728e7247b79320aca9e6.tar.gz
spark-7ed1bf4b485131d58ea6728e7247b79320aca9e6.tar.bz2
spark-7ed1bf4b485131d58ea6728e7247b79320aca9e6.zip
Add RDD checkpointing to Python API.
Diffstat (limited to 'python/pyspark/tests.py')
-rw-r--r--python/pyspark/tests.py46
1 files changed, 46 insertions, 0 deletions
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
new file mode 100644
index 0000000000..c959d5dec7
--- /dev/null
+++ b/python/pyspark/tests.py
@@ -0,0 +1,46 @@
+"""
+Unit tests for PySpark; additional tests are implemented as doctests in
+individual modules.
+"""
+import atexit
+import os
+import shutil
+from tempfile import NamedTemporaryFile
+import time
+import unittest
+
+from pyspark.context import SparkContext
+
+
+class TestCheckpoint(unittest.TestCase):
+
+ def setUp(self):
+ self.sc = SparkContext('local[4]', 'TestPartitioning', batchSize=2)
+
+ def tearDown(self):
+ self.sc.stop()
+
+ def test_basic_checkpointing(self):
+ checkpointDir = NamedTemporaryFile(delete=False)
+ os.unlink(checkpointDir.name)
+ self.sc.setCheckpointDir(checkpointDir.name)
+
+ parCollection = self.sc.parallelize([1, 2, 3, 4])
+ flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1))
+
+ self.assertFalse(flatMappedRDD.isCheckpointed())
+ self.assertIsNone(flatMappedRDD.getCheckpointFile())
+
+ flatMappedRDD.checkpoint()
+ result = flatMappedRDD.collect()
+ time.sleep(1) # 1 second
+ self.assertTrue(flatMappedRDD.isCheckpointed())
+ self.assertEqual(flatMappedRDD.collect(), result)
+ self.assertEqual(checkpointDir.name,
+ os.path.dirname(flatMappedRDD.getCheckpointFile()))
+
+ atexit.register(lambda: shutil.rmtree(checkpointDir.name))
+
+
+if __name__ == "__main__":
+ unittest.main()