aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/rdd.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/rdd.py')
-rw-r--r--python/pyspark/rdd.py9
1 files changed, 7 insertions, 2 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 3a2e7649e6..31919741e9 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -44,7 +44,7 @@ from pyspark.rddsampler import RDDSampler, RDDStratifiedSampler
from pyspark.storagelevel import StorageLevel
from pyspark.resultiterable import ResultIterable
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \
- get_used_memory
+ get_used_memory, ExternalSorter
from py4j.java_collections import ListConverter, MapConverter
@@ -605,8 +605,13 @@ class RDD(object):
if numPartitions is None:
numPartitions = self._defaultReducePartitions()
+ spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower() == 'true')
+ memory = _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m"))
+ serializer = self._jrdd_deserializer
+
def sortPartition(iterator):
- return iter(sorted(iterator, key=lambda (k, v): keyfunc(k), reverse=not ascending))
+ sort = ExternalSorter(memory * 0.9, serializer).sorted if spill else sorted
+ return iter(sort(iterator, key=lambda (k, v): keyfunc(k), reverse=(not ascending)))
if numPartitions == 1:
if self.getNumPartitions() > 1: