aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/shuffle.py
diff options
context:
space:
mode:
authorDavies Liu <davies.liu@gmail.com>2014-08-26 16:57:40 -0700
committerMatei Zaharia <matei@databricks.com>2014-08-26 16:57:40 -0700
commitf1e71d4c3ba678fc108effb05cf2d6101dadc0ce (patch)
treeef5c761a9bf3a75c59b03148985a4a83e64a2c16 /python/pyspark/shuffle.py
parentc4787a3690a9ed3b8b2c6c294fc4a6915436b6f7 (diff)
downloadspark-f1e71d4c3ba678fc108effb05cf2d6101dadc0ce.tar.gz
spark-f1e71d4c3ba678fc108effb05cf2d6101dadc0ce.tar.bz2
spark-f1e71d4c3ba678fc108effb05cf2d6101dadc0ce.zip
[SPARK-3073] [PySpark] use external sort in sortBy() and sortByKey()
Using external sort to support sort large datasets in reduce stage. Author: Davies Liu <davies.liu@gmail.com> Closes #1978 from davies/sort and squashes the following commits: bbcd9ba [Davies Liu] check spilled bytes in tests b125d2f [Davies Liu] add test for external sort in rdd eae0176 [Davies Liu] choose different disks from different processes and instances 1f075ed [Davies Liu] Merge branch 'master' into sort eb53ca6 [Davies Liu] Merge branch 'master' into sort 644abaf [Davies Liu] add license in LICENSE 19f7873 [Davies Liu] improve tests 55602ee [Davies Liu] use external sort in sortBy() and sortByKey()
Diffstat (limited to 'python/pyspark/shuffle.py')
-rw-r--r--python/pyspark/shuffle.py91
1 files changed, 83 insertions, 8 deletions
diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py
index 1ebe7df418..49829f5280 100644
--- a/python/pyspark/shuffle.py
+++ b/python/pyspark/shuffle.py
@@ -21,7 +21,10 @@ import platform
import shutil
import warnings
import gc
+import itertools
+import random
+import pyspark.heapq3 as heapq
from pyspark.serializers import BatchedSerializer, PickleSerializer
try:
@@ -54,6 +57,17 @@ except ImportError:
return 0
+def _get_local_dirs(sub):
+ """ Get all the directories """
+ path = os.environ.get("SPARK_LOCAL_DIRS", "/tmp")
+ dirs = path.split(",")
+ if len(dirs) > 1:
+ # different order in different processes and instances
+ rnd = random.Random(os.getpid() + id(dirs))
+ random.shuffle(dirs, rnd.random)
+ return [os.path.join(d, "python", str(os.getpid()), sub) for d in dirs]
+
+
class Aggregator(object):
"""
@@ -196,7 +210,7 @@ class ExternalMerger(Merger):
# default serializer is only used for tests
self.serializer = serializer or \
BatchedSerializer(PickleSerializer(), 1024)
- self.localdirs = localdirs or self._get_dirs()
+ self.localdirs = localdirs or _get_local_dirs(str(id(self)))
# number of partitions when spill data into disks
self.partitions = partitions
# check the memory after # of items merged
@@ -212,13 +226,6 @@ class ExternalMerger(Merger):
# randomize the hash of key, id(o) is the address of o (aligned by 8)
self._seed = id(self) + 7
- def _get_dirs(self):
- """ Get all the directories """
- path = os.environ.get("SPARK_LOCAL_DIRS", "/tmp")
- dirs = path.split(",")
- return [os.path.join(d, "python", str(os.getpid()), str(id(self)))
- for d in dirs]
-
def _get_spill_dir(self, n):
""" Choose one directory for spill by number n """
return os.path.join(self.localdirs[n % len(self.localdirs)], str(n))
@@ -434,6 +441,74 @@ class ExternalMerger(Merger):
os.remove(os.path.join(path, str(i)))
+class ExternalSorter(object):
+ """
+ ExtenalSorter will divide the elements into chunks, sort them in
+ memory and dump them into disks, finally merge them back.
+
+ The spilling will only happen when the used memory goes above
+ the limit.
+
+ >>> sorter = ExternalSorter(1) # 1M
+ >>> import random
+ >>> l = range(1024)
+ >>> random.shuffle(l)
+ >>> sorted(l) == list(sorter.sorted(l))
+ True
+ >>> sorted(l) == list(sorter.sorted(l, key=lambda x: -x, reverse=True))
+ True
+ """
+ def __init__(self, memory_limit, serializer=None):
+ self.memory_limit = memory_limit
+ self.local_dirs = _get_local_dirs("sort")
+ self.serializer = serializer or BatchedSerializer(PickleSerializer(), 1024)
+ self._spilled_bytes = 0
+
+ def _get_path(self, n):
+ """ Choose one directory for spill by number n """
+ d = self.local_dirs[n % len(self.local_dirs)]
+ if not os.path.exists(d):
+ os.makedirs(d)
+ return os.path.join(d, str(n))
+
+ def sorted(self, iterator, key=None, reverse=False):
+ """
+ Sort the elements in iterator, do external sort when the memory
+ goes above the limit.
+ """
+ batch = 10
+ chunks, current_chunk = [], []
+ iterator = iter(iterator)
+ while True:
+ # pick elements in batch
+ chunk = list(itertools.islice(iterator, batch))
+ current_chunk.extend(chunk)
+ if len(chunk) < batch:
+ break
+
+ if get_used_memory() > self.memory_limit:
+ # sort them inplace will save memory
+ current_chunk.sort(key=key, reverse=reverse)
+ path = self._get_path(len(chunks))
+ with open(path, 'w') as f:
+ self.serializer.dump_stream(current_chunk, f)
+ self._spilled_bytes += os.path.getsize(path)
+ chunks.append(self.serializer.load_stream(open(path)))
+ current_chunk = []
+
+ elif not chunks:
+ batch = min(batch * 2, 10000)
+
+ current_chunk.sort(key=key, reverse=reverse)
+ if not chunks:
+ return current_chunk
+
+ if current_chunk:
+ chunks.append(iter(current_chunk))
+
+ return heapq.merge(chunks, key=key, reverse=reverse)
+
+
if __name__ == "__main__":
import doctest
doctest.testmod()