aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/context.py6
-rw-r--r--python/pyspark/rdd.py9
-rw-r--r--python/pyspark/tests.py24
3 files changed, 37 insertions, 2 deletions
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index a438b43fdc..8beb8e2ae9 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -123,6 +123,10 @@ class SparkContext(object):
jrdd = self._jsc.textFile(name, minSplits)
return RDD(jrdd, self)
+ def _checkpointFile(self, name):
+ jrdd = self._jsc.checkpointFile(name)
+ return RDD(jrdd, self)
+
def union(self, rdds):
"""
Build the union of a list of RDDs.
@@ -145,7 +149,7 @@ class SparkContext(object):
def accumulator(self, value, accum_param=None):
"""
Create an C{Accumulator} with the given initial value, using a given
- AccumulatorParam helper object to define how to add values of the data
+ AccumulatorParam helper object to define how to add values of the data
type if provided. Default AccumulatorParams are used for integers and
floating-point numbers if you do not provide one. For other types, the
AccumulatorParam must implement two methods:
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 9b676cae4a..2a2ff9b271 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -32,6 +32,7 @@ class RDD(object):
def __init__(self, jrdd, ctx):
self._jrdd = jrdd
self.is_cached = False
+ self.is_checkpointed = False
self.ctx = ctx
@property
@@ -65,6 +66,7 @@ class RDD(object):
(ii) This RDD has been made to persist in memory. Otherwise saving it
on a file will require recomputation.
"""
+ self.is_checkpointed = True
self._jrdd.rdd().checkpoint()
def isCheckpointed(self):
@@ -696,7 +698,7 @@ class PipelinedRDD(RDD):
20
"""
def __init__(self, prev, func, preservesPartitioning=False):
- if isinstance(prev, PipelinedRDD) and not prev.is_cached:
+ if isinstance(prev, PipelinedRDD) and prev._is_pipelinable:
prev_func = prev.func
def pipeline_func(split, iterator):
return func(split, prev_func(split, iterator))
@@ -709,6 +711,7 @@ class PipelinedRDD(RDD):
self.preservesPartitioning = preservesPartitioning
self._prev_jrdd = prev._jrdd
self.is_cached = False
+ self.is_checkpointed = False
self.ctx = prev.ctx
self.prev = prev
self._jrdd_val = None
@@ -741,6 +744,10 @@ class PipelinedRDD(RDD):
self._jrdd_val = python_rdd.asJavaRDD()
return self._jrdd_val
+ @property
+ def _is_pipelinable(self):
+ return not (self.is_cached or self.is_checkpointed)
+
def _test():
import doctest
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index c959d5dec7..83283fca4f 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -19,6 +19,9 @@ class TestCheckpoint(unittest.TestCase):
def tearDown(self):
self.sc.stop()
+ # To avoid Akka rebinding to the same port, since it doesn't unbind
+ # immediately on shutdown
+ self.sc.jvm.System.clearProperty("spark.master.port")
def test_basic_checkpointing(self):
checkpointDir = NamedTemporaryFile(delete=False)
@@ -41,6 +44,27 @@ class TestCheckpoint(unittest.TestCase):
atexit.register(lambda: shutil.rmtree(checkpointDir.name))
+ def test_checkpoint_and_restore(self):
+ checkpointDir = NamedTemporaryFile(delete=False)
+ os.unlink(checkpointDir.name)
+ self.sc.setCheckpointDir(checkpointDir.name)
+
+ parCollection = self.sc.parallelize([1, 2, 3, 4])
+ flatMappedRDD = parCollection.flatMap(lambda x: [x])
+
+ self.assertFalse(flatMappedRDD.isCheckpointed())
+ self.assertIsNone(flatMappedRDD.getCheckpointFile())
+
+ flatMappedRDD.checkpoint()
+ flatMappedRDD.count() # forces a checkpoint to be computed
+ time.sleep(1) # 1 second
+
+ self.assertIsNotNone(flatMappedRDD.getCheckpointFile())
+ recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile())
+ self.assertEquals([1, 2, 3, 4], recovered.collect())
+
+ atexit.register(lambda: shutil.rmtree(checkpointDir.name))
+
if __name__ == "__main__":
unittest.main()