aboutsummaryrefslogtreecommitdiff
path: root/pyspark
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@eecs.berkeley.edu>2012-10-28 22:30:28 -0700
committerJosh Rosen <joshrosen@eecs.berkeley.edu>2012-10-28 22:30:28 -0700
commit2ccf3b665280bf5b0919e3801d028126cb070dbd (patch)
treec486342f7d31dd1f529444ca40339e2a879ac219 /pyspark
parent7859879aaa1860ff6b383e32a18fd9a410a97416 (diff)
downloadspark-2ccf3b665280bf5b0919e3801d028126cb070dbd.tar.gz
spark-2ccf3b665280bf5b0919e3801d028126cb070dbd.tar.bz2
spark-2ccf3b665280bf5b0919e3801d028126cb070dbd.zip
Fix PySpark hash partitioning bug.
A Java array's hashCode is based on its object identify, not its elements, so this was causing serialized keys to be hashed incorrectly. This commit adds a PySpark-specific workaround and adds more tests.
Diffstat (limited to 'pyspark')
-rw-r--r--pyspark/pyspark/rdd.py12
1 files changed, 9 insertions, 3 deletions
diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py
index e4878c08ba..85a24c6854 100644
--- a/pyspark/pyspark/rdd.py
+++ b/pyspark/pyspark/rdd.py
@@ -310,6 +310,12 @@ class RDD(object):
return python_right_outer_join(self, other, numSplits)
def partitionBy(self, numSplits, hashFunc=hash):
+ """
+ >>> pairs = sc.parallelize([1, 2, 3, 4, 2, 4, 1]).map(lambda x: (x, x))
+ >>> sets = pairs.partitionBy(2).glom().collect()
+ >>> set(sets[0]).intersection(set(sets[1]))
+ set([])
+ """
if numSplits is None:
numSplits = self.ctx.defaultParallelism
def add_shuffle_key(iterator):
@@ -319,7 +325,7 @@ class RDD(object):
keyed = PipelinedRDD(self, add_shuffle_key)
keyed._bypass_serializer = True
pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
- partitioner = self.ctx.jvm.spark.HashPartitioner(numSplits)
+ partitioner = self.ctx.jvm.spark.api.python.PythonPartitioner(numSplits)
jrdd = pairRDD.partitionBy(partitioner)
jrdd = jrdd.map(self.ctx.jvm.ExtractValue())
return RDD(jrdd, self.ctx)
@@ -391,7 +397,7 @@ class RDD(object):
"""
>>> x = sc.parallelize([("a", 1), ("b", 4)])
>>> y = sc.parallelize([("a", 2)])
- >>> x.cogroup(y).collect()
+ >>> sorted(x.cogroup(y).collect())
[('a', ([1], [2])), ('b', ([4], []))]
"""
return python_cogroup(self, other, numSplits)
@@ -462,7 +468,7 @@ def _test():
import doctest
from pyspark.context import SparkContext
globs = globals().copy()
- globs['sc'] = SparkContext('local', 'PythonTest')
+ globs['sc'] = SparkContext('local[4]', 'PythonTest')
doctest.testmod(globs=globs)
globs['sc'].stop()