aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/rdd.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/rdd.py')
-rw-r--r--python/pyspark/rdd.py49
1 files changed, 33 insertions, 16 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index bd4f16e058..ba2347ae76 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -111,6 +111,19 @@ def _parse_memory(s):
return int(float(s[:-1]) * units[s[-1].lower()])
+class Partitioner(object):
+ def __init__(self, numPartitions, partitionFunc):
+ self.numPartitions = numPartitions
+ self.partitionFunc = partitionFunc
+
+ def __eq__(self, other):
+ return (isinstance(other, Partitioner) and self.numPartitions == other.numPartitions
+ and self.partitionFunc == other.partitionFunc)
+
+ def __call__(self, k):
+ return self.partitionFunc(k) % self.numPartitions
+
+
class RDD(object):
"""
@@ -126,7 +139,7 @@ class RDD(object):
self.ctx = ctx
self._jrdd_deserializer = jrdd_deserializer
self._id = jrdd.id()
- self._partitionFunc = None
+ self.partitioner = None
def _pickled(self):
return self._reserialize(AutoBatchedSerializer(PickleSerializer()))
@@ -450,14 +463,17 @@ class RDD(object):
if self._jrdd_deserializer == other._jrdd_deserializer:
rdd = RDD(self._jrdd.union(other._jrdd), self.ctx,
self._jrdd_deserializer)
- return rdd
else:
# These RDDs contain data in different serialized formats, so we
# must normalize them to the default serializer.
self_copy = self._reserialize()
other_copy = other._reserialize()
- return RDD(self_copy._jrdd.union(other_copy._jrdd), self.ctx,
- self.ctx.serializer)
+ rdd = RDD(self_copy._jrdd.union(other_copy._jrdd), self.ctx,
+ self.ctx.serializer)
+ if (self.partitioner == other.partitioner and
+ self.getNumPartitions() == rdd.getNumPartitions()):
+ rdd.partitioner = self.partitioner
+ return rdd
def intersection(self, other):
"""
@@ -1588,6 +1604,9 @@ class RDD(object):
"""
if numPartitions is None:
numPartitions = self._defaultReducePartitions()
+ partitioner = Partitioner(numPartitions, partitionFunc)
+ if self.partitioner == partitioner:
+ return self
# Transferring O(n) objects to Java is too expensive.
# Instead, we'll form the hash buckets in Python,
@@ -1632,18 +1651,16 @@ class RDD(object):
yield pack_long(split)
yield outputSerializer.dumps(items)
- keyed = self.mapPartitionsWithIndex(add_shuffle_key)
+ keyed = self.mapPartitionsWithIndex(add_shuffle_key, preservesPartitioning=True)
keyed._bypass_serializer = True
with SCCallSiteSync(self.context) as css:
pairRDD = self.ctx._jvm.PairwiseRDD(
keyed._jrdd.rdd()).asJavaPairRDD()
- partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
- id(partitionFunc))
- jrdd = pairRDD.partitionBy(partitioner).values()
+ jpartitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
+ id(partitionFunc))
+ jrdd = self.ctx._jvm.PythonRDD.valueOfPair(pairRDD.partitionBy(jpartitioner))
rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer))
- # This is required so that id(partitionFunc) remains unique,
- # even if partitionFunc is a lambda:
- rdd._partitionFunc = partitionFunc
+ rdd.partitioner = partitioner
return rdd
# TODO: add control over map-side aggregation
@@ -1689,7 +1706,7 @@ class RDD(object):
merger.mergeValues(iterator)
return merger.iteritems()
- locally_combined = self.mapPartitions(combineLocally)
+ locally_combined = self.mapPartitions(combineLocally, preservesPartitioning=True)
shuffled = locally_combined.partitionBy(numPartitions)
def _mergeCombiners(iterator):
@@ -1698,7 +1715,7 @@ class RDD(object):
merger.mergeCombiners(iterator)
return merger.iteritems()
- return shuffled.mapPartitions(_mergeCombiners, True)
+ return shuffled.mapPartitions(_mergeCombiners, preservesPartitioning=True)
def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None):
"""
@@ -2077,8 +2094,8 @@ class RDD(object):
"""
values = self.filter(lambda (k, v): k == key).values()
- if self._partitionFunc is not None:
- return self.ctx.runJob(values, lambda x: x, [self._partitionFunc(key)], False)
+ if self.partitioner is not None:
+ return self.ctx.runJob(values, lambda x: x, [self.partitioner(key)], False)
return values.collect()
@@ -2243,7 +2260,7 @@ class PipelinedRDD(RDD):
self._id = None
self._jrdd_deserializer = self.ctx.serializer
self._bypass_serializer = False
- self._partitionFunc = prev._partitionFunc if self.preservesPartitioning else None
+ self.partitioner = prev.partitioner if self.preservesPartitioning else None
self._broadcast = None
def __del__(self):