diff options
author | Matei Zaharia <matei@eecs.berkeley.edu> | 2013-01-20 17:10:44 -0800 |
---|---|---|
committer | Matei Zaharia <matei@eecs.berkeley.edu> | 2013-01-20 17:10:44 -0800 |
commit | c7b5e5f1ec9b86b91cf955fb16130e1ad2fcd2e9 (patch) | |
tree | 0a9d1751509c4af6235dd921967d508c6f8d130b | |
parent | 14360aba89aae6ff8e4568d08a4570907647da6e (diff) | |
parent | 00d70cd6602d5ff2718e319ec04defbdd486237e (diff) | |
download | spark-c7b5e5f1ec9b86b91cf955fb16130e1ad2fcd2e9.tar.gz spark-c7b5e5f1ec9b86b91cf955fb16130e1ad2fcd2e9.tar.bz2 spark-c7b5e5f1ec9b86b91cf955fb16130e1ad2fcd2e9.zip |
Merge pull request #389 from JoshRosen/python_rdd_checkpointing
Add checkpointing to the Python API
-rw-r--r-- | core/src/main/scala/spark/api/java/JavaRDDLike.scala | 17 | ||||
-rw-r--r-- | core/src/main/scala/spark/api/java/JavaSparkContext.scala | 17 | ||||
-rw-r--r-- | core/src/main/scala/spark/api/python/PythonRDD.scala | 3 | ||||
-rw-r--r-- | python/epydoc.conf | 2 | ||||
-rw-r--r-- | python/pyspark/context.py | 18 | ||||
-rw-r--r-- | python/pyspark/rdd.py | 35 | ||||
-rw-r--r-- | python/pyspark/tests.py | 61 | ||||
-rwxr-xr-x | python/run-tests | 3 |
8 files changed, 132 insertions, 24 deletions
diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala index 087270e46d..b3698ffa44 100644 --- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala @@ -307,16 +307,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { implicit val kcm: ClassManifest[K] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]] JavaPairRDD.fromRDD(rdd.keyBy(f)) } - - /** - * Mark this RDD for checkpointing. The RDD will be saved to a file inside `checkpointDir` - * (set using setCheckpointDir()) and all references to its parent RDDs will be removed. - * This is used to truncate very long lineages. In the current implementation, Spark will save - * this RDD to a file (using saveAsObjectFile()) after the first job using this RDD is done. - * Hence, it is strongly recommended to use checkpoint() on RDDs when - * (i) checkpoint() is called before the any job has been executed on this RDD. - * (ii) This RDD has been made to persist in memory. Otherwise saving it on a file will - * require recomputation. + + /** + * Mark this RDD for checkpointing. It will be saved to a file inside the checkpoint + * directory set with SparkContext.setCheckpointDir() and all references to its parent + * RDDs will be removed. This function must be called before any job has been + * executed on this RDD. It is strongly recommended that this RDD is persisted in + * memory, otherwise saving it on a file will require recomputation. */ def checkpoint() = rdd.checkpoint() diff --git a/core/src/main/scala/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/spark/api/java/JavaSparkContext.scala index fa2f14113d..14699961ad 100644 --- a/core/src/main/scala/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/spark/api/java/JavaSparkContext.scala @@ -357,20 +357,21 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork } /** - * Set the directory under which RDDs are going to be checkpointed. This method will - * create this directory and will throw an exception of the path already exists (to avoid - * overwriting existing files may be overwritten). The directory will be deleted on exit - * if indicated. + * Set the directory under which RDDs are going to be checkpointed. The directory must + * be a HDFS path if running on a cluster. If the directory does not exist, it will + * be created. If the directory exists and useExisting is set to true, then the + * exisiting directory will be used. Otherwise an exception will be thrown to + * prevent accidental overriding of checkpoint files in the existing directory. */ def setCheckpointDir(dir: String, useExisting: Boolean) { sc.setCheckpointDir(dir, useExisting) } /** - * Set the directory under which RDDs are going to be checkpointed. This method will - * create this directory and will throw an exception of the path already exists (to avoid - * overwriting existing files may be overwritten). The directory will be deleted on exit - * if indicated. + * Set the directory under which RDDs are going to be checkpointed. The directory must + * be a HDFS path if running on a cluster. If the directory does not exist, it will + * be created. If the directory exists, an exception will be thrown to prevent accidental + * overriding of checkpoint files. */ def setCheckpointDir(dir: String) { sc.setCheckpointDir(dir) diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index e4c0530241..5526406a20 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -135,8 +135,6 @@ private[spark] class PythonRDD[T: ClassManifest]( } } - override def checkpoint() { } - val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) } @@ -152,7 +150,6 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends case Seq(a, b) => (a, b) case x => throw new Exception("PairwiseRDD: unexpected value: " + x) } - override def checkpoint() { } val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this) } diff --git a/python/epydoc.conf b/python/epydoc.conf index 91ac984ba2..45102cd9fe 100644 --- a/python/epydoc.conf +++ b/python/epydoc.conf @@ -16,4 +16,4 @@ target: docs/ private: no exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.serializers - pyspark.java_gateway pyspark.examples pyspark.shell + pyspark.java_gateway pyspark.examples pyspark.shell pyspark.test diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 1e2f845f9c..dcbed37270 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -123,6 +123,10 @@ class SparkContext(object): jrdd = self._jsc.textFile(name, minSplits) return RDD(jrdd, self) + def _checkpointFile(self, name): + jrdd = self._jsc.checkpointFile(name) + return RDD(jrdd, self) + def union(self, rdds): """ Build the union of a list of RDDs. @@ -145,7 +149,7 @@ class SparkContext(object): def accumulator(self, value, accum_param=None): """ Create an C{Accumulator} with the given initial value, using a given - AccumulatorParam helper object to define how to add values of the data + AccumulatorParam helper object to define how to add values of the data type if provided. Default AccumulatorParams are used for integers and floating-point numbers if you do not provide one. For other types, the AccumulatorParam must implement two methods: @@ -195,3 +199,15 @@ class SparkContext(object): filename = path.split("/")[-1] os.environ["PYTHONPATH"] = \ "%s:%s" % (filename, os.environ["PYTHONPATH"]) + + def setCheckpointDir(self, dirName, useExisting=False): + """ + Set the directory under which RDDs are going to be checkpointed. The + directory must be a HDFS path if running on a cluster. + + If the directory does not exist, it will be created. If the directory + exists and C{useExisting} is set to true, then the exisiting directory + will be used. Otherwise an exception will be thrown to prevent + accidental overriding of checkpoint files in the existing directory. + """ + self._jsc.sc().setCheckpointDir(dirName, useExisting) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index b58bf24e3e..d53355a8f1 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -32,6 +32,7 @@ class RDD(object): def __init__(self, jrdd, ctx): self._jrdd = jrdd self.is_cached = False + self.is_checkpointed = False self.ctx = ctx self._partitionFunc = None @@ -50,6 +51,34 @@ class RDD(object): self._jrdd.cache() return self + def checkpoint(self): + """ + Mark this RDD for checkpointing. It will be saved to a file inside the + checkpoint directory set with L{SparkContext.setCheckpointDir()} and + all references to its parent RDDs will be removed. This function must + be called before any job has been executed on this RDD. It is strongly + recommended that this RDD is persisted in memory, otherwise saving it + on a file will require recomputation. + """ + self.is_checkpointed = True + self._jrdd.rdd().checkpoint() + + def isCheckpointed(self): + """ + Return whether this RDD has been checkpointed or not + """ + return self._jrdd.rdd().isCheckpointed() + + def getCheckpointFile(self): + """ + Gets the name of the file to which this RDD was checkpointed + """ + checkpointFile = self._jrdd.rdd().getCheckpointFile() + if checkpointFile.isDefined(): + return checkpointFile.get() + else: + return None + # TODO persist(self, storageLevel) def map(self, f, preservesPartitioning=False): @@ -667,7 +696,7 @@ class PipelinedRDD(RDD): 20 """ def __init__(self, prev, func, preservesPartitioning=False): - if isinstance(prev, PipelinedRDD) and not prev.is_cached: + if isinstance(prev, PipelinedRDD) and prev._is_pipelinable(): prev_func = prev.func def pipeline_func(split, iterator): return func(split, prev_func(split, iterator)) @@ -680,6 +709,7 @@ class PipelinedRDD(RDD): self.preservesPartitioning = preservesPartitioning self._prev_jrdd = prev._jrdd self.is_cached = False + self.is_checkpointed = False self.ctx = prev.ctx self.prev = prev self._jrdd_val = None @@ -712,6 +742,9 @@ class PipelinedRDD(RDD): self._jrdd_val = python_rdd.asJavaRDD() return self._jrdd_val + def _is_pipelinable(self): + return not (self.is_cached or self.is_checkpointed) + def _test(): import doctest diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py new file mode 100644 index 0000000000..b0a403b580 --- /dev/null +++ b/python/pyspark/tests.py @@ -0,0 +1,61 @@ +""" +Unit tests for PySpark; additional tests are implemented as doctests in +individual modules. +""" +import os +import shutil +from tempfile import NamedTemporaryFile +import time +import unittest + +from pyspark.context import SparkContext + + +class TestCheckpoint(unittest.TestCase): + + def setUp(self): + self.sc = SparkContext('local[4]', 'TestPartitioning', batchSize=2) + self.checkpointDir = NamedTemporaryFile(delete=False) + os.unlink(self.checkpointDir.name) + self.sc.setCheckpointDir(self.checkpointDir.name) + + def tearDown(self): + self.sc.stop() + # To avoid Akka rebinding to the same port, since it doesn't unbind + # immediately on shutdown + self.sc.jvm.System.clearProperty("spark.master.port") + shutil.rmtree(self.checkpointDir.name) + + def test_basic_checkpointing(self): + parCollection = self.sc.parallelize([1, 2, 3, 4]) + flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1)) + + self.assertFalse(flatMappedRDD.isCheckpointed()) + self.assertIsNone(flatMappedRDD.getCheckpointFile()) + + flatMappedRDD.checkpoint() + result = flatMappedRDD.collect() + time.sleep(1) # 1 second + self.assertTrue(flatMappedRDD.isCheckpointed()) + self.assertEqual(flatMappedRDD.collect(), result) + self.assertEqual(self.checkpointDir.name, + os.path.dirname(flatMappedRDD.getCheckpointFile())) + + def test_checkpoint_and_restore(self): + parCollection = self.sc.parallelize([1, 2, 3, 4]) + flatMappedRDD = parCollection.flatMap(lambda x: [x]) + + self.assertFalse(flatMappedRDD.isCheckpointed()) + self.assertIsNone(flatMappedRDD.getCheckpointFile()) + + flatMappedRDD.checkpoint() + flatMappedRDD.count() # forces a checkpoint to be computed + time.sleep(1) # 1 second + + self.assertIsNotNone(flatMappedRDD.getCheckpointFile()) + recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile()) + self.assertEquals([1, 2, 3, 4], recovered.collect()) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/run-tests b/python/run-tests index 32470911f9..ce214e98a8 100755 --- a/python/run-tests +++ b/python/run-tests @@ -14,6 +14,9 @@ FAILED=$(($?||$FAILED)) $FWDIR/pyspark -m doctest pyspark/accumulators.py FAILED=$(($?||$FAILED)) +$FWDIR/pyspark -m unittest pyspark.tests +FAILED=$(($?||$FAILED)) + if [[ $FAILED != 0 ]]; then echo -en "\033[31m" # Red echo "Had test failures; see logs." |