diff options
Diffstat (limited to 'python/pyspark/shuffle.py')
-rw-r--r-- | python/pyspark/shuffle.py | 19 |
1 files changed, 16 insertions, 3 deletions
diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 49829f5280..ce597cbe91 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -68,6 +68,11 @@ def _get_local_dirs(sub): return [os.path.join(d, "python", str(os.getpid()), sub) for d in dirs] +# global stats +MemoryBytesSpilled = 0L +DiskBytesSpilled = 0L + + class Aggregator(object): """ @@ -313,10 +318,12 @@ class ExternalMerger(Merger): It will dump the data in batch for better performance. """ + global MemoryBytesSpilled, DiskBytesSpilled path = self._get_spill_dir(self.spills) if not os.path.exists(path): os.makedirs(path) + used_memory = get_used_memory() if not self.pdata: # The data has not been partitioned, it will iterator the # dataset once, write them into different files, has no @@ -334,6 +341,7 @@ class ExternalMerger(Merger): self.serializer.dump_stream([(k, v)], streams[h]) for s in streams: + DiskBytesSpilled += s.tell() s.close() self.data.clear() @@ -346,9 +354,11 @@ class ExternalMerger(Merger): # dump items in batch self.serializer.dump_stream(self.pdata[i].iteritems(), f) self.pdata[i].clear() + DiskBytesSpilled += os.path.getsize(p) self.spills += 1 gc.collect() # release the memory as much as possible + MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 def iteritems(self): """ Return all merged items as iterator """ @@ -462,7 +472,6 @@ class ExternalSorter(object): self.memory_limit = memory_limit self.local_dirs = _get_local_dirs("sort") self.serializer = serializer or BatchedSerializer(PickleSerializer(), 1024) - self._spilled_bytes = 0 def _get_path(self, n): """ Choose one directory for spill by number n """ @@ -476,6 +485,7 @@ class ExternalSorter(object): Sort the elements in iterator, do external sort when the memory goes above the limit. """ + global MemoryBytesSpilled, DiskBytesSpilled batch = 10 chunks, current_chunk = [], [] iterator = iter(iterator) @@ -486,15 +496,18 @@ class ExternalSorter(object): if len(chunk) < batch: break - if get_used_memory() > self.memory_limit: + used_memory = get_used_memory() + if used_memory > self.memory_limit: # sort them inplace will save memory current_chunk.sort(key=key, reverse=reverse) path = self._get_path(len(chunks)) with open(path, 'w') as f: self.serializer.dump_stream(current_chunk, f) - self._spilled_bytes += os.path.getsize(path) chunks.append(self.serializer.load_stream(open(path))) current_chunk = [] + gc.collect() + MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 + DiskBytesSpilled += os.path.getsize(path) elif not chunks: batch = min(batch * 2, 10000) |