aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorMatei Zaharia <matei@eecs.berkeley.edu>2013-01-20 17:10:44 -0800
committerMatei Zaharia <matei@eecs.berkeley.edu>2013-01-20 17:10:44 -0800
commitc7b5e5f1ec9b86b91cf955fb16130e1ad2fcd2e9 (patch)
tree0a9d1751509c4af6235dd921967d508c6f8d130b /python
parent14360aba89aae6ff8e4568d08a4570907647da6e (diff)
parent00d70cd6602d5ff2718e319ec04defbdd486237e (diff)
downloadspark-c7b5e5f1ec9b86b91cf955fb16130e1ad2fcd2e9.tar.gz
spark-c7b5e5f1ec9b86b91cf955fb16130e1ad2fcd2e9.tar.bz2
spark-c7b5e5f1ec9b86b91cf955fb16130e1ad2fcd2e9.zip
Merge pull request #389 from JoshRosen/python_rdd_checkpointing
Add checkpointing to the Python API
Diffstat (limited to 'python')
-rw-r--r--python/epydoc.conf2
-rw-r--r--python/pyspark/context.py18
-rw-r--r--python/pyspark/rdd.py35
-rw-r--r--python/pyspark/tests.py61
-rwxr-xr-xpython/run-tests3
5 files changed, 116 insertions, 3 deletions
diff --git a/python/epydoc.conf b/python/epydoc.conf
index 91ac984ba2..45102cd9fe 100644
--- a/python/epydoc.conf
+++ b/python/epydoc.conf
@@ -16,4 +16,4 @@ target: docs/
private: no
exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.serializers
- pyspark.java_gateway pyspark.examples pyspark.shell
+ pyspark.java_gateway pyspark.examples pyspark.shell pyspark.test
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 1e2f845f9c..dcbed37270 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -123,6 +123,10 @@ class SparkContext(object):
jrdd = self._jsc.textFile(name, minSplits)
return RDD(jrdd, self)
+ def _checkpointFile(self, name):
+ jrdd = self._jsc.checkpointFile(name)
+ return RDD(jrdd, self)
+
def union(self, rdds):
"""
Build the union of a list of RDDs.
@@ -145,7 +149,7 @@ class SparkContext(object):
def accumulator(self, value, accum_param=None):
"""
Create an C{Accumulator} with the given initial value, using a given
- AccumulatorParam helper object to define how to add values of the data
+ AccumulatorParam helper object to define how to add values of the data
type if provided. Default AccumulatorParams are used for integers and
floating-point numbers if you do not provide one. For other types, the
AccumulatorParam must implement two methods:
@@ -195,3 +199,15 @@ class SparkContext(object):
filename = path.split("/")[-1]
os.environ["PYTHONPATH"] = \
"%s:%s" % (filename, os.environ["PYTHONPATH"])
+
+ def setCheckpointDir(self, dirName, useExisting=False):
+ """
+ Set the directory under which RDDs are going to be checkpointed. The
+ directory must be a HDFS path if running on a cluster.
+
+ If the directory does not exist, it will be created. If the directory
+ exists and C{useExisting} is set to true, then the exisiting directory
+ will be used. Otherwise an exception will be thrown to prevent
+ accidental overriding of checkpoint files in the existing directory.
+ """
+ self._jsc.sc().setCheckpointDir(dirName, useExisting)
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index b58bf24e3e..d53355a8f1 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -32,6 +32,7 @@ class RDD(object):
def __init__(self, jrdd, ctx):
self._jrdd = jrdd
self.is_cached = False
+ self.is_checkpointed = False
self.ctx = ctx
self._partitionFunc = None
@@ -50,6 +51,34 @@ class RDD(object):
self._jrdd.cache()
return self
+ def checkpoint(self):
+ """
+ Mark this RDD for checkpointing. It will be saved to a file inside the
+ checkpoint directory set with L{SparkContext.setCheckpointDir()} and
+ all references to its parent RDDs will be removed. This function must
+ be called before any job has been executed on this RDD. It is strongly
+ recommended that this RDD is persisted in memory, otherwise saving it
+ on a file will require recomputation.
+ """
+ self.is_checkpointed = True
+ self._jrdd.rdd().checkpoint()
+
+ def isCheckpointed(self):
+ """
+ Return whether this RDD has been checkpointed or not
+ """
+ return self._jrdd.rdd().isCheckpointed()
+
+ def getCheckpointFile(self):
+ """
+ Gets the name of the file to which this RDD was checkpointed
+ """
+ checkpointFile = self._jrdd.rdd().getCheckpointFile()
+ if checkpointFile.isDefined():
+ return checkpointFile.get()
+ else:
+ return None
+
# TODO persist(self, storageLevel)
def map(self, f, preservesPartitioning=False):
@@ -667,7 +696,7 @@ class PipelinedRDD(RDD):
20
"""
def __init__(self, prev, func, preservesPartitioning=False):
- if isinstance(prev, PipelinedRDD) and not prev.is_cached:
+ if isinstance(prev, PipelinedRDD) and prev._is_pipelinable():
prev_func = prev.func
def pipeline_func(split, iterator):
return func(split, prev_func(split, iterator))
@@ -680,6 +709,7 @@ class PipelinedRDD(RDD):
self.preservesPartitioning = preservesPartitioning
self._prev_jrdd = prev._jrdd
self.is_cached = False
+ self.is_checkpointed = False
self.ctx = prev.ctx
self.prev = prev
self._jrdd_val = None
@@ -712,6 +742,9 @@ class PipelinedRDD(RDD):
self._jrdd_val = python_rdd.asJavaRDD()
return self._jrdd_val
+ def _is_pipelinable(self):
+ return not (self.is_cached or self.is_checkpointed)
+
def _test():
import doctest
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
new file mode 100644
index 0000000000..b0a403b580
--- /dev/null
+++ b/python/pyspark/tests.py
@@ -0,0 +1,61 @@
+"""
+Unit tests for PySpark; additional tests are implemented as doctests in
+individual modules.
+"""
+import os
+import shutil
+from tempfile import NamedTemporaryFile
+import time
+import unittest
+
+from pyspark.context import SparkContext
+
+
+class TestCheckpoint(unittest.TestCase):
+
+ def setUp(self):
+ self.sc = SparkContext('local[4]', 'TestPartitioning', batchSize=2)
+ self.checkpointDir = NamedTemporaryFile(delete=False)
+ os.unlink(self.checkpointDir.name)
+ self.sc.setCheckpointDir(self.checkpointDir.name)
+
+ def tearDown(self):
+ self.sc.stop()
+ # To avoid Akka rebinding to the same port, since it doesn't unbind
+ # immediately on shutdown
+ self.sc.jvm.System.clearProperty("spark.master.port")
+ shutil.rmtree(self.checkpointDir.name)
+
+ def test_basic_checkpointing(self):
+ parCollection = self.sc.parallelize([1, 2, 3, 4])
+ flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1))
+
+ self.assertFalse(flatMappedRDD.isCheckpointed())
+ self.assertIsNone(flatMappedRDD.getCheckpointFile())
+
+ flatMappedRDD.checkpoint()
+ result = flatMappedRDD.collect()
+ time.sleep(1) # 1 second
+ self.assertTrue(flatMappedRDD.isCheckpointed())
+ self.assertEqual(flatMappedRDD.collect(), result)
+ self.assertEqual(self.checkpointDir.name,
+ os.path.dirname(flatMappedRDD.getCheckpointFile()))
+
+ def test_checkpoint_and_restore(self):
+ parCollection = self.sc.parallelize([1, 2, 3, 4])
+ flatMappedRDD = parCollection.flatMap(lambda x: [x])
+
+ self.assertFalse(flatMappedRDD.isCheckpointed())
+ self.assertIsNone(flatMappedRDD.getCheckpointFile())
+
+ flatMappedRDD.checkpoint()
+ flatMappedRDD.count() # forces a checkpoint to be computed
+ time.sleep(1) # 1 second
+
+ self.assertIsNotNone(flatMappedRDD.getCheckpointFile())
+ recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile())
+ self.assertEquals([1, 2, 3, 4], recovered.collect())
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/python/run-tests b/python/run-tests
index 32470911f9..ce214e98a8 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -14,6 +14,9 @@ FAILED=$(($?||$FAILED))
$FWDIR/pyspark -m doctest pyspark/accumulators.py
FAILED=$(($?||$FAILED))
+$FWDIR/pyspark -m unittest pyspark.tests
+FAILED=$(($?||$FAILED))
+
if [[ $FAILED != 0 ]]; then
echo -en "\033[31m" # Red
echo "Had test failures; see logs."