aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/tests.py
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@eecs.berkeley.edu>2013-01-20 13:59:45 -0800
committerJosh Rosen <joshrosen@eecs.berkeley.edu>2013-01-20 13:59:45 -0800
commitd0ba80dc727d00b2b7627dcefd2c77009af55f7d (patch)
treeeadf5f88feb468179527459a1f7316b3ca5ac2cf /python/pyspark/tests.py
parent7ed1bf4b485131d58ea6728e7247b79320aca9e6 (diff)
downloadspark-d0ba80dc727d00b2b7627dcefd2c77009af55f7d.tar.gz
spark-d0ba80dc727d00b2b7627dcefd2c77009af55f7d.tar.bz2
spark-d0ba80dc727d00b2b7627dcefd2c77009af55f7d.zip
Add checkpointFile() and more tests to PySpark.
Diffstat (limited to 'python/pyspark/tests.py')
-rw-r--r--python/pyspark/tests.py24
1 files changed, 24 insertions, 0 deletions
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index c959d5dec7..83283fca4f 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -19,6 +19,9 @@ class TestCheckpoint(unittest.TestCase):
def tearDown(self):
self.sc.stop()
+ # To avoid Akka rebinding to the same port, since it doesn't unbind
+ # immediately on shutdown
+ self.sc.jvm.System.clearProperty("spark.master.port")
def test_basic_checkpointing(self):
checkpointDir = NamedTemporaryFile(delete=False)
@@ -41,6 +44,27 @@ class TestCheckpoint(unittest.TestCase):
atexit.register(lambda: shutil.rmtree(checkpointDir.name))
+ def test_checkpoint_and_restore(self):
+ checkpointDir = NamedTemporaryFile(delete=False)
+ os.unlink(checkpointDir.name)
+ self.sc.setCheckpointDir(checkpointDir.name)
+
+ parCollection = self.sc.parallelize([1, 2, 3, 4])
+ flatMappedRDD = parCollection.flatMap(lambda x: [x])
+
+ self.assertFalse(flatMappedRDD.isCheckpointed())
+ self.assertIsNone(flatMappedRDD.getCheckpointFile())
+
+ flatMappedRDD.checkpoint()
+ flatMappedRDD.count() # forces a checkpoint to be computed
+ time.sleep(1) # 1 second
+
+ self.assertIsNotNone(flatMappedRDD.getCheckpointFile())
+ recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile())
+ self.assertEquals([1, 2, 3, 4], recovered.collect())
+
+ atexit.register(lambda: shutil.rmtree(checkpointDir.name))
+
if __name__ == "__main__":
unittest.main()