diff options
author | Josh Rosen <joshrosen@eecs.berkeley.edu> | 2012-10-28 22:30:28 -0700 |
---|---|---|
committer | Josh Rosen <joshrosen@eecs.berkeley.edu> | 2012-10-28 22:30:28 -0700 |
commit | 2ccf3b665280bf5b0919e3801d028126cb070dbd (patch) | |
tree | c486342f7d31dd1f529444ca40339e2a879ac219 /pyspark | |
parent | 7859879aaa1860ff6b383e32a18fd9a410a97416 (diff) | |
download | spark-2ccf3b665280bf5b0919e3801d028126cb070dbd.tar.gz spark-2ccf3b665280bf5b0919e3801d028126cb070dbd.tar.bz2 spark-2ccf3b665280bf5b0919e3801d028126cb070dbd.zip |
Fix PySpark hash partitioning bug.
A Java array's hashCode is based on its object
identify, not its elements, so this was causing
serialized keys to be hashed incorrectly.
This commit adds a PySpark-specific workaround
and adds more tests.
Diffstat (limited to 'pyspark')
-rw-r--r-- | pyspark/pyspark/rdd.py | 12 |
1 files changed, 9 insertions, 3 deletions
diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index e4878c08ba..85a24c6854 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -310,6 +310,12 @@ class RDD(object): return python_right_outer_join(self, other, numSplits) def partitionBy(self, numSplits, hashFunc=hash): + """ + >>> pairs = sc.parallelize([1, 2, 3, 4, 2, 4, 1]).map(lambda x: (x, x)) + >>> sets = pairs.partitionBy(2).glom().collect() + >>> set(sets[0]).intersection(set(sets[1])) + set([]) + """ if numSplits is None: numSplits = self.ctx.defaultParallelism def add_shuffle_key(iterator): @@ -319,7 +325,7 @@ class RDD(object): keyed = PipelinedRDD(self, add_shuffle_key) keyed._bypass_serializer = True pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() - partitioner = self.ctx.jvm.spark.HashPartitioner(numSplits) + partitioner = self.ctx.jvm.spark.api.python.PythonPartitioner(numSplits) jrdd = pairRDD.partitionBy(partitioner) jrdd = jrdd.map(self.ctx.jvm.ExtractValue()) return RDD(jrdd, self.ctx) @@ -391,7 +397,7 @@ class RDD(object): """ >>> x = sc.parallelize([("a", 1), ("b", 4)]) >>> y = sc.parallelize([("a", 2)]) - >>> x.cogroup(y).collect() + >>> sorted(x.cogroup(y).collect()) [('a', ([1], [2])), ('b', ([4], []))] """ return python_cogroup(self, other, numSplits) @@ -462,7 +468,7 @@ def _test(): import doctest from pyspark.context import SparkContext globs = globals().copy() - globs['sc'] = SparkContext('local', 'PythonTest') + globs['sc'] = SparkContext('local[4]', 'PythonTest') doctest.testmod(globs=globs) globs['sc'].stop() |