aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/shuffle.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/shuffle.py')
-rw-r--r--python/pyspark/shuffle.py126
1 files changed, 69 insertions, 57 deletions
diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py
index 8a6fc627eb..b54baa57ec 100644
--- a/python/pyspark/shuffle.py
+++ b/python/pyspark/shuffle.py
@@ -78,8 +78,8 @@ def _get_local_dirs(sub):
# global stats
-MemoryBytesSpilled = 0L
-DiskBytesSpilled = 0L
+MemoryBytesSpilled = 0
+DiskBytesSpilled = 0
class Aggregator(object):
@@ -126,7 +126,7 @@ class Merger(object):
""" Merge the combined items by mergeCombiner """
raise NotImplementedError
- def iteritems(self):
+ def items(self):
""" Return the merged items ad iterator """
raise NotImplementedError
@@ -156,9 +156,9 @@ class InMemoryMerger(Merger):
for k, v in iterator:
d[k] = comb(d[k], v) if k in d else v
- def iteritems(self):
- """ Return the merged items as iterator """
- return self.data.iteritems()
+ def items(self):
+ """ Return the merged items ad iterator """
+ return iter(self.data.items())
def _compressed_serializer(self, serializer=None):
@@ -208,15 +208,15 @@ class ExternalMerger(Merger):
>>> agg = SimpleAggregator(lambda x, y: x + y)
>>> merger = ExternalMerger(agg, 10)
>>> N = 10000
- >>> merger.mergeValues(zip(xrange(N), xrange(N)))
+ >>> merger.mergeValues(zip(range(N), range(N)))
>>> assert merger.spills > 0
- >>> sum(v for k,v in merger.iteritems())
+ >>> sum(v for k,v in merger.items())
49995000
>>> merger = ExternalMerger(agg, 10)
- >>> merger.mergeCombiners(zip(xrange(N), xrange(N)))
+ >>> merger.mergeCombiners(zip(range(N), range(N)))
>>> assert merger.spills > 0
- >>> sum(v for k,v in merger.iteritems())
+ >>> sum(v for k,v in merger.items())
49995000
"""
@@ -335,10 +335,10 @@ class ExternalMerger(Merger):
# above limit at the first time.
# open all the files for writing
- streams = [open(os.path.join(path, str(i)), 'w')
+ streams = [open(os.path.join(path, str(i)), 'wb')
for i in range(self.partitions)]
- for k, v in self.data.iteritems():
+ for k, v in self.data.items():
h = self._partition(k)
# put one item in batch, make it compatible with load_stream
# it will increase the memory if dump them in batch
@@ -354,9 +354,9 @@ class ExternalMerger(Merger):
else:
for i in range(self.partitions):
p = os.path.join(path, str(i))
- with open(p, "w") as f:
+ with open(p, "wb") as f:
# dump items in batch
- self.serializer.dump_stream(self.pdata[i].iteritems(), f)
+ self.serializer.dump_stream(iter(self.pdata[i].items()), f)
self.pdata[i].clear()
DiskBytesSpilled += os.path.getsize(p)
@@ -364,10 +364,10 @@ class ExternalMerger(Merger):
gc.collect() # release the memory as much as possible
MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
- def iteritems(self):
+ def items(self):
""" Return all merged items as iterator """
if not self.pdata and not self.spills:
- return self.data.iteritems()
+ return iter(self.data.items())
return self._external_items()
def _external_items(self):
@@ -398,7 +398,8 @@ class ExternalMerger(Merger):
path = self._get_spill_dir(j)
p = os.path.join(path, str(index))
# do not check memory during merging
- self.mergeCombiners(self.serializer.load_stream(open(p)), 0)
+ with open(p, "rb") as f:
+ self.mergeCombiners(self.serializer.load_stream(f), 0)
# limit the total partitions
if (self.scale * self.partitions < self.MAX_TOTAL_PARTITIONS
@@ -408,7 +409,7 @@ class ExternalMerger(Merger):
gc.collect() # release the memory as much as possible
return self._recursive_merged_items(index)
- return self.data.iteritems()
+ return self.data.items()
def _recursive_merged_items(self, index):
"""
@@ -426,7 +427,8 @@ class ExternalMerger(Merger):
for j in range(self.spills):
path = self._get_spill_dir(j)
p = os.path.join(path, str(index))
- m.mergeCombiners(self.serializer.load_stream(open(p)), 0)
+ with open(p, 'rb') as f:
+ m.mergeCombiners(self.serializer.load_stream(f), 0)
if get_used_memory() > limit:
m._spill()
@@ -451,7 +453,7 @@ class ExternalSorter(object):
>>> sorter = ExternalSorter(1) # 1M
>>> import random
- >>> l = range(1024)
+ >>> l = list(range(1024))
>>> random.shuffle(l)
>>> sorted(l) == list(sorter.sorted(l))
True
@@ -499,9 +501,16 @@ class ExternalSorter(object):
# sort them inplace will save memory
current_chunk.sort(key=key, reverse=reverse)
path = self._get_path(len(chunks))
- with open(path, 'w') as f:
+ with open(path, 'wb') as f:
self.serializer.dump_stream(current_chunk, f)
- chunks.append(self.serializer.load_stream(open(path)))
+
+ def load(f):
+ for v in self.serializer.load_stream(f):
+ yield v
+ # close the file explicit once we consume all the items
+ # to avoid ResourceWarning in Python3
+ f.close()
+ chunks.append(load(open(path, 'rb')))
current_chunk = []
gc.collect()
limit = self._next_limit()
@@ -527,7 +536,7 @@ class ExternalList(object):
ExternalList can have many items which cannot be hold in memory in
the same time.
- >>> l = ExternalList(range(100))
+ >>> l = ExternalList(list(range(100)))
>>> len(l)
100
>>> l.append(10)
@@ -555,11 +564,11 @@ class ExternalList(object):
def __getstate__(self):
if self._file is not None:
self._file.flush()
- f = os.fdopen(os.dup(self._file.fileno()))
- f.seek(0)
- serialized = f.read()
+ with os.fdopen(os.dup(self._file.fileno()), "rb") as f:
+ f.seek(0)
+ serialized = f.read()
else:
- serialized = ''
+ serialized = b''
return self.values, self.count, serialized
def __setstate__(self, item):
@@ -575,7 +584,7 @@ class ExternalList(object):
if self._file is not None:
self._file.flush()
# read all items from disks first
- with os.fdopen(os.dup(self._file.fileno()), 'r') as f:
+ with os.fdopen(os.dup(self._file.fileno()), 'rb') as f:
f.seek(0)
for v in self._ser.load_stream(f):
yield v
@@ -598,11 +607,16 @@ class ExternalList(object):
d = dirs[id(self) % len(dirs)]
if not os.path.exists(d):
os.makedirs(d)
- p = os.path.join(d, str(id))
- self._file = open(p, "w+", 65536)
+ p = os.path.join(d, str(id(self)))
+ self._file = open(p, "wb+", 65536)
self._ser = BatchedSerializer(CompressedSerializer(PickleSerializer()), 1024)
os.unlink(p)
+ def __del__(self):
+ if self._file:
+ self._file.close()
+ self._file = None
+
def _spill(self):
""" dump the values into disk """
global MemoryBytesSpilled, DiskBytesSpilled
@@ -651,33 +665,28 @@ class GroupByKey(object):
"""
Group a sorted iterator as [(k1, it1), (k2, it2), ...]
- >>> k = [i/3 for i in range(6)]
+ >>> k = [i // 3 for i in range(6)]
>>> v = [[i] for i in range(6)]
- >>> g = GroupByKey(iter(zip(k, v)))
+ >>> g = GroupByKey(zip(k, v))
>>> [(k, list(it)) for k, it in g]
[(0, [0, 1, 2]), (1, [3, 4, 5])]
"""
def __init__(self, iterator):
- self.iterator = iter(iterator)
- self.next_item = None
+ self.iterator = iterator
def __iter__(self):
- return self
-
- def next(self):
- key, value = self.next_item if self.next_item else next(self.iterator)
- values = ExternalListOfList([value])
- try:
- while True:
- k, v = next(self.iterator)
- if k != key:
- self.next_item = (k, v)
- break
+ key, values = None, None
+ for k, v in self.iterator:
+ if values is not None and k == key:
values.append(v)
- except StopIteration:
- self.next_item = None
- return key, values
+ else:
+ if values is not None:
+ yield (key, values)
+ key = k
+ values = ExternalListOfList([v])
+ if values is not None:
+ yield (key, values)
class ExternalGroupBy(ExternalMerger):
@@ -744,7 +753,7 @@ class ExternalGroupBy(ExternalMerger):
# above limit at the first time.
# open all the files for writing
- streams = [open(os.path.join(path, str(i)), 'w')
+ streams = [open(os.path.join(path, str(i)), 'wb')
for i in range(self.partitions)]
# If the number of keys is small, then the overhead of sort is small
@@ -756,7 +765,7 @@ class ExternalGroupBy(ExternalMerger):
h = self._partition(k)
self.serializer.dump_stream([(k, self.data[k])], streams[h])
else:
- for k, v in self.data.iteritems():
+ for k, v in self.data.items():
h = self._partition(k)
self.serializer.dump_stream([(k, v)], streams[h])
@@ -771,14 +780,14 @@ class ExternalGroupBy(ExternalMerger):
else:
for i in range(self.partitions):
p = os.path.join(path, str(i))
- with open(p, "w") as f:
+ with open(p, "wb") as f:
# dump items in batch
if self._sorted:
# sort by key only (stable)
- sorted_items = sorted(self.pdata[i].iteritems(), key=operator.itemgetter(0))
+ sorted_items = sorted(self.pdata[i].items(), key=operator.itemgetter(0))
self.serializer.dump_stream(sorted_items, f)
else:
- self.serializer.dump_stream(self.pdata[i].iteritems(), f)
+ self.serializer.dump_stream(self.pdata[i].items(), f)
self.pdata[i].clear()
DiskBytesSpilled += os.path.getsize(p)
@@ -792,7 +801,7 @@ class ExternalGroupBy(ExternalMerger):
# if the memory can not hold all the partition,
# then use sort based merge. Because of compression,
# the data on disks will be much smaller than needed memory
- if (size >> 20) >= self.memory_limit / 10:
+ if size >= self.memory_limit << 17: # * 1M / 8
return self._merge_sorted_items(index)
self.data = {}
@@ -800,15 +809,18 @@ class ExternalGroupBy(ExternalMerger):
path = self._get_spill_dir(j)
p = os.path.join(path, str(index))
# do not check memory during merging
- self.mergeCombiners(self.serializer.load_stream(open(p)), 0)
- return self.data.iteritems()
+ with open(p, "rb") as f:
+ self.mergeCombiners(self.serializer.load_stream(f), 0)
+ return self.data.items()
def _merge_sorted_items(self, index):
""" load a partition from disk, then sort and group by key """
def load_partition(j):
path = self._get_spill_dir(j)
p = os.path.join(path, str(index))
- return self.serializer.load_stream(open(p, 'r', 65536))
+ with open(p, 'rb', 65536) as f:
+ for v in self.serializer.load_stream(f):
+ yield v
disk_items = [load_partition(j) for j in range(self.spills)]