aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/shuffle.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/shuffle.py')
-rw-r--r--python/pyspark/shuffle.py19
1 files changed, 16 insertions, 3 deletions
diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py
index 49829f5280..ce597cbe91 100644
--- a/python/pyspark/shuffle.py
+++ b/python/pyspark/shuffle.py
@@ -68,6 +68,11 @@ def _get_local_dirs(sub):
return [os.path.join(d, "python", str(os.getpid()), sub) for d in dirs]
+# global stats
+MemoryBytesSpilled = 0L
+DiskBytesSpilled = 0L
+
+
class Aggregator(object):
"""
@@ -313,10 +318,12 @@ class ExternalMerger(Merger):
It will dump the data in batch for better performance.
"""
+ global MemoryBytesSpilled, DiskBytesSpilled
path = self._get_spill_dir(self.spills)
if not os.path.exists(path):
os.makedirs(path)
+ used_memory = get_used_memory()
if not self.pdata:
# The data has not been partitioned, it will iterator the
# dataset once, write them into different files, has no
@@ -334,6 +341,7 @@ class ExternalMerger(Merger):
self.serializer.dump_stream([(k, v)], streams[h])
for s in streams:
+ DiskBytesSpilled += s.tell()
s.close()
self.data.clear()
@@ -346,9 +354,11 @@ class ExternalMerger(Merger):
# dump items in batch
self.serializer.dump_stream(self.pdata[i].iteritems(), f)
self.pdata[i].clear()
+ DiskBytesSpilled += os.path.getsize(p)
self.spills += 1
gc.collect() # release the memory as much as possible
+ MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
def iteritems(self):
""" Return all merged items as iterator """
@@ -462,7 +472,6 @@ class ExternalSorter(object):
self.memory_limit = memory_limit
self.local_dirs = _get_local_dirs("sort")
self.serializer = serializer or BatchedSerializer(PickleSerializer(), 1024)
- self._spilled_bytes = 0
def _get_path(self, n):
""" Choose one directory for spill by number n """
@@ -476,6 +485,7 @@ class ExternalSorter(object):
Sort the elements in iterator, do external sort when the memory
goes above the limit.
"""
+ global MemoryBytesSpilled, DiskBytesSpilled
batch = 10
chunks, current_chunk = [], []
iterator = iter(iterator)
@@ -486,15 +496,18 @@ class ExternalSorter(object):
if len(chunk) < batch:
break
- if get_used_memory() > self.memory_limit:
+ used_memory = get_used_memory()
+ if used_memory > self.memory_limit:
# sort them inplace will save memory
current_chunk.sort(key=key, reverse=reverse)
path = self._get_path(len(chunks))
with open(path, 'w') as f:
self.serializer.dump_stream(current_chunk, f)
- self._spilled_bytes += os.path.getsize(path)
chunks.append(self.serializer.load_stream(open(path)))
current_chunk = []
+ gc.collect()
+ MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
+ DiskBytesSpilled += os.path.getsize(path)
elif not chunks:
batch = min(batch * 2, 10000)