aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/rdd.py
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@eecs.berkeley.edu>2013-01-14 15:30:42 -0800
committerJosh Rosen <joshrosen@eecs.berkeley.edu>2013-01-20 15:41:42 -0800
commit9f211dd3f0132daf72fb39883fa4b28e4fd547ca (patch)
tree270d3bf88a053e858921277d329b5ace6843bac1 /python/pyspark/rdd.py
parentfe85a075117a79675971aff0cd020bba446c0233 (diff)
downloadspark-9f211dd3f0132daf72fb39883fa4b28e4fd547ca.tar.gz
spark-9f211dd3f0132daf72fb39883fa4b28e4fd547ca.tar.bz2
spark-9f211dd3f0132daf72fb39883fa4b28e4fd547ca.zip
Fix PythonPartitioner equality; see SPARK-654.
PythonPartitioner did not take the Python-side partitioning function into account when checking for equality, which might cause problems in the future.
Diffstat (limited to 'python/pyspark/rdd.py')
-rw-r--r--python/pyspark/rdd.py17
1 files changed, 11 insertions, 6 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index d705f0f9e1..b58bf24e3e 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -33,6 +33,7 @@ class RDD(object):
self._jrdd = jrdd
self.is_cached = False
self.ctx = ctx
+ self._partitionFunc = None
@property
def context(self):
@@ -497,7 +498,7 @@ class RDD(object):
return python_right_outer_join(self, other, numSplits)
# TODO: add option to control map-side combining
- def partitionBy(self, numSplits, hashFunc=hash):
+ def partitionBy(self, numSplits, partitionFunc=hash):
"""
Return a copy of the RDD partitioned using the specified partitioner.
@@ -514,17 +515,21 @@ class RDD(object):
def add_shuffle_key(split, iterator):
buckets = defaultdict(list)
for (k, v) in iterator:
- buckets[hashFunc(k) % numSplits].append((k, v))
+ buckets[partitionFunc(k) % numSplits].append((k, v))
for (split, items) in buckets.iteritems():
yield str(split)
yield dump_pickle(Batch(items))
keyed = PipelinedRDD(self, add_shuffle_key)
keyed._bypass_serializer = True
pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
- partitioner = self.ctx.jvm.spark.api.python.PythonPartitioner(numSplits)
- jrdd = pairRDD.partitionBy(partitioner)
- jrdd = jrdd.map(self.ctx.jvm.ExtractValue())
- return RDD(jrdd, self.ctx)
+ partitioner = self.ctx.jvm.PythonPartitioner(numSplits,
+ id(partitionFunc))
+ jrdd = pairRDD.partitionBy(partitioner).values()
+ rdd = RDD(jrdd, self.ctx)
+ # This is required so that id(partitionFunc) remains unique, even if
+ # partitionFunc is a lambda:
+ rdd._partitionFunc = partitionFunc
+ return rdd
# TODO: add control over map-side aggregation
def combineByKey(self, createCombiner, mergeValue, mergeCombiners,