diff options
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/rdd.py | 21 |
1 files changed, 18 insertions, 3 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index a0b2c744f0..62a95c8467 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -512,7 +512,7 @@ class RDD(object): [('a', 3), ('fleece', 7), ('had', 2), ('lamb', 5), ('little', 4), ('Mary', 1), ('was', 8), ('white', 9), ('whose', 6)] """ if numPartitions is None: - numPartitions = self.ctx.defaultParallelism + numPartitions = self._defaultReducePartitions() bounds = list() @@ -1154,7 +1154,7 @@ class RDD(object): set([]) """ if numPartitions is None: - numPartitions = self.ctx.defaultParallelism + numPartitions = self._defaultReducePartitions() if partitionFunc is None: partitionFunc = lambda x: 0 if x is None else hash(x) @@ -1212,7 +1212,7 @@ class RDD(object): [('a', '11'), ('b', '1')] """ if numPartitions is None: - numPartitions = self.ctx.defaultParallelism + numPartitions = self._defaultReducePartitions() def combineLocally(iterator): combiners = {} for x in iterator: @@ -1475,6 +1475,21 @@ class RDD(object): java_storage_level.replication()) return storage_level + def _defaultReducePartitions(self): + """ + Returns the default number of partitions to use during reduce tasks (e.g., groupBy). + If spark.default.parallelism is set, then we'll use the value from SparkContext + defaultParallelism, otherwise we'll use the number of partitions in this RDD. + + This mirrors the behavior of the Scala Partitioner#defaultPartitioner, intended to reduce + the likelihood of OOMs. Once PySpark adopts Partitioner-based APIs, this behavior will + be inherent. + """ + if self.ctx._conf.contains("spark.default.parallelism"): + return self.ctx.defaultParallelism + else: + return self.getNumPartitions() + # TODO: `lookup` is disabled because we can't make direct comparisons based # on the key; we need to compare the hash of the key to the hash of the # keys in the pairs. This could be an expensive operation, since those |