aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@eecs.berkeley.edu>2013-01-16 19:15:14 -0800
committerJosh Rosen <joshrosen@eecs.berkeley.edu>2013-01-20 13:19:19 -0800
commit7ed1bf4b485131d58ea6728e7247b79320aca9e6 (patch)
tree4c9e91c1c997d328bed7c939fdb69f6e8eed516f /python
parentfe85a075117a79675971aff0cd020bba446c0233 (diff)
downloadspark-7ed1bf4b485131d58ea6728e7247b79320aca9e6.tar.gz
spark-7ed1bf4b485131d58ea6728e7247b79320aca9e6.tar.bz2
spark-7ed1bf4b485131d58ea6728e7247b79320aca9e6.zip
Add RDD checkpointing to Python API.
Diffstat (limited to 'python')
-rw-r--r--python/epydoc.conf2
-rw-r--r--python/pyspark/context.py9
-rw-r--r--python/pyspark/rdd.py34
-rw-r--r--python/pyspark/tests.py46
-rwxr-xr-xpython/run-tests3
5 files changed, 93 insertions, 1 deletions
diff --git a/python/epydoc.conf b/python/epydoc.conf
index 91ac984ba2..45102cd9fe 100644
--- a/python/epydoc.conf
+++ b/python/epydoc.conf
@@ -16,4 +16,4 @@ target: docs/
private: no
exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.serializers
- pyspark.java_gateway pyspark.examples pyspark.shell
+ pyspark.java_gateway pyspark.examples pyspark.shell pyspark.test
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 1e2f845f9c..a438b43fdc 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -195,3 +195,12 @@ class SparkContext(object):
filename = path.split("/")[-1]
os.environ["PYTHONPATH"] = \
"%s:%s" % (filename, os.environ["PYTHONPATH"])
+
+ def setCheckpointDir(self, dirName, useExisting=False):
+ """
+ Set the directory under which RDDs are going to be checkpointed. This
+ method will create this directory and will throw an exception of the
+ path already exists (to avoid overwriting existing files may be
+ overwritten). The directory will be deleted on exit if indicated.
+ """
+ self._jsc.sc().setCheckpointDir(dirName, useExisting)
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index d705f0f9e1..9b676cae4a 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -49,6 +49,40 @@ class RDD(object):
self._jrdd.cache()
return self
+ def checkpoint(self):
+ """
+ Mark this RDD for checkpointing. The RDD will be saved to a file inside
+ `checkpointDir` (set using setCheckpointDir()) and all references to
+ its parent RDDs will be removed. This is used to truncate very long
+ lineages. In the current implementation, Spark will save this RDD to
+ a file (using saveAsObjectFile()) after the first job using this RDD is
+ done. Hence, it is strongly recommended to use checkpoint() on RDDs
+ when
+
+ (i) checkpoint() is called before the any job has been executed on this
+ RDD.
+
+ (ii) This RDD has been made to persist in memory. Otherwise saving it
+ on a file will require recomputation.
+ """
+ self._jrdd.rdd().checkpoint()
+
+ def isCheckpointed(self):
+ """
+ Return whether this RDD has been checkpointed or not
+ """
+ return self._jrdd.rdd().isCheckpointed()
+
+ def getCheckpointFile(self):
+ """
+ Gets the name of the file to which this RDD was checkpointed
+ """
+ checkpointFile = self._jrdd.rdd().getCheckpointFile()
+ if checkpointFile.isDefined():
+ return checkpointFile.get()
+ else:
+ return None
+
# TODO persist(self, storageLevel)
def map(self, f, preservesPartitioning=False):
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
new file mode 100644
index 0000000000..c959d5dec7
--- /dev/null
+++ b/python/pyspark/tests.py
@@ -0,0 +1,46 @@
+"""
+Unit tests for PySpark; additional tests are implemented as doctests in
+individual modules.
+"""
+import atexit
+import os
+import shutil
+from tempfile import NamedTemporaryFile
+import time
+import unittest
+
+from pyspark.context import SparkContext
+
+
+class TestCheckpoint(unittest.TestCase):
+
+ def setUp(self):
+ self.sc = SparkContext('local[4]', 'TestPartitioning', batchSize=2)
+
+ def tearDown(self):
+ self.sc.stop()
+
+ def test_basic_checkpointing(self):
+ checkpointDir = NamedTemporaryFile(delete=False)
+ os.unlink(checkpointDir.name)
+ self.sc.setCheckpointDir(checkpointDir.name)
+
+ parCollection = self.sc.parallelize([1, 2, 3, 4])
+ flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1))
+
+ self.assertFalse(flatMappedRDD.isCheckpointed())
+ self.assertIsNone(flatMappedRDD.getCheckpointFile())
+
+ flatMappedRDD.checkpoint()
+ result = flatMappedRDD.collect()
+ time.sleep(1) # 1 second
+ self.assertTrue(flatMappedRDD.isCheckpointed())
+ self.assertEqual(flatMappedRDD.collect(), result)
+ self.assertEqual(checkpointDir.name,
+ os.path.dirname(flatMappedRDD.getCheckpointFile()))
+
+ atexit.register(lambda: shutil.rmtree(checkpointDir.name))
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/python/run-tests b/python/run-tests
index 32470911f9..ce214e98a8 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -14,6 +14,9 @@ FAILED=$(($?||$FAILED))
$FWDIR/pyspark -m doctest pyspark/accumulators.py
FAILED=$(($?||$FAILED))
+$FWDIR/pyspark -m unittest pyspark.tests
+FAILED=$(($?||$FAILED))
+
if [[ $FAILED != 0 ]]; then
echo -en "\033[31m" # Red
echo "Had test failures; see logs."