aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/rdd.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/rdd.py')
-rw-r--r--python/pyspark/rdd.py77
1 files changed, 58 insertions, 19 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index d705f0f9e1..4cda6cf661 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -1,4 +1,3 @@
-import atexit
from base64 import standard_b64encode as b64enc
import copy
from collections import defaultdict
@@ -32,7 +31,9 @@ class RDD(object):
def __init__(self, jrdd, ctx):
self._jrdd = jrdd
self.is_cached = False
+ self.is_checkpointed = False
self.ctx = ctx
+ self._partitionFunc = None
@property
def context(self):
@@ -49,6 +50,34 @@ class RDD(object):
self._jrdd.cache()
return self
+ def checkpoint(self):
+ """
+ Mark this RDD for checkpointing. It will be saved to a file inside the
+ checkpoint directory set with L{SparkContext.setCheckpointDir()} and
+ all references to its parent RDDs will be removed. This function must
+ be called before any job has been executed on this RDD. It is strongly
+ recommended that this RDD is persisted in memory, otherwise saving it
+ on a file will require recomputation.
+ """
+ self.is_checkpointed = True
+ self._jrdd.rdd().checkpoint()
+
+ def isCheckpointed(self):
+ """
+ Return whether this RDD has been checkpointed or not
+ """
+ return self._jrdd.rdd().isCheckpointed()
+
+ def getCheckpointFile(self):
+ """
+ Gets the name of the file to which this RDD was checkpointed
+ """
+ checkpointFile = self._jrdd.rdd().getCheckpointFile()
+ if checkpointFile.isDefined():
+ return checkpointFile.get()
+ else:
+ return None
+
# TODO persist(self, storageLevel)
def map(self, f, preservesPartitioning=False):
@@ -234,12 +263,8 @@ class RDD(object):
# Transferring lots of data through Py4J can be slow because
# socket.readline() is inefficient. Instead, we'll dump the data to a
# file and read it back.
- tempFile = NamedTemporaryFile(delete=False)
+ tempFile = NamedTemporaryFile(delete=False, dir=self.ctx._temp_dir)
tempFile.close()
- def clean_up_file():
- try: os.unlink(tempFile.name)
- except: pass
- atexit.register(clean_up_file)
self.ctx._writeIteratorToPickleFile(iterator, tempFile.name)
# Read the data into Python and deserialize it:
with open(tempFile.name, 'rb') as tempFile:
@@ -347,6 +372,10 @@ class RDD(object):
items = []
for partition in range(self._jrdd.splits().size()):
iterator = self.ctx._takePartition(self._jrdd.rdd(), partition)
+ # Each item in the iterator is a string, Python object, batch of
+ # Python objects. Regardless, it is sufficient to take `num`
+ # of these objects in order to collect `num` Python objects:
+ iterator = iterator.take(num)
items.extend(self._collect_iterator_through_file(iterator))
if len(items) >= num:
break
@@ -377,7 +406,7 @@ class RDD(object):
return (str(x).encode("utf-8") for x in iterator)
keyed = PipelinedRDD(self, func)
keyed._bypass_serializer = True
- keyed._jrdd.map(self.ctx.jvm.BytesToString()).saveAsTextFile(path)
+ keyed._jrdd.map(self.ctx._jvm.BytesToString()).saveAsTextFile(path)
# Pair functions
@@ -497,7 +526,7 @@ class RDD(object):
return python_right_outer_join(self, other, numSplits)
# TODO: add option to control map-side combining
- def partitionBy(self, numSplits, hashFunc=hash):
+ def partitionBy(self, numSplits, partitionFunc=hash):
"""
Return a copy of the RDD partitioned using the specified partitioner.
@@ -514,17 +543,21 @@ class RDD(object):
def add_shuffle_key(split, iterator):
buckets = defaultdict(list)
for (k, v) in iterator:
- buckets[hashFunc(k) % numSplits].append((k, v))
+ buckets[partitionFunc(k) % numSplits].append((k, v))
for (split, items) in buckets.iteritems():
yield str(split)
yield dump_pickle(Batch(items))
keyed = PipelinedRDD(self, add_shuffle_key)
keyed._bypass_serializer = True
- pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
- partitioner = self.ctx.jvm.spark.api.python.PythonPartitioner(numSplits)
- jrdd = pairRDD.partitionBy(partitioner)
- jrdd = jrdd.map(self.ctx.jvm.ExtractValue())
- return RDD(jrdd, self.ctx)
+ pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
+ partitioner = self.ctx._jvm.PythonPartitioner(numSplits,
+ id(partitionFunc))
+ jrdd = pairRDD.partitionBy(partitioner).values()
+ rdd = RDD(jrdd, self.ctx)
+ # This is required so that id(partitionFunc) remains unique, even if
+ # partitionFunc is a lambda:
+ rdd._partitionFunc = partitionFunc
+ return rdd
# TODO: add control over map-side aggregation
def combineByKey(self, createCombiner, mergeValue, mergeCombiners,
@@ -662,7 +695,7 @@ class PipelinedRDD(RDD):
20
"""
def __init__(self, prev, func, preservesPartitioning=False):
- if isinstance(prev, PipelinedRDD) and not prev.is_cached:
+ if isinstance(prev, PipelinedRDD) and prev._is_pipelinable():
prev_func = prev.func
def pipeline_func(split, iterator):
return func(split, prev_func(split, iterator))
@@ -675,6 +708,7 @@ class PipelinedRDD(RDD):
self.preservesPartitioning = preservesPartitioning
self._prev_jrdd = prev._jrdd
self.is_cached = False
+ self.is_checkpointed = False
self.ctx = prev.ctx
self.prev = prev
self._jrdd_val = None
@@ -695,18 +729,21 @@ class PipelinedRDD(RDD):
pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds)
broadcast_vars = ListConverter().convert(
[x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
- self.ctx.gateway._gateway_client)
+ self.ctx._gateway._gateway_client)
self.ctx._pickled_broadcast_vars.clear()
class_manifest = self._prev_jrdd.classManifest()
env = copy.copy(self.ctx.environment)
env['PYTHONPATH'] = os.environ.get("PYTHONPATH", "")
- env = MapConverter().convert(env, self.ctx.gateway._gateway_client)
- python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(),
+ env = MapConverter().convert(env, self.ctx._gateway._gateway_client)
+ python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
pipe_command, env, self.preservesPartitioning, self.ctx.pythonExec,
broadcast_vars, self.ctx._javaAccumulator, class_manifest)
self._jrdd_val = python_rdd.asJavaRDD()
return self._jrdd_val
+ def _is_pipelinable(self):
+ return not (self.is_cached or self.is_checkpointed)
+
def _test():
import doctest
@@ -715,8 +752,10 @@ def _test():
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
- doctest.testmod(globs=globs)
+ (failure_count, test_count) = doctest.testmod(globs=globs)
globs['sc'].stop()
+ if failure_count:
+ exit(-1)
if __name__ == "__main__":