aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/join.py13
-rw-r--r--python/pyspark/rdd.py48
-rw-r--r--python/pyspark/resultiterable.py7
-rw-r--r--python/pyspark/serializers.py25
-rw-r--r--python/pyspark/shuffle.py531
-rw-r--r--python/pyspark/tests.py50
6 files changed, 531 insertions, 143 deletions
diff --git a/python/pyspark/join.py b/python/pyspark/join.py
index efc1ef9396..c3491defb2 100644
--- a/python/pyspark/join.py
+++ b/python/pyspark/join.py
@@ -48,7 +48,7 @@ def python_join(rdd, other, numPartitions):
vbuf.append(v)
elif n == 2:
wbuf.append(v)
- return [(v, w) for v in vbuf for w in wbuf]
+ return ((v, w) for v in vbuf for w in wbuf)
return _do_python_join(rdd, other, numPartitions, dispatch)
@@ -62,7 +62,7 @@ def python_right_outer_join(rdd, other, numPartitions):
wbuf.append(v)
if not vbuf:
vbuf.append(None)
- return [(v, w) for v in vbuf for w in wbuf]
+ return ((v, w) for v in vbuf for w in wbuf)
return _do_python_join(rdd, other, numPartitions, dispatch)
@@ -76,7 +76,7 @@ def python_left_outer_join(rdd, other, numPartitions):
wbuf.append(v)
if not wbuf:
wbuf.append(None)
- return [(v, w) for v in vbuf for w in wbuf]
+ return ((v, w) for v in vbuf for w in wbuf)
return _do_python_join(rdd, other, numPartitions, dispatch)
@@ -104,8 +104,9 @@ def python_cogroup(rdds, numPartitions):
rdd_len = len(vrdds)
def dispatch(seq):
- bufs = [[] for i in range(rdd_len)]
- for (n, v) in seq:
+ bufs = [[] for _ in range(rdd_len)]
+ for n, v in seq:
bufs[n].append(v)
- return tuple(map(ResultIterable, bufs))
+ return tuple(ResultIterable(vs) for vs in bufs)
+
return union_vrdds.groupByKey(numPartitions).mapValues(dispatch)
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 2d05611321..1b18789040 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -41,7 +41,7 @@ from pyspark.rddsampler import RDDSampler, RDDRangeSampler, RDDStratifiedSampler
from pyspark.storagelevel import StorageLevel
from pyspark.resultiterable import ResultIterable
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \
- get_used_memory, ExternalSorter
+ get_used_memory, ExternalSorter, ExternalGroupBy
from pyspark.traceback_utils import SCCallSiteSync
from py4j.java_collections import ListConverter, MapConverter
@@ -573,8 +573,8 @@ class RDD(object):
if numPartitions is None:
numPartitions = self._defaultReducePartitions()
- spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower() == 'true')
- memory = _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m"))
+ spill = self._can_spill()
+ memory = self._memory_limit()
serializer = self._jrdd_deserializer
def sortPartition(iterator):
@@ -1699,10 +1699,8 @@ class RDD(object):
numPartitions = self._defaultReducePartitions()
serializer = self.ctx.serializer
- spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower()
- == 'true')
- memory = _parse_memory(self.ctx._conf.get(
- "spark.python.worker.memory", "512m"))
+ spill = self._can_spill()
+ memory = self._memory_limit()
agg = Aggregator(createCombiner, mergeValue, mergeCombiners)
def combineLocally(iterator):
@@ -1755,21 +1753,28 @@ class RDD(object):
return self.combineByKey(lambda v: func(createZero(), v), func, func, numPartitions)
+ def _can_spill(self):
+ return self.ctx._conf.get("spark.shuffle.spill", "True").lower() == "true"
+
+ def _memory_limit(self):
+ return _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m"))
+
# TODO: support variant with custom partitioner
def groupByKey(self, numPartitions=None):
"""
Group the values for each key in the RDD into a single sequence.
- Hash-partitions the resulting RDD with into numPartitions partitions.
+ Hash-partitions the resulting RDD with numPartitions partitions.
Note: If you are grouping in order to perform an aggregation (such as a
sum or average) over each key, using reduceByKey or aggregateByKey will
provide much better performance.
>>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
- >>> map((lambda (x,y): (x, list(y))), sorted(x.groupByKey().collect()))
+ >>> sorted(x.groupByKey().mapValues(len).collect())
+ [('a', 2), ('b', 1)]
+ >>> sorted(x.groupByKey().mapValues(list).collect())
[('a', [1, 1]), ('b', [1])]
"""
-
def createCombiner(x):
return [x]
@@ -1781,8 +1786,27 @@ class RDD(object):
a.extend(b)
return a
- return self.combineByKey(createCombiner, mergeValue, mergeCombiners,
- numPartitions).mapValues(lambda x: ResultIterable(x))
+ spill = self._can_spill()
+ memory = self._memory_limit()
+ serializer = self._jrdd_deserializer
+ agg = Aggregator(createCombiner, mergeValue, mergeCombiners)
+
+ def combine(iterator):
+ merger = ExternalMerger(agg, memory * 0.9, serializer) \
+ if spill else InMemoryMerger(agg)
+ merger.mergeValues(iterator)
+ return merger.iteritems()
+
+ locally_combined = self.mapPartitions(combine, preservesPartitioning=True)
+ shuffled = locally_combined.partitionBy(numPartitions)
+
+ def groupByKey(it):
+ merger = ExternalGroupBy(agg, memory, serializer)\
+ if spill else InMemoryMerger(agg)
+ merger.mergeCombiners(it)
+ return merger.iteritems()
+
+ return shuffled.mapPartitions(groupByKey, True).mapValues(ResultIterable)
def flatMapValues(self, f):
"""
diff --git a/python/pyspark/resultiterable.py b/python/pyspark/resultiterable.py
index ef04c82866..1ab5ce14c3 100644
--- a/python/pyspark/resultiterable.py
+++ b/python/pyspark/resultiterable.py
@@ -15,15 +15,16 @@
# limitations under the License.
#
-__all__ = ["ResultIterable"]
-
import collections
+__all__ = ["ResultIterable"]
+
class ResultIterable(collections.Iterable):
"""
- A special result iterable. This is used because the standard iterator can not be pickled
+ A special result iterable. This is used because the standard
+ iterator can not be pickled
"""
def __init__(self, data):
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 0ffb41d02f..4afa82f4b2 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -220,6 +220,29 @@ class BatchedSerializer(Serializer):
return "BatchedSerializer(%s, %d)" % (str(self.serializer), self.batchSize)
+class FlattenedValuesSerializer(BatchedSerializer):
+
+ """
+ Serializes a stream of list of pairs, split the list of values
+ which contain more than a certain number of objects to make them
+ have similar sizes.
+ """
+ def __init__(self, serializer, batchSize=10):
+ BatchedSerializer.__init__(self, serializer, batchSize)
+
+ def _batched(self, iterator):
+ n = self.batchSize
+ for key, values in iterator:
+ for i in xrange(0, len(values), n):
+ yield key, values[i:i + n]
+
+ def load_stream(self, stream):
+ return self.serializer.load_stream(stream)
+
+ def __repr__(self):
+ return "FlattenedValuesSerializer(%d)" % self.batchSize
+
+
class AutoBatchedSerializer(BatchedSerializer):
"""
Choose the size of batch automatically based on the size of object
@@ -251,7 +274,7 @@ class AutoBatchedSerializer(BatchedSerializer):
return (isinstance(other, AutoBatchedSerializer) and
other.serializer == self.serializer and other.bestSize == self.bestSize)
- def __str__(self):
+ def __repr__(self):
return "AutoBatchedSerializer(%s)" % str(self.serializer)
diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py
index 10a7ccd502..8a6fc627eb 100644
--- a/python/pyspark/shuffle.py
+++ b/python/pyspark/shuffle.py
@@ -16,28 +16,35 @@
#
import os
-import sys
import platform
import shutil
import warnings
import gc
import itertools
+import operator
import random
import pyspark.heapq3 as heapq
-from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
+from pyspark.serializers import BatchedSerializer, PickleSerializer, FlattenedValuesSerializer, \
+ CompressedSerializer, AutoBatchedSerializer
+
try:
import psutil
+ process = None
+
def get_used_memory():
""" Return the used memory in MB """
- process = psutil.Process(os.getpid())
+ global process
+ if process is None or process._pid != os.getpid():
+ process = psutil.Process(os.getpid())
if hasattr(process, "memory_info"):
info = process.memory_info()
else:
info = process.get_memory_info()
return info.rss >> 20
+
except ImportError:
def get_used_memory():
@@ -46,6 +53,7 @@ except ImportError:
for line in open('/proc/self/status'):
if line.startswith('VmRSS:'):
return int(line.split()[1]) >> 10
+
else:
warnings.warn("Please install psutil to have better "
"support with spilling")
@@ -54,6 +62,7 @@ except ImportError:
rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
return rss >> 20
# TODO: support windows
+
return 0
@@ -148,10 +157,16 @@ class InMemoryMerger(Merger):
d[k] = comb(d[k], v) if k in d else v
def iteritems(self):
- """ Return the merged items ad iterator """
+ """ Return the merged items as iterator """
return self.data.iteritems()
+def _compressed_serializer(self, serializer=None):
+ # always use PickleSerializer to simplify implementation
+ ser = PickleSerializer()
+ return AutoBatchedSerializer(CompressedSerializer(ser))
+
+
class ExternalMerger(Merger):
"""
@@ -173,7 +188,7 @@ class ExternalMerger(Merger):
dict. Repeat this again until combine all the items.
- Before return any items, it will load each partition and
- combine them seperately. Yield them before loading next
+ combine them separately. Yield them before loading next
partition.
- During loading a partition, if the memory goes over limit,
@@ -182,7 +197,7 @@ class ExternalMerger(Merger):
`data` and `pdata` are used to hold the merged items in memory.
At first, all the data are merged into `data`. Once the used
- memory goes over limit, the items in `data` are dumped indo
+ memory goes over limit, the items in `data` are dumped into
disks, `data` will be cleared, all rest of items will be merged
into `pdata` and then dumped into disks. Before returning, all
the items in `pdata` will be dumped into disks.
@@ -193,16 +208,16 @@ class ExternalMerger(Merger):
>>> agg = SimpleAggregator(lambda x, y: x + y)
>>> merger = ExternalMerger(agg, 10)
>>> N = 10000
- >>> merger.mergeValues(zip(xrange(N), xrange(N)) * 10)
+ >>> merger.mergeValues(zip(xrange(N), xrange(N)))
>>> assert merger.spills > 0
>>> sum(v for k,v in merger.iteritems())
- 499950000
+ 49995000
>>> merger = ExternalMerger(agg, 10)
- >>> merger.mergeCombiners(zip(xrange(N), xrange(N)) * 10)
+ >>> merger.mergeCombiners(zip(xrange(N), xrange(N)))
>>> assert merger.spills > 0
>>> sum(v for k,v in merger.iteritems())
- 499950000
+ 49995000
"""
# the max total partitions created recursively
@@ -212,8 +227,7 @@ class ExternalMerger(Merger):
localdirs=None, scale=1, partitions=59, batch=1000):
Merger.__init__(self, aggregator)
self.memory_limit = memory_limit
- # default serializer is only used for tests
- self.serializer = serializer or AutoBatchedSerializer(PickleSerializer())
+ self.serializer = _compressed_serializer(serializer)
self.localdirs = localdirs or _get_local_dirs(str(id(self)))
# number of partitions when spill data into disks
self.partitions = partitions
@@ -221,7 +235,7 @@ class ExternalMerger(Merger):
self.batch = batch
# scale is used to scale down the hash of key for recursive hash map
self.scale = scale
- # unpartitioned merged data
+ # un-partitioned merged data
self.data = {}
# partitioned merged data, list of dicts
self.pdata = []
@@ -244,72 +258,63 @@ class ExternalMerger(Merger):
def mergeValues(self, iterator):
""" Combine the items by creator and combiner """
- iterator = iter(iterator)
# speedup attribute lookup
creator, comb = self.agg.createCombiner, self.agg.mergeValue
- d, c, batch = self.data, 0, self.batch
+ c, data, pdata, hfun, batch = 0, self.data, self.pdata, self._partition, self.batch
+ limit = self.memory_limit
for k, v in iterator:
+ d = pdata[hfun(k)] if pdata else data
d[k] = comb(d[k], v) if k in d else creator(v)
c += 1
- if c % batch == 0 and get_used_memory() > self.memory_limit:
- self._spill()
- self._partitioned_mergeValues(iterator, self._next_limit())
- break
+ if c >= batch:
+ if get_used_memory() >= limit:
+ self._spill()
+ limit = self._next_limit()
+ batch /= 2
+ c = 0
+ else:
+ batch *= 1.5
+
+ if get_used_memory() >= limit:
+ self._spill()
def _partition(self, key):
""" Return the partition for key """
return hash((key, self._seed)) % self.partitions
- def _partitioned_mergeValues(self, iterator, limit=0):
- """ Partition the items by key, then combine them """
- # speedup attribute lookup
- creator, comb = self.agg.createCombiner, self.agg.mergeValue
- c, pdata, hfun, batch = 0, self.pdata, self._partition, self.batch
-
- for k, v in iterator:
- d = pdata[hfun(k)]
- d[k] = comb(d[k], v) if k in d else creator(v)
- if not limit:
- continue
-
- c += 1
- if c % batch == 0 and get_used_memory() > limit:
- self._spill()
- limit = self._next_limit()
+ def _object_size(self, obj):
+ """ How much of memory for this obj, assume that all the objects
+ consume similar bytes of memory
+ """
+ return 1
- def mergeCombiners(self, iterator, check=True):
+ def mergeCombiners(self, iterator, limit=None):
""" Merge (K,V) pair by mergeCombiner """
- iterator = iter(iterator)
+ if limit is None:
+ limit = self.memory_limit
# speedup attribute lookup
- d, comb, batch = self.data, self.agg.mergeCombiners, self.batch
- c = 0
- for k, v in iterator:
- d[k] = comb(d[k], v) if k in d else v
- if not check:
- continue
-
- c += 1
- if c % batch == 0 and get_used_memory() > self.memory_limit:
- self._spill()
- self._partitioned_mergeCombiners(iterator, self._next_limit())
- break
-
- def _partitioned_mergeCombiners(self, iterator, limit=0):
- """ Partition the items by key, then merge them """
- comb, pdata = self.agg.mergeCombiners, self.pdata
- c, hfun = 0, self._partition
+ comb, hfun, objsize = self.agg.mergeCombiners, self._partition, self._object_size
+ c, data, pdata, batch = 0, self.data, self.pdata, self.batch
for k, v in iterator:
- d = pdata[hfun(k)]
+ d = pdata[hfun(k)] if pdata else data
d[k] = comb(d[k], v) if k in d else v
if not limit:
continue
- c += 1
- if c % self.batch == 0 and get_used_memory() > limit:
- self._spill()
- limit = self._next_limit()
+ c += objsize(v)
+ if c > batch:
+ if get_used_memory() > limit:
+ self._spill()
+ limit = self._next_limit()
+ batch /= 2
+ c = 0
+ else:
+ batch *= 1.5
+
+ if limit and get_used_memory() >= limit:
+ self._spill()
def _spill(self):
"""
@@ -335,7 +340,7 @@ class ExternalMerger(Merger):
for k, v in self.data.iteritems():
h = self._partition(k)
- # put one item in batch, make it compatitable with load_stream
+ # put one item in batch, make it compatible with load_stream
# it will increase the memory if dump them in batch
self.serializer.dump_stream([(k, v)], streams[h])
@@ -344,7 +349,7 @@ class ExternalMerger(Merger):
s.close()
self.data.clear()
- self.pdata = [{} for i in range(self.partitions)]
+ self.pdata.extend([{} for i in range(self.partitions)])
else:
for i in range(self.partitions):
@@ -370,29 +375,12 @@ class ExternalMerger(Merger):
assert not self.data
if any(self.pdata):
self._spill()
- hard_limit = self._next_limit()
+ # disable partitioning and spilling when merge combiners from disk
+ self.pdata = []
try:
for i in range(self.partitions):
- self.data = {}
- for j in range(self.spills):
- path = self._get_spill_dir(j)
- p = os.path.join(path, str(i))
- # do not check memory during merging
- self.mergeCombiners(self.serializer.load_stream(open(p)),
- False)
-
- # limit the total partitions
- if (self.scale * self.partitions < self.MAX_TOTAL_PARTITIONS
- and j < self.spills - 1
- and get_used_memory() > hard_limit):
- self.data.clear() # will read from disk again
- gc.collect() # release the memory as much as possible
- for v in self._recursive_merged_items(i):
- yield v
- return
-
- for v in self.data.iteritems():
+ for v in self._merged_items(i):
yield v
self.data.clear()
@@ -400,53 +388,56 @@ class ExternalMerger(Merger):
for j in range(self.spills):
path = self._get_spill_dir(j)
os.remove(os.path.join(path, str(i)))
-
finally:
self._cleanup()
- def _cleanup(self):
- """ Clean up all the files in disks """
- for d in self.localdirs:
- shutil.rmtree(d, True)
+ def _merged_items(self, index):
+ self.data = {}
+ limit = self._next_limit()
+ for j in range(self.spills):
+ path = self._get_spill_dir(j)
+ p = os.path.join(path, str(index))
+ # do not check memory during merging
+ self.mergeCombiners(self.serializer.load_stream(open(p)), 0)
+
+ # limit the total partitions
+ if (self.scale * self.partitions < self.MAX_TOTAL_PARTITIONS
+ and j < self.spills - 1
+ and get_used_memory() > limit):
+ self.data.clear() # will read from disk again
+ gc.collect() # release the memory as much as possible
+ return self._recursive_merged_items(index)
- def _recursive_merged_items(self, start):
+ return self.data.iteritems()
+
+ def _recursive_merged_items(self, index):
"""
merge the partitioned items and return the as iterator
If one partition can not be fit in memory, then them will be
partitioned and merged recursively.
"""
- # make sure all the data are dumps into disks.
- assert not self.data
- if any(self.pdata):
- self._spill()
- assert self.spills > 0
-
- for i in range(start, self.partitions):
- subdirs = [os.path.join(d, "parts", str(i))
- for d in self.localdirs]
- m = ExternalMerger(self.agg, self.memory_limit, self.serializer,
- subdirs, self.scale * self.partitions, self.partitions)
- m.pdata = [{} for _ in range(self.partitions)]
- limit = self._next_limit()
-
- for j in range(self.spills):
- path = self._get_spill_dir(j)
- p = os.path.join(path, str(i))
- m._partitioned_mergeCombiners(
- self.serializer.load_stream(open(p)))
-
- if get_used_memory() > limit:
- m._spill()
- limit = self._next_limit()
+ subdirs = [os.path.join(d, "parts", str(index)) for d in self.localdirs]
+ m = ExternalMerger(self.agg, self.memory_limit, self.serializer, subdirs,
+ self.scale * self.partitions, self.partitions, self.batch)
+ m.pdata = [{} for _ in range(self.partitions)]
+ limit = self._next_limit()
+
+ for j in range(self.spills):
+ path = self._get_spill_dir(j)
+ p = os.path.join(path, str(index))
+ m.mergeCombiners(self.serializer.load_stream(open(p)), 0)
+
+ if get_used_memory() > limit:
+ m._spill()
+ limit = self._next_limit()
- for v in m._external_items():
- yield v
+ return m._external_items()
- # remove the merged partition
- for j in range(self.spills):
- path = self._get_spill_dir(j)
- os.remove(os.path.join(path, str(i)))
+ def _cleanup(self):
+ """ Clean up all the files in disks """
+ for d in self.localdirs:
+ shutil.rmtree(d, True)
class ExternalSorter(object):
@@ -457,6 +448,7 @@ class ExternalSorter(object):
The spilling will only happen when the used memory goes above
the limit.
+
>>> sorter = ExternalSorter(1) # 1M
>>> import random
>>> l = range(1024)
@@ -469,7 +461,7 @@ class ExternalSorter(object):
def __init__(self, memory_limit, serializer=None):
self.memory_limit = memory_limit
self.local_dirs = _get_local_dirs("sort")
- self.serializer = serializer or AutoBatchedSerializer(PickleSerializer())
+ self.serializer = _compressed_serializer(serializer)
def _get_path(self, n):
""" Choose one directory for spill by number n """
@@ -515,6 +507,7 @@ class ExternalSorter(object):
limit = self._next_limit()
MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
DiskBytesSpilled += os.path.getsize(path)
+ os.unlink(path) # data will be deleted after close
elif not chunks:
batch = min(batch * 2, 10000)
@@ -529,6 +522,310 @@ class ExternalSorter(object):
return heapq.merge(chunks, key=key, reverse=reverse)
+class ExternalList(object):
+ """
+ ExternalList can have many items which cannot be hold in memory in
+ the same time.
+
+ >>> l = ExternalList(range(100))
+ >>> len(l)
+ 100
+ >>> l.append(10)
+ >>> len(l)
+ 101
+ >>> for i in range(20240):
+ ... l.append(i)
+ >>> len(l)
+ 20341
+ >>> import pickle
+ >>> l2 = pickle.loads(pickle.dumps(l))
+ >>> len(l2)
+ 20341
+ >>> list(l2)[100]
+ 10
+ """
+ LIMIT = 10240
+
+ def __init__(self, values):
+ self.values = values
+ self.count = len(values)
+ self._file = None
+ self._ser = None
+
+ def __getstate__(self):
+ if self._file is not None:
+ self._file.flush()
+ f = os.fdopen(os.dup(self._file.fileno()))
+ f.seek(0)
+ serialized = f.read()
+ else:
+ serialized = ''
+ return self.values, self.count, serialized
+
+ def __setstate__(self, item):
+ self.values, self.count, serialized = item
+ if serialized:
+ self._open_file()
+ self._file.write(serialized)
+ else:
+ self._file = None
+ self._ser = None
+
+ def __iter__(self):
+ if self._file is not None:
+ self._file.flush()
+ # read all items from disks first
+ with os.fdopen(os.dup(self._file.fileno()), 'r') as f:
+ f.seek(0)
+ for v in self._ser.load_stream(f):
+ yield v
+
+ for v in self.values:
+ yield v
+
+ def __len__(self):
+ return self.count
+
+ def append(self, value):
+ self.values.append(value)
+ self.count += 1
+ # dump them into disk if the key is huge
+ if len(self.values) >= self.LIMIT:
+ self._spill()
+
+ def _open_file(self):
+ dirs = _get_local_dirs("objects")
+ d = dirs[id(self) % len(dirs)]
+ if not os.path.exists(d):
+ os.makedirs(d)
+ p = os.path.join(d, str(id))
+ self._file = open(p, "w+", 65536)
+ self._ser = BatchedSerializer(CompressedSerializer(PickleSerializer()), 1024)
+ os.unlink(p)
+
+ def _spill(self):
+ """ dump the values into disk """
+ global MemoryBytesSpilled, DiskBytesSpilled
+ if self._file is None:
+ self._open_file()
+
+ used_memory = get_used_memory()
+ pos = self._file.tell()
+ self._ser.dump_stream(self.values, self._file)
+ self.values = []
+ gc.collect()
+ DiskBytesSpilled += self._file.tell() - pos
+ MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
+
+
+class ExternalListOfList(ExternalList):
+ """
+ An external list for list.
+
+ >>> l = ExternalListOfList([[i, i] for i in range(100)])
+ >>> len(l)
+ 200
+ >>> l.append(range(10))
+ >>> len(l)
+ 210
+ >>> len(list(l))
+ 210
+ """
+
+ def __init__(self, values):
+ ExternalList.__init__(self, values)
+ self.count = sum(len(i) for i in values)
+
+ def append(self, value):
+ ExternalList.append(self, value)
+ # already counted 1 in ExternalList.append
+ self.count += len(value) - 1
+
+ def __iter__(self):
+ for values in ExternalList.__iter__(self):
+ for v in values:
+ yield v
+
+
+class GroupByKey(object):
+ """
+ Group a sorted iterator as [(k1, it1), (k2, it2), ...]
+
+ >>> k = [i/3 for i in range(6)]
+ >>> v = [[i] for i in range(6)]
+ >>> g = GroupByKey(iter(zip(k, v)))
+ >>> [(k, list(it)) for k, it in g]
+ [(0, [0, 1, 2]), (1, [3, 4, 5])]
+ """
+
+ def __init__(self, iterator):
+ self.iterator = iter(iterator)
+ self.next_item = None
+
+ def __iter__(self):
+ return self
+
+ def next(self):
+ key, value = self.next_item if self.next_item else next(self.iterator)
+ values = ExternalListOfList([value])
+ try:
+ while True:
+ k, v = next(self.iterator)
+ if k != key:
+ self.next_item = (k, v)
+ break
+ values.append(v)
+ except StopIteration:
+ self.next_item = None
+ return key, values
+
+
+class ExternalGroupBy(ExternalMerger):
+
+ """
+ Group by the items by key. If any partition of them can not been
+ hold in memory, it will do sort based group by.
+
+ This class works as follows:
+
+ - It repeatedly group the items by key and save them in one dict in
+ memory.
+
+ - When the used memory goes above memory limit, it will split
+ the combined data into partitions by hash code, dump them
+ into disk, one file per partition. If the number of keys
+ in one partitions is smaller than 1000, it will sort them
+ by key before dumping into disk.
+
+ - Then it goes through the rest of the iterator, group items
+ by key into different dict by hash. Until the used memory goes over
+ memory limit, it dump all the dicts into disks, one file per
+ dict. Repeat this again until combine all the items. It
+ also will try to sort the items by key in each partition
+ before dumping into disks.
+
+ - It will yield the grouped items partitions by partitions.
+ If the data in one partitions can be hold in memory, then it
+ will load and combine them in memory and yield.
+
+ - If the dataset in one partition cannot be hold in memory,
+ it will sort them first. If all the files are already sorted,
+ it merge them by heap.merge(), so it will do external sort
+ for all the files.
+
+ - After sorting, `GroupByKey` class will put all the continuous
+ items with the same key as a group, yield the values as
+ an iterator.
+ """
+ SORT_KEY_LIMIT = 1000
+
+ def flattened_serializer(self):
+ assert isinstance(self.serializer, BatchedSerializer)
+ ser = self.serializer
+ return FlattenedValuesSerializer(ser, 20)
+
+ def _object_size(self, obj):
+ return len(obj)
+
+ def _spill(self):
+ """
+ dump already partitioned data into disks.
+ """
+ global MemoryBytesSpilled, DiskBytesSpilled
+ path = self._get_spill_dir(self.spills)
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+ used_memory = get_used_memory()
+ if not self.pdata:
+ # The data has not been partitioned, it will iterator the
+ # data once, write them into different files, has no
+ # additional memory. It only called when the memory goes
+ # above limit at the first time.
+
+ # open all the files for writing
+ streams = [open(os.path.join(path, str(i)), 'w')
+ for i in range(self.partitions)]
+
+ # If the number of keys is small, then the overhead of sort is small
+ # sort them before dumping into disks
+ self._sorted = len(self.data) < self.SORT_KEY_LIMIT
+ if self._sorted:
+ self.serializer = self.flattened_serializer()
+ for k in sorted(self.data.keys()):
+ h = self._partition(k)
+ self.serializer.dump_stream([(k, self.data[k])], streams[h])
+ else:
+ for k, v in self.data.iteritems():
+ h = self._partition(k)
+ self.serializer.dump_stream([(k, v)], streams[h])
+
+ for s in streams:
+ DiskBytesSpilled += s.tell()
+ s.close()
+
+ self.data.clear()
+ # self.pdata is cached in `mergeValues` and `mergeCombiners`
+ self.pdata.extend([{} for i in range(self.partitions)])
+
+ else:
+ for i in range(self.partitions):
+ p = os.path.join(path, str(i))
+ with open(p, "w") as f:
+ # dump items in batch
+ if self._sorted:
+ # sort by key only (stable)
+ sorted_items = sorted(self.pdata[i].iteritems(), key=operator.itemgetter(0))
+ self.serializer.dump_stream(sorted_items, f)
+ else:
+ self.serializer.dump_stream(self.pdata[i].iteritems(), f)
+ self.pdata[i].clear()
+ DiskBytesSpilled += os.path.getsize(p)
+
+ self.spills += 1
+ gc.collect() # release the memory as much as possible
+ MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
+
+ def _merged_items(self, index):
+ size = sum(os.path.getsize(os.path.join(self._get_spill_dir(j), str(index)))
+ for j in range(self.spills))
+ # if the memory can not hold all the partition,
+ # then use sort based merge. Because of compression,
+ # the data on disks will be much smaller than needed memory
+ if (size >> 20) >= self.memory_limit / 10:
+ return self._merge_sorted_items(index)
+
+ self.data = {}
+ for j in range(self.spills):
+ path = self._get_spill_dir(j)
+ p = os.path.join(path, str(index))
+ # do not check memory during merging
+ self.mergeCombiners(self.serializer.load_stream(open(p)), 0)
+ return self.data.iteritems()
+
+ def _merge_sorted_items(self, index):
+ """ load a partition from disk, then sort and group by key """
+ def load_partition(j):
+ path = self._get_spill_dir(j)
+ p = os.path.join(path, str(index))
+ return self.serializer.load_stream(open(p, 'r', 65536))
+
+ disk_items = [load_partition(j) for j in range(self.spills)]
+
+ if self._sorted:
+ # all the partitions are already sorted
+ sorted_items = heapq.merge(disk_items, key=operator.itemgetter(0))
+
+ else:
+ # Flatten the combined values, so it will not consume huge
+ # memory during merging sort.
+ ser = self.flattened_serializer()
+ sorter = ExternalSorter(self.memory_limit, ser)
+ sorted_items = sorter.sorted(itertools.chain(*disk_items),
+ key=operator.itemgetter(0))
+ return ((k, vs) for k, vs in GroupByKey(sorted_items))
+
+
if __name__ == "__main__":
import doctest
doctest.testmod()
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index dd8d3b1c53..0bd5d20f78 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -31,6 +31,7 @@ import tempfile
import time
import zipfile
import random
+import itertools
import threading
import hashlib
@@ -76,7 +77,7 @@ SPARK_HOME = os.environ["SPARK_HOME"]
class MergerTests(unittest.TestCase):
def setUp(self):
- self.N = 1 << 14
+ self.N = 1 << 12
self.l = [i for i in xrange(self.N)]
self.data = zip(self.l, self.l)
self.agg = Aggregator(lambda x: [x],
@@ -108,7 +109,7 @@ class MergerTests(unittest.TestCase):
sum(xrange(self.N)))
def test_medium_dataset(self):
- m = ExternalMerger(self.agg, 10)
+ m = ExternalMerger(self.agg, 30)
m.mergeValues(self.data)
self.assertTrue(m.spills >= 1)
self.assertEqual(sum(sum(v) for k, v in m.iteritems()),
@@ -124,10 +125,36 @@ class MergerTests(unittest.TestCase):
m = ExternalMerger(self.agg, 10, partitions=3)
m.mergeCombiners(map(lambda (k, v): (k, [str(v)]), self.data * 10))
self.assertTrue(m.spills >= 1)
- self.assertEqual(sum(len(v) for k, v in m._recursive_merged_items(0)),
+ self.assertEqual(sum(len(v) for k, v in m.iteritems()),
self.N * 10)
m._cleanup()
+ def test_group_by_key(self):
+
+ def gen_data(N, step):
+ for i in range(1, N + 1, step):
+ for j in range(i):
+ yield (i, [j])
+
+ def gen_gs(N, step=1):
+ return shuffle.GroupByKey(gen_data(N, step))
+
+ self.assertEqual(1, len(list(gen_gs(1))))
+ self.assertEqual(2, len(list(gen_gs(2))))
+ self.assertEqual(100, len(list(gen_gs(100))))
+ self.assertEqual(range(1, 101), [k for k, _ in gen_gs(100)])
+ self.assertTrue(all(range(k) == list(vs) for k, vs in gen_gs(100)))
+
+ for k, vs in gen_gs(50002, 10000):
+ self.assertEqual(k, len(vs))
+ self.assertEqual(range(k), list(vs))
+
+ ser = PickleSerializer()
+ l = ser.loads(ser.dumps(list(gen_gs(50002, 30000))))
+ for k, vs in l:
+ self.assertEqual(k, len(vs))
+ self.assertEqual(range(k), list(vs))
+
class SorterTests(unittest.TestCase):
def test_in_memory_sort(self):
@@ -702,6 +729,21 @@ class RDDTests(ReusedPySparkTestCase):
self.assertEquals(result.getNumPartitions(), 5)
self.assertEquals(result.count(), 3)
+ def test_external_group_by_key(self):
+ self.sc._conf.set("spark.python.worker.memory", "5m")
+ N = 200001
+ kv = self.sc.parallelize(range(N)).map(lambda x: (x % 3, x))
+ gkv = kv.groupByKey().cache()
+ self.assertEqual(3, gkv.count())
+ filtered = gkv.filter(lambda (k, vs): k == 1)
+ self.assertEqual(1, filtered.count())
+ self.assertEqual([(1, N/3)], filtered.mapValues(len).collect())
+ self.assertEqual([(N/3, N/3)],
+ filtered.values().map(lambda x: (len(x), len(list(x)))).collect())
+ result = filtered.collect()[0][1]
+ self.assertEqual(N/3, len(result))
+ self.assertTrue(isinstance(result.data, shuffle.ExternalList))
+
def test_sort_on_empty_rdd(self):
self.assertEqual([], self.sc.parallelize(zip([], [])).sortByKey().collect())
@@ -752,9 +794,9 @@ class RDDTests(ReusedPySparkTestCase):
self.assertEqual(rdd.getNumPartitions() + 2, parted.union(rdd).getNumPartitions())
self.assertEqual(rdd.getNumPartitions() + 2, rdd.union(parted).getNumPartitions())
- self.sc.setJobGroup("test1", "test", True)
tracker = self.sc.statusTracker()
+ self.sc.setJobGroup("test1", "test", True)
d = sorted(parted.join(parted).collect())
self.assertEqual(10, len(d))
self.assertEqual((0, (0, 0)), d[0])