aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatei Zaharia <matei@eecs.berkeley.edu>2013-01-20 17:10:44 -0800
committerMatei Zaharia <matei@eecs.berkeley.edu>2013-01-20 17:10:44 -0800
commitc7b5e5f1ec9b86b91cf955fb16130e1ad2fcd2e9 (patch)
tree0a9d1751509c4af6235dd921967d508c6f8d130b
parent14360aba89aae6ff8e4568d08a4570907647da6e (diff)
parent00d70cd6602d5ff2718e319ec04defbdd486237e (diff)
downloadspark-c7b5e5f1ec9b86b91cf955fb16130e1ad2fcd2e9.tar.gz
spark-c7b5e5f1ec9b86b91cf955fb16130e1ad2fcd2e9.tar.bz2
spark-c7b5e5f1ec9b86b91cf955fb16130e1ad2fcd2e9.zip
Merge pull request #389 from JoshRosen/python_rdd_checkpointing
Add checkpointing to the Python API
-rw-r--r--core/src/main/scala/spark/api/java/JavaRDDLike.scala17
-rw-r--r--core/src/main/scala/spark/api/java/JavaSparkContext.scala17
-rw-r--r--core/src/main/scala/spark/api/python/PythonRDD.scala3
-rw-r--r--python/epydoc.conf2
-rw-r--r--python/pyspark/context.py18
-rw-r--r--python/pyspark/rdd.py35
-rw-r--r--python/pyspark/tests.py61
-rwxr-xr-xpython/run-tests3
8 files changed, 132 insertions, 24 deletions
diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala
index 087270e46d..b3698ffa44 100644
--- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala
@@ -307,16 +307,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
implicit val kcm: ClassManifest[K] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]]
JavaPairRDD.fromRDD(rdd.keyBy(f))
}
-
- /**
- * Mark this RDD for checkpointing. The RDD will be saved to a file inside `checkpointDir`
- * (set using setCheckpointDir()) and all references to its parent RDDs will be removed.
- * This is used to truncate very long lineages. In the current implementation, Spark will save
- * this RDD to a file (using saveAsObjectFile()) after the first job using this RDD is done.
- * Hence, it is strongly recommended to use checkpoint() on RDDs when
- * (i) checkpoint() is called before the any job has been executed on this RDD.
- * (ii) This RDD has been made to persist in memory. Otherwise saving it on a file will
- * require recomputation.
+
+ /**
+ * Mark this RDD for checkpointing. It will be saved to a file inside the checkpoint
+ * directory set with SparkContext.setCheckpointDir() and all references to its parent
+ * RDDs will be removed. This function must be called before any job has been
+ * executed on this RDD. It is strongly recommended that this RDD is persisted in
+ * memory, otherwise saving it on a file will require recomputation.
*/
def checkpoint() = rdd.checkpoint()
diff --git a/core/src/main/scala/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/spark/api/java/JavaSparkContext.scala
index fa2f14113d..14699961ad 100644
--- a/core/src/main/scala/spark/api/java/JavaSparkContext.scala
+++ b/core/src/main/scala/spark/api/java/JavaSparkContext.scala
@@ -357,20 +357,21 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
}
/**
- * Set the directory under which RDDs are going to be checkpointed. This method will
- * create this directory and will throw an exception of the path already exists (to avoid
- * overwriting existing files may be overwritten). The directory will be deleted on exit
- * if indicated.
+ * Set the directory under which RDDs are going to be checkpointed. The directory must
+ * be a HDFS path if running on a cluster. If the directory does not exist, it will
+ * be created. If the directory exists and useExisting is set to true, then the
+ * exisiting directory will be used. Otherwise an exception will be thrown to
+ * prevent accidental overriding of checkpoint files in the existing directory.
*/
def setCheckpointDir(dir: String, useExisting: Boolean) {
sc.setCheckpointDir(dir, useExisting)
}
/**
- * Set the directory under which RDDs are going to be checkpointed. This method will
- * create this directory and will throw an exception of the path already exists (to avoid
- * overwriting existing files may be overwritten). The directory will be deleted on exit
- * if indicated.
+ * Set the directory under which RDDs are going to be checkpointed. The directory must
+ * be a HDFS path if running on a cluster. If the directory does not exist, it will
+ * be created. If the directory exists, an exception will be thrown to prevent accidental
+ * overriding of checkpoint files.
*/
def setCheckpointDir(dir: String) {
sc.setCheckpointDir(dir)
diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala
index e4c0530241..5526406a20 100644
--- a/core/src/main/scala/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/spark/api/python/PythonRDD.scala
@@ -135,8 +135,6 @@ private[spark] class PythonRDD[T: ClassManifest](
}
}
- override def checkpoint() { }
-
val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
}
@@ -152,7 +150,6 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends
case Seq(a, b) => (a, b)
case x => throw new Exception("PairwiseRDD: unexpected value: " + x)
}
- override def checkpoint() { }
val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this)
}
diff --git a/python/epydoc.conf b/python/epydoc.conf
index 91ac984ba2..45102cd9fe 100644
--- a/python/epydoc.conf
+++ b/python/epydoc.conf
@@ -16,4 +16,4 @@ target: docs/
private: no
exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.serializers
- pyspark.java_gateway pyspark.examples pyspark.shell
+ pyspark.java_gateway pyspark.examples pyspark.shell pyspark.test
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 1e2f845f9c..dcbed37270 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -123,6 +123,10 @@ class SparkContext(object):
jrdd = self._jsc.textFile(name, minSplits)
return RDD(jrdd, self)
+ def _checkpointFile(self, name):
+ jrdd = self._jsc.checkpointFile(name)
+ return RDD(jrdd, self)
+
def union(self, rdds):
"""
Build the union of a list of RDDs.
@@ -145,7 +149,7 @@ class SparkContext(object):
def accumulator(self, value, accum_param=None):
"""
Create an C{Accumulator} with the given initial value, using a given
- AccumulatorParam helper object to define how to add values of the data
+ AccumulatorParam helper object to define how to add values of the data
type if provided. Default AccumulatorParams are used for integers and
floating-point numbers if you do not provide one. For other types, the
AccumulatorParam must implement two methods:
@@ -195,3 +199,15 @@ class SparkContext(object):
filename = path.split("/")[-1]
os.environ["PYTHONPATH"] = \
"%s:%s" % (filename, os.environ["PYTHONPATH"])
+
+ def setCheckpointDir(self, dirName, useExisting=False):
+ """
+ Set the directory under which RDDs are going to be checkpointed. The
+ directory must be a HDFS path if running on a cluster.
+
+ If the directory does not exist, it will be created. If the directory
+ exists and C{useExisting} is set to true, then the exisiting directory
+ will be used. Otherwise an exception will be thrown to prevent
+ accidental overriding of checkpoint files in the existing directory.
+ """
+ self._jsc.sc().setCheckpointDir(dirName, useExisting)
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index b58bf24e3e..d53355a8f1 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -32,6 +32,7 @@ class RDD(object):
def __init__(self, jrdd, ctx):
self._jrdd = jrdd
self.is_cached = False
+ self.is_checkpointed = False
self.ctx = ctx
self._partitionFunc = None
@@ -50,6 +51,34 @@ class RDD(object):
self._jrdd.cache()
return self
+ def checkpoint(self):
+ """
+ Mark this RDD for checkpointing. It will be saved to a file inside the
+ checkpoint directory set with L{SparkContext.setCheckpointDir()} and
+ all references to its parent RDDs will be removed. This function must
+ be called before any job has been executed on this RDD. It is strongly
+ recommended that this RDD is persisted in memory, otherwise saving it
+ on a file will require recomputation.
+ """
+ self.is_checkpointed = True
+ self._jrdd.rdd().checkpoint()
+
+ def isCheckpointed(self):
+ """
+ Return whether this RDD has been checkpointed or not
+ """
+ return self._jrdd.rdd().isCheckpointed()
+
+ def getCheckpointFile(self):
+ """
+ Gets the name of the file to which this RDD was checkpointed
+ """
+ checkpointFile = self._jrdd.rdd().getCheckpointFile()
+ if checkpointFile.isDefined():
+ return checkpointFile.get()
+ else:
+ return None
+
# TODO persist(self, storageLevel)
def map(self, f, preservesPartitioning=False):
@@ -667,7 +696,7 @@ class PipelinedRDD(RDD):
20
"""
def __init__(self, prev, func, preservesPartitioning=False):
- if isinstance(prev, PipelinedRDD) and not prev.is_cached:
+ if isinstance(prev, PipelinedRDD) and prev._is_pipelinable():
prev_func = prev.func
def pipeline_func(split, iterator):
return func(split, prev_func(split, iterator))
@@ -680,6 +709,7 @@ class PipelinedRDD(RDD):
self.preservesPartitioning = preservesPartitioning
self._prev_jrdd = prev._jrdd
self.is_cached = False
+ self.is_checkpointed = False
self.ctx = prev.ctx
self.prev = prev
self._jrdd_val = None
@@ -712,6 +742,9 @@ class PipelinedRDD(RDD):
self._jrdd_val = python_rdd.asJavaRDD()
return self._jrdd_val
+ def _is_pipelinable(self):
+ return not (self.is_cached or self.is_checkpointed)
+
def _test():
import doctest
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
new file mode 100644
index 0000000000..b0a403b580
--- /dev/null
+++ b/python/pyspark/tests.py
@@ -0,0 +1,61 @@
+"""
+Unit tests for PySpark; additional tests are implemented as doctests in
+individual modules.
+"""
+import os
+import shutil
+from tempfile import NamedTemporaryFile
+import time
+import unittest
+
+from pyspark.context import SparkContext
+
+
+class TestCheckpoint(unittest.TestCase):
+
+ def setUp(self):
+ self.sc = SparkContext('local[4]', 'TestPartitioning', batchSize=2)
+ self.checkpointDir = NamedTemporaryFile(delete=False)
+ os.unlink(self.checkpointDir.name)
+ self.sc.setCheckpointDir(self.checkpointDir.name)
+
+ def tearDown(self):
+ self.sc.stop()
+ # To avoid Akka rebinding to the same port, since it doesn't unbind
+ # immediately on shutdown
+ self.sc.jvm.System.clearProperty("spark.master.port")
+ shutil.rmtree(self.checkpointDir.name)
+
+ def test_basic_checkpointing(self):
+ parCollection = self.sc.parallelize([1, 2, 3, 4])
+ flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1))
+
+ self.assertFalse(flatMappedRDD.isCheckpointed())
+ self.assertIsNone(flatMappedRDD.getCheckpointFile())
+
+ flatMappedRDD.checkpoint()
+ result = flatMappedRDD.collect()
+ time.sleep(1) # 1 second
+ self.assertTrue(flatMappedRDD.isCheckpointed())
+ self.assertEqual(flatMappedRDD.collect(), result)
+ self.assertEqual(self.checkpointDir.name,
+ os.path.dirname(flatMappedRDD.getCheckpointFile()))
+
+ def test_checkpoint_and_restore(self):
+ parCollection = self.sc.parallelize([1, 2, 3, 4])
+ flatMappedRDD = parCollection.flatMap(lambda x: [x])
+
+ self.assertFalse(flatMappedRDD.isCheckpointed())
+ self.assertIsNone(flatMappedRDD.getCheckpointFile())
+
+ flatMappedRDD.checkpoint()
+ flatMappedRDD.count() # forces a checkpoint to be computed
+ time.sleep(1) # 1 second
+
+ self.assertIsNotNone(flatMappedRDD.getCheckpointFile())
+ recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile())
+ self.assertEquals([1, 2, 3, 4], recovered.collect())
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/python/run-tests b/python/run-tests
index 32470911f9..ce214e98a8 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -14,6 +14,9 @@ FAILED=$(($?||$FAILED))
$FWDIR/pyspark -m doctest pyspark/accumulators.py
FAILED=$(($?||$FAILED))
+$FWDIR/pyspark -m unittest pyspark.tests
+FAILED=$(($?||$FAILED))
+
if [[ $FAILED != 0 ]]; then
echo -en "\033[31m" # Red
echo "Had test failures; see logs."