diff options
author | Josh Rosen <joshrosen@eecs.berkeley.edu> | 2013-01-20 13:59:45 -0800 |
---|---|---|
committer | Josh Rosen <joshrosen@eecs.berkeley.edu> | 2013-01-20 13:59:45 -0800 |
commit | d0ba80dc727d00b2b7627dcefd2c77009af55f7d (patch) | |
tree | eadf5f88feb468179527459a1f7316b3ca5ac2cf /python | |
parent | 7ed1bf4b485131d58ea6728e7247b79320aca9e6 (diff) | |
download | spark-d0ba80dc727d00b2b7627dcefd2c77009af55f7d.tar.gz spark-d0ba80dc727d00b2b7627dcefd2c77009af55f7d.tar.bz2 spark-d0ba80dc727d00b2b7627dcefd2c77009af55f7d.zip |
Add checkpointFile() and more tests to PySpark.
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/context.py | 6 | ||||
-rw-r--r-- | python/pyspark/rdd.py | 9 | ||||
-rw-r--r-- | python/pyspark/tests.py | 24 |
3 files changed, 37 insertions, 2 deletions
diff --git a/python/pyspark/context.py b/python/pyspark/context.py index a438b43fdc..8beb8e2ae9 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -123,6 +123,10 @@ class SparkContext(object): jrdd = self._jsc.textFile(name, minSplits) return RDD(jrdd, self) + def _checkpointFile(self, name): + jrdd = self._jsc.checkpointFile(name) + return RDD(jrdd, self) + def union(self, rdds): """ Build the union of a list of RDDs. @@ -145,7 +149,7 @@ class SparkContext(object): def accumulator(self, value, accum_param=None): """ Create an C{Accumulator} with the given initial value, using a given - AccumulatorParam helper object to define how to add values of the data + AccumulatorParam helper object to define how to add values of the data type if provided. Default AccumulatorParams are used for integers and floating-point numbers if you do not provide one. For other types, the AccumulatorParam must implement two methods: diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 9b676cae4a..2a2ff9b271 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -32,6 +32,7 @@ class RDD(object): def __init__(self, jrdd, ctx): self._jrdd = jrdd self.is_cached = False + self.is_checkpointed = False self.ctx = ctx @property @@ -65,6 +66,7 @@ class RDD(object): (ii) This RDD has been made to persist in memory. Otherwise saving it on a file will require recomputation. """ + self.is_checkpointed = True self._jrdd.rdd().checkpoint() def isCheckpointed(self): @@ -696,7 +698,7 @@ class PipelinedRDD(RDD): 20 """ def __init__(self, prev, func, preservesPartitioning=False): - if isinstance(prev, PipelinedRDD) and not prev.is_cached: + if isinstance(prev, PipelinedRDD) and prev._is_pipelinable: prev_func = prev.func def pipeline_func(split, iterator): return func(split, prev_func(split, iterator)) @@ -709,6 +711,7 @@ class PipelinedRDD(RDD): self.preservesPartitioning = preservesPartitioning self._prev_jrdd = prev._jrdd self.is_cached = False + self.is_checkpointed = False self.ctx = prev.ctx self.prev = prev self._jrdd_val = None @@ -741,6 +744,10 @@ class PipelinedRDD(RDD): self._jrdd_val = python_rdd.asJavaRDD() return self._jrdd_val + @property + def _is_pipelinable(self): + return not (self.is_cached or self.is_checkpointed) + def _test(): import doctest diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index c959d5dec7..83283fca4f 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -19,6 +19,9 @@ class TestCheckpoint(unittest.TestCase): def tearDown(self): self.sc.stop() + # To avoid Akka rebinding to the same port, since it doesn't unbind + # immediately on shutdown + self.sc.jvm.System.clearProperty("spark.master.port") def test_basic_checkpointing(self): checkpointDir = NamedTemporaryFile(delete=False) @@ -41,6 +44,27 @@ class TestCheckpoint(unittest.TestCase): atexit.register(lambda: shutil.rmtree(checkpointDir.name)) + def test_checkpoint_and_restore(self): + checkpointDir = NamedTemporaryFile(delete=False) + os.unlink(checkpointDir.name) + self.sc.setCheckpointDir(checkpointDir.name) + + parCollection = self.sc.parallelize([1, 2, 3, 4]) + flatMappedRDD = parCollection.flatMap(lambda x: [x]) + + self.assertFalse(flatMappedRDD.isCheckpointed()) + self.assertIsNone(flatMappedRDD.getCheckpointFile()) + + flatMappedRDD.checkpoint() + flatMappedRDD.count() # forces a checkpoint to be computed + time.sleep(1) # 1 second + + self.assertIsNotNone(flatMappedRDD.getCheckpointFile()) + recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile()) + self.assertEquals([1, 2, 3, 4], recovered.collect()) + + atexit.register(lambda: shutil.rmtree(checkpointDir.name)) + if __name__ == "__main__": unittest.main() |