aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/shuffle.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/shuffle.py')
-rw-r--r--python/pyspark/shuffle.py531
1 files changed, 414 insertions, 117 deletions
diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py
index 10a7ccd502..8a6fc627eb 100644
--- a/python/pyspark/shuffle.py
+++ b/python/pyspark/shuffle.py
@@ -16,28 +16,35 @@
#
import os
-import sys
import platform
import shutil
import warnings
import gc
import itertools
+import operator
import random
import pyspark.heapq3 as heapq
-from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
+from pyspark.serializers import BatchedSerializer, PickleSerializer, FlattenedValuesSerializer, \
+ CompressedSerializer, AutoBatchedSerializer
+
try:
import psutil
+ process = None
+
def get_used_memory():
""" Return the used memory in MB """
- process = psutil.Process(os.getpid())
+ global process
+ if process is None or process._pid != os.getpid():
+ process = psutil.Process(os.getpid())
if hasattr(process, "memory_info"):
info = process.memory_info()
else:
info = process.get_memory_info()
return info.rss >> 20
+
except ImportError:
def get_used_memory():
@@ -46,6 +53,7 @@ except ImportError:
for line in open('/proc/self/status'):
if line.startswith('VmRSS:'):
return int(line.split()[1]) >> 10
+
else:
warnings.warn("Please install psutil to have better "
"support with spilling")
@@ -54,6 +62,7 @@ except ImportError:
rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
return rss >> 20
# TODO: support windows
+
return 0
@@ -148,10 +157,16 @@ class InMemoryMerger(Merger):
d[k] = comb(d[k], v) if k in d else v
def iteritems(self):
- """ Return the merged items ad iterator """
+ """ Return the merged items as iterator """
return self.data.iteritems()
+def _compressed_serializer(self, serializer=None):
+ # always use PickleSerializer to simplify implementation
+ ser = PickleSerializer()
+ return AutoBatchedSerializer(CompressedSerializer(ser))
+
+
class ExternalMerger(Merger):
"""
@@ -173,7 +188,7 @@ class ExternalMerger(Merger):
dict. Repeat this again until combine all the items.
- Before return any items, it will load each partition and
- combine them seperately. Yield them before loading next
+ combine them separately. Yield them before loading next
partition.
- During loading a partition, if the memory goes over limit,
@@ -182,7 +197,7 @@ class ExternalMerger(Merger):
`data` and `pdata` are used to hold the merged items in memory.
At first, all the data are merged into `data`. Once the used
- memory goes over limit, the items in `data` are dumped indo
+ memory goes over limit, the items in `data` are dumped into
disks, `data` will be cleared, all rest of items will be merged
into `pdata` and then dumped into disks. Before returning, all
the items in `pdata` will be dumped into disks.
@@ -193,16 +208,16 @@ class ExternalMerger(Merger):
>>> agg = SimpleAggregator(lambda x, y: x + y)
>>> merger = ExternalMerger(agg, 10)
>>> N = 10000
- >>> merger.mergeValues(zip(xrange(N), xrange(N)) * 10)
+ >>> merger.mergeValues(zip(xrange(N), xrange(N)))
>>> assert merger.spills > 0
>>> sum(v for k,v in merger.iteritems())
- 499950000
+ 49995000
>>> merger = ExternalMerger(agg, 10)
- >>> merger.mergeCombiners(zip(xrange(N), xrange(N)) * 10)
+ >>> merger.mergeCombiners(zip(xrange(N), xrange(N)))
>>> assert merger.spills > 0
>>> sum(v for k,v in merger.iteritems())
- 499950000
+ 49995000
"""
# the max total partitions created recursively
@@ -212,8 +227,7 @@ class ExternalMerger(Merger):
localdirs=None, scale=1, partitions=59, batch=1000):
Merger.__init__(self, aggregator)
self.memory_limit = memory_limit
- # default serializer is only used for tests
- self.serializer = serializer or AutoBatchedSerializer(PickleSerializer())
+ self.serializer = _compressed_serializer(serializer)
self.localdirs = localdirs or _get_local_dirs(str(id(self)))
# number of partitions when spill data into disks
self.partitions = partitions
@@ -221,7 +235,7 @@ class ExternalMerger(Merger):
self.batch = batch
# scale is used to scale down the hash of key for recursive hash map
self.scale = scale
- # unpartitioned merged data
+ # un-partitioned merged data
self.data = {}
# partitioned merged data, list of dicts
self.pdata = []
@@ -244,72 +258,63 @@ class ExternalMerger(Merger):
def mergeValues(self, iterator):
""" Combine the items by creator and combiner """
- iterator = iter(iterator)
# speedup attribute lookup
creator, comb = self.agg.createCombiner, self.agg.mergeValue
- d, c, batch = self.data, 0, self.batch
+ c, data, pdata, hfun, batch = 0, self.data, self.pdata, self._partition, self.batch
+ limit = self.memory_limit
for k, v in iterator:
+ d = pdata[hfun(k)] if pdata else data
d[k] = comb(d[k], v) if k in d else creator(v)
c += 1
- if c % batch == 0 and get_used_memory() > self.memory_limit:
- self._spill()
- self._partitioned_mergeValues(iterator, self._next_limit())
- break
+ if c >= batch:
+ if get_used_memory() >= limit:
+ self._spill()
+ limit = self._next_limit()
+ batch /= 2
+ c = 0
+ else:
+ batch *= 1.5
+
+ if get_used_memory() >= limit:
+ self._spill()
def _partition(self, key):
""" Return the partition for key """
return hash((key, self._seed)) % self.partitions
- def _partitioned_mergeValues(self, iterator, limit=0):
- """ Partition the items by key, then combine them """
- # speedup attribute lookup
- creator, comb = self.agg.createCombiner, self.agg.mergeValue
- c, pdata, hfun, batch = 0, self.pdata, self._partition, self.batch
-
- for k, v in iterator:
- d = pdata[hfun(k)]
- d[k] = comb(d[k], v) if k in d else creator(v)
- if not limit:
- continue
-
- c += 1
- if c % batch == 0 and get_used_memory() > limit:
- self._spill()
- limit = self._next_limit()
+ def _object_size(self, obj):
+ """ How much of memory for this obj, assume that all the objects
+ consume similar bytes of memory
+ """
+ return 1
- def mergeCombiners(self, iterator, check=True):
+ def mergeCombiners(self, iterator, limit=None):
""" Merge (K,V) pair by mergeCombiner """
- iterator = iter(iterator)
+ if limit is None:
+ limit = self.memory_limit
# speedup attribute lookup
- d, comb, batch = self.data, self.agg.mergeCombiners, self.batch
- c = 0
- for k, v in iterator:
- d[k] = comb(d[k], v) if k in d else v
- if not check:
- continue
-
- c += 1
- if c % batch == 0 and get_used_memory() > self.memory_limit:
- self._spill()
- self._partitioned_mergeCombiners(iterator, self._next_limit())
- break
-
- def _partitioned_mergeCombiners(self, iterator, limit=0):
- """ Partition the items by key, then merge them """
- comb, pdata = self.agg.mergeCombiners, self.pdata
- c, hfun = 0, self._partition
+ comb, hfun, objsize = self.agg.mergeCombiners, self._partition, self._object_size
+ c, data, pdata, batch = 0, self.data, self.pdata, self.batch
for k, v in iterator:
- d = pdata[hfun(k)]
+ d = pdata[hfun(k)] if pdata else data
d[k] = comb(d[k], v) if k in d else v
if not limit:
continue
- c += 1
- if c % self.batch == 0 and get_used_memory() > limit:
- self._spill()
- limit = self._next_limit()
+ c += objsize(v)
+ if c > batch:
+ if get_used_memory() > limit:
+ self._spill()
+ limit = self._next_limit()
+ batch /= 2
+ c = 0
+ else:
+ batch *= 1.5
+
+ if limit and get_used_memory() >= limit:
+ self._spill()
def _spill(self):
"""
@@ -335,7 +340,7 @@ class ExternalMerger(Merger):
for k, v in self.data.iteritems():
h = self._partition(k)
- # put one item in batch, make it compatitable with load_stream
+ # put one item in batch, make it compatible with load_stream
# it will increase the memory if dump them in batch
self.serializer.dump_stream([(k, v)], streams[h])
@@ -344,7 +349,7 @@ class ExternalMerger(Merger):
s.close()
self.data.clear()
- self.pdata = [{} for i in range(self.partitions)]
+ self.pdata.extend([{} for i in range(self.partitions)])
else:
for i in range(self.partitions):
@@ -370,29 +375,12 @@ class ExternalMerger(Merger):
assert not self.data
if any(self.pdata):
self._spill()
- hard_limit = self._next_limit()
+ # disable partitioning and spilling when merge combiners from disk
+ self.pdata = []
try:
for i in range(self.partitions):
- self.data = {}
- for j in range(self.spills):
- path = self._get_spill_dir(j)
- p = os.path.join(path, str(i))
- # do not check memory during merging
- self.mergeCombiners(self.serializer.load_stream(open(p)),
- False)
-
- # limit the total partitions
- if (self.scale * self.partitions < self.MAX_TOTAL_PARTITIONS
- and j < self.spills - 1
- and get_used_memory() > hard_limit):
- self.data.clear() # will read from disk again
- gc.collect() # release the memory as much as possible
- for v in self._recursive_merged_items(i):
- yield v
- return
-
- for v in self.data.iteritems():
+ for v in self._merged_items(i):
yield v
self.data.clear()
@@ -400,53 +388,56 @@ class ExternalMerger(Merger):
for j in range(self.spills):
path = self._get_spill_dir(j)
os.remove(os.path.join(path, str(i)))
-
finally:
self._cleanup()
- def _cleanup(self):
- """ Clean up all the files in disks """
- for d in self.localdirs:
- shutil.rmtree(d, True)
+ def _merged_items(self, index):
+ self.data = {}
+ limit = self._next_limit()
+ for j in range(self.spills):
+ path = self._get_spill_dir(j)
+ p = os.path.join(path, str(index))
+ # do not check memory during merging
+ self.mergeCombiners(self.serializer.load_stream(open(p)), 0)
+
+ # limit the total partitions
+ if (self.scale * self.partitions < self.MAX_TOTAL_PARTITIONS
+ and j < self.spills - 1
+ and get_used_memory() > limit):
+ self.data.clear() # will read from disk again
+ gc.collect() # release the memory as much as possible
+ return self._recursive_merged_items(index)
- def _recursive_merged_items(self, start):
+ return self.data.iteritems()
+
+ def _recursive_merged_items(self, index):
"""
merge the partitioned items and return the as iterator
If one partition can not be fit in memory, then them will be
partitioned and merged recursively.
"""
- # make sure all the data are dumps into disks.
- assert not self.data
- if any(self.pdata):
- self._spill()
- assert self.spills > 0
-
- for i in range(start, self.partitions):
- subdirs = [os.path.join(d, "parts", str(i))
- for d in self.localdirs]
- m = ExternalMerger(self.agg, self.memory_limit, self.serializer,
- subdirs, self.scale * self.partitions, self.partitions)
- m.pdata = [{} for _ in range(self.partitions)]
- limit = self._next_limit()
-
- for j in range(self.spills):
- path = self._get_spill_dir(j)
- p = os.path.join(path, str(i))
- m._partitioned_mergeCombiners(
- self.serializer.load_stream(open(p)))
-
- if get_used_memory() > limit:
- m._spill()
- limit = self._next_limit()
+ subdirs = [os.path.join(d, "parts", str(index)) for d in self.localdirs]
+ m = ExternalMerger(self.agg, self.memory_limit, self.serializer, subdirs,
+ self.scale * self.partitions, self.partitions, self.batch)
+ m.pdata = [{} for _ in range(self.partitions)]
+ limit = self._next_limit()
+
+ for j in range(self.spills):
+ path = self._get_spill_dir(j)
+ p = os.path.join(path, str(index))
+ m.mergeCombiners(self.serializer.load_stream(open(p)), 0)
+
+ if get_used_memory() > limit:
+ m._spill()
+ limit = self._next_limit()
- for v in m._external_items():
- yield v
+ return m._external_items()
- # remove the merged partition
- for j in range(self.spills):
- path = self._get_spill_dir(j)
- os.remove(os.path.join(path, str(i)))
+ def _cleanup(self):
+ """ Clean up all the files in disks """
+ for d in self.localdirs:
+ shutil.rmtree(d, True)
class ExternalSorter(object):
@@ -457,6 +448,7 @@ class ExternalSorter(object):
The spilling will only happen when the used memory goes above
the limit.
+
>>> sorter = ExternalSorter(1) # 1M
>>> import random
>>> l = range(1024)
@@ -469,7 +461,7 @@ class ExternalSorter(object):
def __init__(self, memory_limit, serializer=None):
self.memory_limit = memory_limit
self.local_dirs = _get_local_dirs("sort")
- self.serializer = serializer or AutoBatchedSerializer(PickleSerializer())
+ self.serializer = _compressed_serializer(serializer)
def _get_path(self, n):
""" Choose one directory for spill by number n """
@@ -515,6 +507,7 @@ class ExternalSorter(object):
limit = self._next_limit()
MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
DiskBytesSpilled += os.path.getsize(path)
+ os.unlink(path) # data will be deleted after close
elif not chunks:
batch = min(batch * 2, 10000)
@@ -529,6 +522,310 @@ class ExternalSorter(object):
return heapq.merge(chunks, key=key, reverse=reverse)
+class ExternalList(object):
+ """
+ ExternalList can have many items which cannot be hold in memory in
+ the same time.
+
+ >>> l = ExternalList(range(100))
+ >>> len(l)
+ 100
+ >>> l.append(10)
+ >>> len(l)
+ 101
+ >>> for i in range(20240):
+ ... l.append(i)
+ >>> len(l)
+ 20341
+ >>> import pickle
+ >>> l2 = pickle.loads(pickle.dumps(l))
+ >>> len(l2)
+ 20341
+ >>> list(l2)[100]
+ 10
+ """
+ LIMIT = 10240
+
+ def __init__(self, values):
+ self.values = values
+ self.count = len(values)
+ self._file = None
+ self._ser = None
+
+ def __getstate__(self):
+ if self._file is not None:
+ self._file.flush()
+ f = os.fdopen(os.dup(self._file.fileno()))
+ f.seek(0)
+ serialized = f.read()
+ else:
+ serialized = ''
+ return self.values, self.count, serialized
+
+ def __setstate__(self, item):
+ self.values, self.count, serialized = item
+ if serialized:
+ self._open_file()
+ self._file.write(serialized)
+ else:
+ self._file = None
+ self._ser = None
+
+ def __iter__(self):
+ if self._file is not None:
+ self._file.flush()
+ # read all items from disks first
+ with os.fdopen(os.dup(self._file.fileno()), 'r') as f:
+ f.seek(0)
+ for v in self._ser.load_stream(f):
+ yield v
+
+ for v in self.values:
+ yield v
+
+ def __len__(self):
+ return self.count
+
+ def append(self, value):
+ self.values.append(value)
+ self.count += 1
+ # dump them into disk if the key is huge
+ if len(self.values) >= self.LIMIT:
+ self._spill()
+
+ def _open_file(self):
+ dirs = _get_local_dirs("objects")
+ d = dirs[id(self) % len(dirs)]
+ if not os.path.exists(d):
+ os.makedirs(d)
+ p = os.path.join(d, str(id))
+ self._file = open(p, "w+", 65536)
+ self._ser = BatchedSerializer(CompressedSerializer(PickleSerializer()), 1024)
+ os.unlink(p)
+
+ def _spill(self):
+ """ dump the values into disk """
+ global MemoryBytesSpilled, DiskBytesSpilled
+ if self._file is None:
+ self._open_file()
+
+ used_memory = get_used_memory()
+ pos = self._file.tell()
+ self._ser.dump_stream(self.values, self._file)
+ self.values = []
+ gc.collect()
+ DiskBytesSpilled += self._file.tell() - pos
+ MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
+
+
+class ExternalListOfList(ExternalList):
+ """
+ An external list for list.
+
+ >>> l = ExternalListOfList([[i, i] for i in range(100)])
+ >>> len(l)
+ 200
+ >>> l.append(range(10))
+ >>> len(l)
+ 210
+ >>> len(list(l))
+ 210
+ """
+
+ def __init__(self, values):
+ ExternalList.__init__(self, values)
+ self.count = sum(len(i) for i in values)
+
+ def append(self, value):
+ ExternalList.append(self, value)
+ # already counted 1 in ExternalList.append
+ self.count += len(value) - 1
+
+ def __iter__(self):
+ for values in ExternalList.__iter__(self):
+ for v in values:
+ yield v
+
+
+class GroupByKey(object):
+ """
+ Group a sorted iterator as [(k1, it1), (k2, it2), ...]
+
+ >>> k = [i/3 for i in range(6)]
+ >>> v = [[i] for i in range(6)]
+ >>> g = GroupByKey(iter(zip(k, v)))
+ >>> [(k, list(it)) for k, it in g]
+ [(0, [0, 1, 2]), (1, [3, 4, 5])]
+ """
+
+ def __init__(self, iterator):
+ self.iterator = iter(iterator)
+ self.next_item = None
+
+ def __iter__(self):
+ return self
+
+ def next(self):
+ key, value = self.next_item if self.next_item else next(self.iterator)
+ values = ExternalListOfList([value])
+ try:
+ while True:
+ k, v = next(self.iterator)
+ if k != key:
+ self.next_item = (k, v)
+ break
+ values.append(v)
+ except StopIteration:
+ self.next_item = None
+ return key, values
+
+
+class ExternalGroupBy(ExternalMerger):
+
+ """
+ Group by the items by key. If any partition of them can not been
+ hold in memory, it will do sort based group by.
+
+ This class works as follows:
+
+ - It repeatedly group the items by key and save them in one dict in
+ memory.
+
+ - When the used memory goes above memory limit, it will split
+ the combined data into partitions by hash code, dump them
+ into disk, one file per partition. If the number of keys
+ in one partitions is smaller than 1000, it will sort them
+ by key before dumping into disk.
+
+ - Then it goes through the rest of the iterator, group items
+ by key into different dict by hash. Until the used memory goes over
+ memory limit, it dump all the dicts into disks, one file per
+ dict. Repeat this again until combine all the items. It
+ also will try to sort the items by key in each partition
+ before dumping into disks.
+
+ - It will yield the grouped items partitions by partitions.
+ If the data in one partitions can be hold in memory, then it
+ will load and combine them in memory and yield.
+
+ - If the dataset in one partition cannot be hold in memory,
+ it will sort them first. If all the files are already sorted,
+ it merge them by heap.merge(), so it will do external sort
+ for all the files.
+
+ - After sorting, `GroupByKey` class will put all the continuous
+ items with the same key as a group, yield the values as
+ an iterator.
+ """
+ SORT_KEY_LIMIT = 1000
+
+ def flattened_serializer(self):
+ assert isinstance(self.serializer, BatchedSerializer)
+ ser = self.serializer
+ return FlattenedValuesSerializer(ser, 20)
+
+ def _object_size(self, obj):
+ return len(obj)
+
+ def _spill(self):
+ """
+ dump already partitioned data into disks.
+ """
+ global MemoryBytesSpilled, DiskBytesSpilled
+ path = self._get_spill_dir(self.spills)
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+ used_memory = get_used_memory()
+ if not self.pdata:
+ # The data has not been partitioned, it will iterator the
+ # data once, write them into different files, has no
+ # additional memory. It only called when the memory goes
+ # above limit at the first time.
+
+ # open all the files for writing
+ streams = [open(os.path.join(path, str(i)), 'w')
+ for i in range(self.partitions)]
+
+ # If the number of keys is small, then the overhead of sort is small
+ # sort them before dumping into disks
+ self._sorted = len(self.data) < self.SORT_KEY_LIMIT
+ if self._sorted:
+ self.serializer = self.flattened_serializer()
+ for k in sorted(self.data.keys()):
+ h = self._partition(k)
+ self.serializer.dump_stream([(k, self.data[k])], streams[h])
+ else:
+ for k, v in self.data.iteritems():
+ h = self._partition(k)
+ self.serializer.dump_stream([(k, v)], streams[h])
+
+ for s in streams:
+ DiskBytesSpilled += s.tell()
+ s.close()
+
+ self.data.clear()
+ # self.pdata is cached in `mergeValues` and `mergeCombiners`
+ self.pdata.extend([{} for i in range(self.partitions)])
+
+ else:
+ for i in range(self.partitions):
+ p = os.path.join(path, str(i))
+ with open(p, "w") as f:
+ # dump items in batch
+ if self._sorted:
+ # sort by key only (stable)
+ sorted_items = sorted(self.pdata[i].iteritems(), key=operator.itemgetter(0))
+ self.serializer.dump_stream(sorted_items, f)
+ else:
+ self.serializer.dump_stream(self.pdata[i].iteritems(), f)
+ self.pdata[i].clear()
+ DiskBytesSpilled += os.path.getsize(p)
+
+ self.spills += 1
+ gc.collect() # release the memory as much as possible
+ MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
+
+ def _merged_items(self, index):
+ size = sum(os.path.getsize(os.path.join(self._get_spill_dir(j), str(index)))
+ for j in range(self.spills))
+ # if the memory can not hold all the partition,
+ # then use sort based merge. Because of compression,
+ # the data on disks will be much smaller than needed memory
+ if (size >> 20) >= self.memory_limit / 10:
+ return self._merge_sorted_items(index)
+
+ self.data = {}
+ for j in range(self.spills):
+ path = self._get_spill_dir(j)
+ p = os.path.join(path, str(index))
+ # do not check memory during merging
+ self.mergeCombiners(self.serializer.load_stream(open(p)), 0)
+ return self.data.iteritems()
+
+ def _merge_sorted_items(self, index):
+ """ load a partition from disk, then sort and group by key """
+ def load_partition(j):
+ path = self._get_spill_dir(j)
+ p = os.path.join(path, str(index))
+ return self.serializer.load_stream(open(p, 'r', 65536))
+
+ disk_items = [load_partition(j) for j in range(self.spills)]
+
+ if self._sorted:
+ # all the partitions are already sorted
+ sorted_items = heapq.merge(disk_items, key=operator.itemgetter(0))
+
+ else:
+ # Flatten the combined values, so it will not consume huge
+ # memory during merging sort.
+ ser = self.flattened_serializer()
+ sorter = ExternalSorter(self.memory_limit, ser)
+ sorted_items = sorter.sorted(itertools.chain(*disk_items),
+ key=operator.itemgetter(0))
+ return ((k, vs) for k, vs in GroupByKey(sorted_items))
+
+
if __name__ == "__main__":
import doctest
doctest.testmod()