aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorDavies Liu <davies.liu@gmail.com>2015-04-09 17:07:23 -0700
committerJosh Rosen <joshrosen@databricks.com>2015-04-09 17:07:23 -0700
commitb5c51c8df480f1a82a82e4d597d8eea631bffb4e (patch)
tree7842078dd2b5d8dab92a129725353647136e8a9f /python
parent9c67049b4ef416a80803ccb958bbac1dd02cc380 (diff)
downloadspark-b5c51c8df480f1a82a82e4d597d8eea631bffb4e.tar.gz
spark-b5c51c8df480f1a82a82e4d597d8eea631bffb4e.tar.bz2
spark-b5c51c8df480f1a82a82e4d597d8eea631bffb4e.zip
[SPARK-3074] [PySpark] support groupByKey() with single huge key
This patch change groupByKey() to use external sort based approach, so it can support single huge key. For example, it can group by a dataset including one hot key with 40 millions values (strings), using 500M memory for Python worker, finished in about 2 minutes. (it will need 6G memory in hash based approach). During groupByKey(), it will do in-memory groupBy first. If the dataset can not fit in memory, then data will be partitioned by hash. If one partition still can not fit in memory, it will switch to sort based groupBy(). Author: Davies Liu <davies.liu@gmail.com> Author: Davies Liu <davies@databricks.com> Closes #1977 from davies/groupby and squashes the following commits: af3713a [Davies Liu] make sure it's iterator 67772dd [Davies Liu] fix tests e78c15c [Davies Liu] address comments 0b0fde8 [Davies Liu] address comments 0dcf320 [Davies Liu] address comments, rollback changes in ResultIterable e3b8eab [Davies Liu] fix narrow dependency 2a1857a [Davies Liu] typo d2f053b [Davies Liu] add repr for FlattedValuesSerializer c6a2f8d [Davies Liu] address comments 9e2df24 [Davies Liu] Merge branch 'master' of github.com:apache/spark into groupby 2b9c261 [Davies Liu] fix typo in comments 70aadcd [Davies Liu] Merge branch 'master' of github.com:apache/spark into groupby a14b4bd [Davies Liu] Merge branch 'master' of github.com:apache/spark into groupby ab5515b [Davies Liu] Merge branch 'master' into groupby 651f891 [Davies Liu] simplify GroupByKey 1578f2e [Davies Liu] Merge branch 'master' of github.com:apache/spark into groupby 1f69f93 [Davies Liu] fix tests 0d3395f [Davies Liu] Merge branch 'master' of github.com:apache/spark into groupby 341f1e0 [Davies Liu] add comments, refactor 47918b8 [Davies Liu] remove unused code 6540948 [Davies Liu] address comments: 17f4ec6 [Davies Liu] Merge branch 'master' of github.com:apache/spark into groupby 4d4bc86 [Davies Liu] bugfix 8ef965e [Davies Liu] Merge branch 'master' into groupby fbc504a [Davies Liu] Merge branch 'master' into groupby 779ed03 [Davies Liu] fix merge conflict 2c1d05b [Davies Liu] refactor, minor turning b48cda5 [Davies Liu] Merge branch 'master' into groupby 85138e6 [Davies Liu] Merge branch 'master' into groupby acd8e1b [Davies Liu] fix memory when groupByKey().count() 905b233 [Davies Liu] Merge branch 'sort' into groupby 1f075ed [Davies Liu] Merge branch 'master' into sort 4b07d39 [Davies Liu] compress the data while spilling 0a081c6 [Davies Liu] Merge branch 'master' into groupby f157fe7 [Davies Liu] Merge branch 'sort' into groupby eb53ca6 [Davies Liu] Merge branch 'master' into sort b2dc3bf [Davies Liu] Merge branch 'sort' into groupby 644abaf [Davies Liu] add license in LICENSE 19f7873 [Davies Liu] improve tests 11ba318 [Davies Liu] typo 085aef8 [Davies Liu] Merge branch 'master' into groupby 3ee58e5 [Davies Liu] switch to sort based groupBy, based on size of data 1ea0669 [Davies Liu] choose sort based groupByKey() automatically b40bae7 [Davies Liu] bugfix efa23df [Davies Liu] refactor, add spark.shuffle.sort=False 250be4e [Davies Liu] flatten the combined values when dumping into disks d05060d [Davies Liu] group the same key before shuffle, reduce the comparison during sorting 083d842 [Davies Liu] sorted based groupByKey() 55602ee [Davies Liu] use external sort in sortBy() and sortByKey()
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/join.py13
-rw-r--r--python/pyspark/rdd.py48
-rw-r--r--python/pyspark/resultiterable.py7
-rw-r--r--python/pyspark/serializers.py25
-rw-r--r--python/pyspark/shuffle.py531
-rw-r--r--python/pyspark/tests.py50
6 files changed, 531 insertions, 143 deletions
diff --git a/python/pyspark/join.py b/python/pyspark/join.py
index efc1ef9396..c3491defb2 100644
--- a/python/pyspark/join.py
+++ b/python/pyspark/join.py
@@ -48,7 +48,7 @@ def python_join(rdd, other, numPartitions):
vbuf.append(v)
elif n == 2:
wbuf.append(v)
- return [(v, w) for v in vbuf for w in wbuf]
+ return ((v, w) for v in vbuf for w in wbuf)
return _do_python_join(rdd, other, numPartitions, dispatch)
@@ -62,7 +62,7 @@ def python_right_outer_join(rdd, other, numPartitions):
wbuf.append(v)
if not vbuf:
vbuf.append(None)
- return [(v, w) for v in vbuf for w in wbuf]
+ return ((v, w) for v in vbuf for w in wbuf)
return _do_python_join(rdd, other, numPartitions, dispatch)
@@ -76,7 +76,7 @@ def python_left_outer_join(rdd, other, numPartitions):
wbuf.append(v)
if not wbuf:
wbuf.append(None)
- return [(v, w) for v in vbuf for w in wbuf]
+ return ((v, w) for v in vbuf for w in wbuf)
return _do_python_join(rdd, other, numPartitions, dispatch)
@@ -104,8 +104,9 @@ def python_cogroup(rdds, numPartitions):
rdd_len = len(vrdds)
def dispatch(seq):
- bufs = [[] for i in range(rdd_len)]
- for (n, v) in seq:
+ bufs = [[] for _ in range(rdd_len)]
+ for n, v in seq:
bufs[n].append(v)
- return tuple(map(ResultIterable, bufs))
+ return tuple(ResultIterable(vs) for vs in bufs)
+
return union_vrdds.groupByKey(numPartitions).mapValues(dispatch)
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 2d05611321..1b18789040 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -41,7 +41,7 @@ from pyspark.rddsampler import RDDSampler, RDDRangeSampler, RDDStratifiedSampler
from pyspark.storagelevel import StorageLevel
from pyspark.resultiterable import ResultIterable
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \
- get_used_memory, ExternalSorter
+ get_used_memory, ExternalSorter, ExternalGroupBy
from pyspark.traceback_utils import SCCallSiteSync
from py4j.java_collections import ListConverter, MapConverter
@@ -573,8 +573,8 @@ class RDD(object):
if numPartitions is None:
numPartitions = self._defaultReducePartitions()
- spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower() == 'true')
- memory = _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m"))
+ spill = self._can_spill()
+ memory = self._memory_limit()
serializer = self._jrdd_deserializer
def sortPartition(iterator):
@@ -1699,10 +1699,8 @@ class RDD(object):
numPartitions = self._defaultReducePartitions()
serializer = self.ctx.serializer
- spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower()
- == 'true')
- memory = _parse_memory(self.ctx._conf.get(
- "spark.python.worker.memory", "512m"))
+ spill = self._can_spill()
+ memory = self._memory_limit()
agg = Aggregator(createCombiner, mergeValue, mergeCombiners)
def combineLocally(iterator):
@@ -1755,21 +1753,28 @@ class RDD(object):
return self.combineByKey(lambda v: func(createZero(), v), func, func, numPartitions)
+ def _can_spill(self):
+ return self.ctx._conf.get("spark.shuffle.spill", "True").lower() == "true"
+
+ def _memory_limit(self):
+ return _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m"))
+
# TODO: support variant with custom partitioner
def groupByKey(self, numPartitions=None):
"""
Group the values for each key in the RDD into a single sequence.
- Hash-partitions the resulting RDD with into numPartitions partitions.
+ Hash-partitions the resulting RDD with numPartitions partitions.
Note: If you are grouping in order to perform an aggregation (such as a
sum or average) over each key, using reduceByKey or aggregateByKey will
provide much better performance.
>>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
- >>> map((lambda (x,y): (x, list(y))), sorted(x.groupByKey().collect()))
+ >>> sorted(x.groupByKey().mapValues(len).collect())
+ [('a', 2), ('b', 1)]
+ >>> sorted(x.groupByKey().mapValues(list).collect())
[('a', [1, 1]), ('b', [1])]
"""
-
def createCombiner(x):
return [x]
@@ -1781,8 +1786,27 @@ class RDD(object):
a.extend(b)
return a
- return self.combineByKey(createCombiner, mergeValue, mergeCombiners,
- numPartitions).mapValues(lambda x: ResultIterable(x))
+ spill = self._can_spill()
+ memory = self._memory_limit()
+ serializer = self._jrdd_deserializer
+ agg = Aggregator(createCombiner, mergeValue, mergeCombiners)
+
+ def combine(iterator):
+ merger = ExternalMerger(agg, memory * 0.9, serializer) \
+ if spill else InMemoryMerger(agg)
+ merger.mergeValues(iterator)
+ return merger.iteritems()
+
+ locally_combined = self.mapPartitions(combine, preservesPartitioning=True)
+ shuffled = locally_combined.partitionBy(numPartitions)
+
+ def groupByKey(it):
+ merger = ExternalGroupBy(agg, memory, serializer)\
+ if spill else InMemoryMerger(agg)
+ merger.mergeCombiners(it)
+ return merger.iteritems()
+
+ return shuffled.mapPartitions(groupByKey, True).mapValues(ResultIterable)
def flatMapValues(self, f):
"""
diff --git a/python/pyspark/resultiterable.py b/python/pyspark/resultiterable.py
index ef04c82866..1ab5ce14c3 100644
--- a/python/pyspark/resultiterable.py
+++ b/python/pyspark/resultiterable.py
@@ -15,15 +15,16 @@
# limitations under the License.
#
-__all__ = ["ResultIterable"]
-
import collections
+__all__ = ["ResultIterable"]
+
class ResultIterable(collections.Iterable):
"""
- A special result iterable. This is used because the standard iterator can not be pickled
+ A special result iterable. This is used because the standard
+ iterator can not be pickled
"""
def __init__(self, data):
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 0ffb41d02f..4afa82f4b2 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -220,6 +220,29 @@ class BatchedSerializer(Serializer):
return "BatchedSerializer(%s, %d)" % (str(self.serializer), self.batchSize)
+class FlattenedValuesSerializer(BatchedSerializer):
+
+ """
+ Serializes a stream of list of pairs, split the list of values
+ which contain more than a certain number of objects to make them
+ have similar sizes.
+ """
+ def __init__(self, serializer, batchSize=10):
+ BatchedSerializer.__init__(self, serializer, batchSize)
+
+ def _batched(self, iterator):
+ n = self.batchSize
+ for key, values in iterator:
+ for i in xrange(0, len(values), n):
+ yield key, values[i:i + n]
+
+ def load_stream(self, stream):
+ return self.serializer.load_stream(stream)
+
+ def __repr__(self):
+ return "FlattenedValuesSerializer(%d)" % self.batchSize
+
+
class AutoBatchedSerializer(BatchedSerializer):
"""
Choose the size of batch automatically based on the size of object
@@ -251,7 +274,7 @@ class AutoBatchedSerializer(BatchedSerializer):
return (isinstance(other, AutoBatchedSerializer) and
other.serializer == self.serializer and other.bestSize == self.bestSize)
- def __str__(self):
+ def __repr__(self):
return "AutoBatchedSerializer(%s)" % str(self.serializer)
diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py
index 10a7ccd502..8a6fc627eb 100644
--- a/python/pyspark/shuffle.py
+++ b/python/pyspark/shuffle.py
@@ -16,28 +16,35 @@
#
import os
-import sys
import platform
import shutil
import warnings
import gc
import itertools
+import operator
import random
import pyspark.heapq3 as heapq
-from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
+from pyspark.serializers import BatchedSerializer, PickleSerializer, FlattenedValuesSerializer, \
+ CompressedSerializer, AutoBatchedSerializer
+
try:
import psutil
+ process = None
+
def get_used_memory():
""" Return the used memory in MB """
- process = psutil.Process(os.getpid())
+ global process
+ if process is None or process._pid != os.getpid():
+ process = psutil.Process(os.getpid())
if hasattr(process, "memory_info"):
info = process.memory_info()
else:
info = process.get_memory_info()
return info.rss >> 20
+
except ImportError:
def get_used_memory():
@@ -46,6 +53,7 @@ except ImportError:
for line in open('/proc/self/status'):
if line.startswith('VmRSS:'):
return int(line.split()[1]) >> 10
+
else:
warnings.warn("Please install psutil to have better "
"support with spilling")
@@ -54,6 +62,7 @@ except ImportError:
rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
return rss >> 20
# TODO: support windows
+
return 0
@@ -148,10 +157,16 @@ class InMemoryMerger(Merger):
d[k] = comb(d[k], v) if k in d else v
def iteritems(self):
- """ Return the merged items ad iterator """
+ """ Return the merged items as iterator """
return self.data.iteritems()
+def _compressed_serializer(self, serializer=None):
+ # always use PickleSerializer to simplify implementation
+ ser = PickleSerializer()
+ return AutoBatchedSerializer(CompressedSerializer(ser))
+
+
class ExternalMerger(Merger):
"""
@@ -173,7 +188,7 @@ class ExternalMerger(Merger):
dict. Repeat this again until combine all the items.
- Before return any items, it will load each partition and
- combine them seperately. Yield them before loading next
+ combine them separately. Yield them before loading next
partition.
- During loading a partition, if the memory goes over limit,
@@ -182,7 +197,7 @@ class ExternalMerger(Merger):
`data` and `pdata` are used to hold the merged items in memory.
At first, all the data are merged into `data`. Once the used
- memory goes over limit, the items in `data` are dumped indo
+ memory goes over limit, the items in `data` are dumped into
disks, `data` will be cleared, all rest of items will be merged
into `pdata` and then dumped into disks. Before returning, all
the items in `pdata` will be dumped into disks.
@@ -193,16 +208,16 @@ class ExternalMerger(Merger):
>>> agg = SimpleAggregator(lambda x, y: x + y)
>>> merger = ExternalMerger(agg, 10)
>>> N = 10000
- >>> merger.mergeValues(zip(xrange(N), xrange(N)) * 10)
+ >>> merger.mergeValues(zip(xrange(N), xrange(N)))
>>> assert merger.spills > 0
>>> sum(v for k,v in merger.iteritems())
- 499950000
+ 49995000
>>> merger = ExternalMerger(agg, 10)
- >>> merger.mergeCombiners(zip(xrange(N), xrange(N)) * 10)
+ >>> merger.mergeCombiners(zip(xrange(N), xrange(N)))
>>> assert merger.spills > 0
>>> sum(v for k,v in merger.iteritems())
- 499950000
+ 49995000
"""
# the max total partitions created recursively
@@ -212,8 +227,7 @@ class ExternalMerger(Merger):
localdirs=None, scale=1, partitions=59, batch=1000):
Merger.__init__(self, aggregator)
self.memory_limit = memory_limit
- # default serializer is only used for tests
- self.serializer = serializer or AutoBatchedSerializer(PickleSerializer())
+ self.serializer = _compressed_serializer(serializer)
self.localdirs = localdirs or _get_local_dirs(str(id(self)))
# number of partitions when spill data into disks
self.partitions = partitions
@@ -221,7 +235,7 @@ class ExternalMerger(Merger):
self.batch = batch
# scale is used to scale down the hash of key for recursive hash map
self.scale = scale
- # unpartitioned merged data
+ # un-partitioned merged data
self.data = {}
# partitioned merged data, list of dicts
self.pdata = []
@@ -244,72 +258,63 @@ class ExternalMerger(Merger):
def mergeValues(self, iterator):
""" Combine the items by creator and combiner """
- iterator = iter(iterator)
# speedup attribute lookup
creator, comb = self.agg.createCombiner, self.agg.mergeValue
- d, c, batch = self.data, 0, self.batch
+ c, data, pdata, hfun, batch = 0, self.data, self.pdata, self._partition, self.batch
+ limit = self.memory_limit
for k, v in iterator:
+ d = pdata[hfun(k)] if pdata else data
d[k] = comb(d[k], v) if k in d else creator(v)
c += 1
- if c % batch == 0 and get_used_memory() > self.memory_limit:
- self._spill()
- self._partitioned_mergeValues(iterator, self._next_limit())
- break
+ if c >= batch:
+ if get_used_memory() >= limit:
+ self._spill()
+ limit = self._next_limit()
+ batch /= 2
+ c = 0
+ else:
+ batch *= 1.5
+
+ if get_used_memory() >= limit:
+ self._spill()
def _partition(self, key):
""" Return the partition for key """
return hash((key, self._seed)) % self.partitions
- def _partitioned_mergeValues(self, iterator, limit=0):
- """ Partition the items by key, then combine them """
- # speedup attribute lookup
- creator, comb = self.agg.createCombiner, self.agg.mergeValue
- c, pdata, hfun, batch = 0, self.pdata, self._partition, self.batch
-
- for k, v in iterator:
- d = pdata[hfun(k)]
- d[k] = comb(d[k], v) if k in d else creator(v)
- if not limit:
- continue
-
- c += 1
- if c % batch == 0 and get_used_memory() > limit:
- self._spill()
- limit = self._next_limit()
+ def _object_size(self, obj):
+ """ How much of memory for this obj, assume that all the objects
+ consume similar bytes of memory
+ """
+ return 1
- def mergeCombiners(self, iterator, check=True):
+ def mergeCombiners(self, iterator, limit=None):
""" Merge (K,V) pair by mergeCombiner """
- iterator = iter(iterator)
+ if limit is None:
+ limit = self.memory_limit
# speedup attribute lookup
- d, comb, batch = self.data, self.agg.mergeCombiners, self.batch
- c = 0
- for k, v in iterator:
- d[k] = comb(d[k], v) if k in d else v
- if not check:
- continue
-
- c += 1
- if c % batch == 0 and get_used_memory() > self.memory_limit:
- self._spill()
- self._partitioned_mergeCombiners(iterator, self._next_limit())
- break
-
- def _partitioned_mergeCombiners(self, iterator, limit=0):
- """ Partition the items by key, then merge them """
- comb, pdata = self.agg.mergeCombiners, self.pdata
- c, hfun = 0, self._partition
+ comb, hfun, objsize = self.agg.mergeCombiners, self._partition, self._object_size
+ c, data, pdata, batch = 0, self.data, self.pdata, self.batch
for k, v in iterator:
- d = pdata[hfun(k)]
+ d = pdata[hfun(k)] if pdata else data
d[k] = comb(d[k], v) if k in d else v
if not limit:
continue
- c += 1
- if c % self.batch == 0 and get_used_memory() > limit:
- self._spill()
- limit = self._next_limit()
+ c += objsize(v)
+ if c > batch:
+ if get_used_memory() > limit:
+ self._spill()
+ limit = self._next_limit()
+ batch /= 2
+ c = 0
+ else:
+ batch *= 1.5
+
+ if limit and get_used_memory() >= limit:
+ self._spill()
def _spill(self):
"""
@@ -335,7 +340,7 @@ class ExternalMerger(Merger):
for k, v in self.data.iteritems():
h = self._partition(k)
- # put one item in batch, make it compatitable with load_stream
+ # put one item in batch, make it compatible with load_stream
# it will increase the memory if dump them in batch
self.serializer.dump_stream([(k, v)], streams[h])
@@ -344,7 +349,7 @@ class ExternalMerger(Merger):
s.close()
self.data.clear()
- self.pdata = [{} for i in range(self.partitions)]
+ self.pdata.extend([{} for i in range(self.partitions)])
else:
for i in range(self.partitions):
@@ -370,29 +375,12 @@ class ExternalMerger(Merger):
assert not self.data
if any(self.pdata):
self._spill()
- hard_limit = self._next_limit()
+ # disable partitioning and spilling when merge combiners from disk
+ self.pdata = []
try:
for i in range(self.partitions):
- self.data = {}
- for j in range(self.spills):
- path = self._get_spill_dir(j)
- p = os.path.join(path, str(i))
- # do not check memory during merging
- self.mergeCombiners(self.serializer.load_stream(open(p)),
- False)
-
- # limit the total partitions
- if (self.scale * self.partitions < self.MAX_TOTAL_PARTITIONS
- and j < self.spills - 1
- and get_used_memory() > hard_limit):
- self.data.clear() # will read from disk again
- gc.collect() # release the memory as much as possible
- for v in self._recursive_merged_items(i):
- yield v
- return
-
- for v in self.data.iteritems():
+ for v in self._merged_items(i):
yield v
self.data.clear()
@@ -400,53 +388,56 @@ class ExternalMerger(Merger):
for j in range(self.spills):
path = self._get_spill_dir(j)
os.remove(os.path.join(path, str(i)))
-
finally:
self._cleanup()
- def _cleanup(self):
- """ Clean up all the files in disks """
- for d in self.localdirs:
- shutil.rmtree(d, True)
+ def _merged_items(self, index):
+ self.data = {}
+ limit = self._next_limit()
+ for j in range(self.spills):
+ path = self._get_spill_dir(j)
+ p = os.path.join(path, str(index))
+ # do not check memory during merging
+ self.mergeCombiners(self.serializer.load_stream(open(p)), 0)
+
+ # limit the total partitions
+ if (self.scale * self.partitions < self.MAX_TOTAL_PARTITIONS
+ and j < self.spills - 1
+ and get_used_memory() > limit):
+ self.data.clear() # will read from disk again
+ gc.collect() # release the memory as much as possible
+ return self._recursive_merged_items(index)
- def _recursive_merged_items(self, start):
+ return self.data.iteritems()
+
+ def _recursive_merged_items(self, index):
"""
merge the partitioned items and return the as iterator
If one partition can not be fit in memory, then them will be
partitioned and merged recursively.
"""
- # make sure all the data are dumps into disks.
- assert not self.data
- if any(self.pdata):
- self._spill()
- assert self.spills > 0
-
- for i in range(start, self.partitions):
- subdirs = [os.path.join(d, "parts", str(i))
- for d in self.localdirs]
- m = ExternalMerger(self.agg, self.memory_limit, self.serializer,
- subdirs, self.scale * self.partitions, self.partitions)
- m.pdata = [{} for _ in range(self.partitions)]
- limit = self._next_limit()
-
- for j in range(self.spills):
- path = self._get_spill_dir(j)
- p = os.path.join(path, str(i))
- m._partitioned_mergeCombiners(
- self.serializer.load_stream(open(p)))
-
- if get_used_memory() > limit:
- m._spill()
- limit = self._next_limit()
+ subdirs = [os.path.join(d, "parts", str(index)) for d in self.localdirs]
+ m = ExternalMerger(self.agg, self.memory_limit, self.serializer, subdirs,
+ self.scale * self.partitions, self.partitions, self.batch)
+ m.pdata = [{} for _ in range(self.partitions)]
+ limit = self._next_limit()
+
+ for j in range(self.spills):
+ path = self._get_spill_dir(j)
+ p = os.path.join(path, str(index))
+ m.mergeCombiners(self.serializer.load_stream(open(p)), 0)
+
+ if get_used_memory() > limit:
+ m._spill()
+ limit = self._next_limit()
- for v in m._external_items():
- yield v
+ return m._external_items()
- # remove the merged partition
- for j in range(self.spills):
- path = self._get_spill_dir(j)
- os.remove(os.path.join(path, str(i)))
+ def _cleanup(self):
+ """ Clean up all the files in disks """
+ for d in self.localdirs:
+ shutil.rmtree(d, True)
class ExternalSorter(object):
@@ -457,6 +448,7 @@ class ExternalSorter(object):
The spilling will only happen when the used memory goes above
the limit.
+
>>> sorter = ExternalSorter(1) # 1M
>>> import random
>>> l = range(1024)
@@ -469,7 +461,7 @@ class ExternalSorter(object):
def __init__(self, memory_limit, serializer=None):
self.memory_limit = memory_limit
self.local_dirs = _get_local_dirs("sort")
- self.serializer = serializer or AutoBatchedSerializer(PickleSerializer())
+ self.serializer = _compressed_serializer(serializer)
def _get_path(self, n):
""" Choose one directory for spill by number n """
@@ -515,6 +507,7 @@ class ExternalSorter(object):
limit = self._next_limit()
MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
DiskBytesSpilled += os.path.getsize(path)
+ os.unlink(path) # data will be deleted after close
elif not chunks:
batch = min(batch * 2, 10000)
@@ -529,6 +522,310 @@ class ExternalSorter(object):
return heapq.merge(chunks, key=key, reverse=reverse)
+class ExternalList(object):
+ """
+ ExternalList can have many items which cannot be hold in memory in
+ the same time.
+
+ >>> l = ExternalList(range(100))
+ >>> len(l)
+ 100
+ >>> l.append(10)
+ >>> len(l)
+ 101
+ >>> for i in range(20240):
+ ... l.append(i)
+ >>> len(l)
+ 20341
+ >>> import pickle
+ >>> l2 = pickle.loads(pickle.dumps(l))
+ >>> len(l2)
+ 20341
+ >>> list(l2)[100]
+ 10
+ """
+ LIMIT = 10240
+
+ def __init__(self, values):
+ self.values = values
+ self.count = len(values)
+ self._file = None
+ self._ser = None
+
+ def __getstate__(self):
+ if self._file is not None:
+ self._file.flush()
+ f = os.fdopen(os.dup(self._file.fileno()))
+ f.seek(0)
+ serialized = f.read()
+ else:
+ serialized = ''
+ return self.values, self.count, serialized
+
+ def __setstate__(self, item):
+ self.values, self.count, serialized = item
+ if serialized:
+ self._open_file()
+ self._file.write(serialized)
+ else:
+ self._file = None
+ self._ser = None
+
+ def __iter__(self):
+ if self._file is not None:
+ self._file.flush()
+ # read all items from disks first
+ with os.fdopen(os.dup(self._file.fileno()), 'r') as f:
+ f.seek(0)
+ for v in self._ser.load_stream(f):
+ yield v
+
+ for v in self.values:
+ yield v
+
+ def __len__(self):
+ return self.count
+
+ def append(self, value):
+ self.values.append(value)
+ self.count += 1
+ # dump them into disk if the key is huge
+ if len(self.values) >= self.LIMIT:
+ self._spill()
+
+ def _open_file(self):
+ dirs = _get_local_dirs("objects")
+ d = dirs[id(self) % len(dirs)]
+ if not os.path.exists(d):
+ os.makedirs(d)
+ p = os.path.join(d, str(id))
+ self._file = open(p, "w+", 65536)
+ self._ser = BatchedSerializer(CompressedSerializer(PickleSerializer()), 1024)
+ os.unlink(p)
+
+ def _spill(self):
+ """ dump the values into disk """
+ global MemoryBytesSpilled, DiskBytesSpilled
+ if self._file is None:
+ self._open_file()
+
+ used_memory = get_used_memory()
+ pos = self._file.tell()
+ self._ser.dump_stream(self.values, self._file)
+ self.values = []
+ gc.collect()
+ DiskBytesSpilled += self._file.tell() - pos
+ MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
+
+
+class ExternalListOfList(ExternalList):
+ """
+ An external list for list.
+
+ >>> l = ExternalListOfList([[i, i] for i in range(100)])
+ >>> len(l)
+ 200
+ >>> l.append(range(10))
+ >>> len(l)
+ 210
+ >>> len(list(l))
+ 210
+ """
+
+ def __init__(self, values):
+ ExternalList.__init__(self, values)
+ self.count = sum(len(i) for i in values)
+
+ def append(self, value):
+ ExternalList.append(self, value)
+ # already counted 1 in ExternalList.append
+ self.count += len(value) - 1
+
+ def __iter__(self):
+ for values in ExternalList.__iter__(self):
+ for v in values:
+ yield v
+
+
+class GroupByKey(object):
+ """
+ Group a sorted iterator as [(k1, it1), (k2, it2), ...]
+
+ >>> k = [i/3 for i in range(6)]
+ >>> v = [[i] for i in range(6)]
+ >>> g = GroupByKey(iter(zip(k, v)))
+ >>> [(k, list(it)) for k, it in g]
+ [(0, [0, 1, 2]), (1, [3, 4, 5])]
+ """
+
+ def __init__(self, iterator):
+ self.iterator = iter(iterator)
+ self.next_item = None
+
+ def __iter__(self):
+ return self
+
+ def next(self):
+ key, value = self.next_item if self.next_item else next(self.iterator)
+ values = ExternalListOfList([value])
+ try:
+ while True:
+ k, v = next(self.iterator)
+ if k != key:
+ self.next_item = (k, v)
+ break
+ values.append(v)
+ except StopIteration:
+ self.next_item = None
+ return key, values
+
+
+class ExternalGroupBy(ExternalMerger):
+
+ """
+ Group by the items by key. If any partition of them can not been
+ hold in memory, it will do sort based group by.
+
+ This class works as follows:
+
+ - It repeatedly group the items by key and save them in one dict in
+ memory.
+
+ - When the used memory goes above memory limit, it will split
+ the combined data into partitions by hash code, dump them
+ into disk, one file per partition. If the number of keys
+ in one partitions is smaller than 1000, it will sort them
+ by key before dumping into disk.
+
+ - Then it goes through the rest of the iterator, group items
+ by key into different dict by hash. Until the used memory goes over
+ memory limit, it dump all the dicts into disks, one file per
+ dict. Repeat this again until combine all the items. It
+ also will try to sort the items by key in each partition
+ before dumping into disks.
+
+ - It will yield the grouped items partitions by partitions.
+ If the data in one partitions can be hold in memory, then it
+ will load and combine them in memory and yield.
+
+ - If the dataset in one partition cannot be hold in memory,
+ it will sort them first. If all the files are already sorted,
+ it merge them by heap.merge(), so it will do external sort
+ for all the files.
+
+ - After sorting, `GroupByKey` class will put all the continuous
+ items with the same key as a group, yield the values as
+ an iterator.
+ """
+ SORT_KEY_LIMIT = 1000
+
+ def flattened_serializer(self):
+ assert isinstance(self.serializer, BatchedSerializer)
+ ser = self.serializer
+ return FlattenedValuesSerializer(ser, 20)
+
+ def _object_size(self, obj):
+ return len(obj)
+
+ def _spill(self):
+ """
+ dump already partitioned data into disks.
+ """
+ global MemoryBytesSpilled, DiskBytesSpilled
+ path = self._get_spill_dir(self.spills)
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+ used_memory = get_used_memory()
+ if not self.pdata:
+ # The data has not been partitioned, it will iterator the
+ # data once, write them into different files, has no
+ # additional memory. It only called when the memory goes
+ # above limit at the first time.
+
+ # open all the files for writing
+ streams = [open(os.path.join(path, str(i)), 'w')
+ for i in range(self.partitions)]
+
+ # If the number of keys is small, then the overhead of sort is small
+ # sort them before dumping into disks
+ self._sorted = len(self.data) < self.SORT_KEY_LIMIT
+ if self._sorted:
+ self.serializer = self.flattened_serializer()
+ for k in sorted(self.data.keys()):
+ h = self._partition(k)
+ self.serializer.dump_stream([(k, self.data[k])], streams[h])
+ else:
+ for k, v in self.data.iteritems():
+ h = self._partition(k)
+ self.serializer.dump_stream([(k, v)], streams[h])
+
+ for s in streams:
+ DiskBytesSpilled += s.tell()
+ s.close()
+
+ self.data.clear()
+ # self.pdata is cached in `mergeValues` and `mergeCombiners`
+ self.pdata.extend([{} for i in range(self.partitions)])
+
+ else:
+ for i in range(self.partitions):
+ p = os.path.join(path, str(i))
+ with open(p, "w") as f:
+ # dump items in batch
+ if self._sorted:
+ # sort by key only (stable)
+ sorted_items = sorted(self.pdata[i].iteritems(), key=operator.itemgetter(0))
+ self.serializer.dump_stream(sorted_items, f)
+ else:
+ self.serializer.dump_stream(self.pdata[i].iteritems(), f)
+ self.pdata[i].clear()
+ DiskBytesSpilled += os.path.getsize(p)
+
+ self.spills += 1
+ gc.collect() # release the memory as much as possible
+ MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
+
+ def _merged_items(self, index):
+ size = sum(os.path.getsize(os.path.join(self._get_spill_dir(j), str(index)))
+ for j in range(self.spills))
+ # if the memory can not hold all the partition,
+ # then use sort based merge. Because of compression,
+ # the data on disks will be much smaller than needed memory
+ if (size >> 20) >= self.memory_limit / 10:
+ return self._merge_sorted_items(index)
+
+ self.data = {}
+ for j in range(self.spills):
+ path = self._get_spill_dir(j)
+ p = os.path.join(path, str(index))
+ # do not check memory during merging
+ self.mergeCombiners(self.serializer.load_stream(open(p)), 0)
+ return self.data.iteritems()
+
+ def _merge_sorted_items(self, index):
+ """ load a partition from disk, then sort and group by key """
+ def load_partition(j):
+ path = self._get_spill_dir(j)
+ p = os.path.join(path, str(index))
+ return self.serializer.load_stream(open(p, 'r', 65536))
+
+ disk_items = [load_partition(j) for j in range(self.spills)]
+
+ if self._sorted:
+ # all the partitions are already sorted
+ sorted_items = heapq.merge(disk_items, key=operator.itemgetter(0))
+
+ else:
+ # Flatten the combined values, so it will not consume huge
+ # memory during merging sort.
+ ser = self.flattened_serializer()
+ sorter = ExternalSorter(self.memory_limit, ser)
+ sorted_items = sorter.sorted(itertools.chain(*disk_items),
+ key=operator.itemgetter(0))
+ return ((k, vs) for k, vs in GroupByKey(sorted_items))
+
+
if __name__ == "__main__":
import doctest
doctest.testmod()
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index dd8d3b1c53..0bd5d20f78 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -31,6 +31,7 @@ import tempfile
import time
import zipfile
import random
+import itertools
import threading
import hashlib
@@ -76,7 +77,7 @@ SPARK_HOME = os.environ["SPARK_HOME"]
class MergerTests(unittest.TestCase):
def setUp(self):
- self.N = 1 << 14
+ self.N = 1 << 12
self.l = [i for i in xrange(self.N)]
self.data = zip(self.l, self.l)
self.agg = Aggregator(lambda x: [x],
@@ -108,7 +109,7 @@ class MergerTests(unittest.TestCase):
sum(xrange(self.N)))
def test_medium_dataset(self):
- m = ExternalMerger(self.agg, 10)
+ m = ExternalMerger(self.agg, 30)
m.mergeValues(self.data)
self.assertTrue(m.spills >= 1)
self.assertEqual(sum(sum(v) for k, v in m.iteritems()),
@@ -124,10 +125,36 @@ class MergerTests(unittest.TestCase):
m = ExternalMerger(self.agg, 10, partitions=3)
m.mergeCombiners(map(lambda (k, v): (k, [str(v)]), self.data * 10))
self.assertTrue(m.spills >= 1)
- self.assertEqual(sum(len(v) for k, v in m._recursive_merged_items(0)),
+ self.assertEqual(sum(len(v) for k, v in m.iteritems()),
self.N * 10)
m._cleanup()
+ def test_group_by_key(self):
+
+ def gen_data(N, step):
+ for i in range(1, N + 1, step):
+ for j in range(i):
+ yield (i, [j])
+
+ def gen_gs(N, step=1):
+ return shuffle.GroupByKey(gen_data(N, step))
+
+ self.assertEqual(1, len(list(gen_gs(1))))
+ self.assertEqual(2, len(list(gen_gs(2))))
+ self.assertEqual(100, len(list(gen_gs(100))))
+ self.assertEqual(range(1, 101), [k for k, _ in gen_gs(100)])
+ self.assertTrue(all(range(k) == list(vs) for k, vs in gen_gs(100)))
+
+ for k, vs in gen_gs(50002, 10000):
+ self.assertEqual(k, len(vs))
+ self.assertEqual(range(k), list(vs))
+
+ ser = PickleSerializer()
+ l = ser.loads(ser.dumps(list(gen_gs(50002, 30000))))
+ for k, vs in l:
+ self.assertEqual(k, len(vs))
+ self.assertEqual(range(k), list(vs))
+
class SorterTests(unittest.TestCase):
def test_in_memory_sort(self):
@@ -702,6 +729,21 @@ class RDDTests(ReusedPySparkTestCase):
self.assertEquals(result.getNumPartitions(), 5)
self.assertEquals(result.count(), 3)
+ def test_external_group_by_key(self):
+ self.sc._conf.set("spark.python.worker.memory", "5m")
+ N = 200001
+ kv = self.sc.parallelize(range(N)).map(lambda x: (x % 3, x))
+ gkv = kv.groupByKey().cache()
+ self.assertEqual(3, gkv.count())
+ filtered = gkv.filter(lambda (k, vs): k == 1)
+ self.assertEqual(1, filtered.count())
+ self.assertEqual([(1, N/3)], filtered.mapValues(len).collect())
+ self.assertEqual([(N/3, N/3)],
+ filtered.values().map(lambda x: (len(x), len(list(x)))).collect())
+ result = filtered.collect()[0][1]
+ self.assertEqual(N/3, len(result))
+ self.assertTrue(isinstance(result.data, shuffle.ExternalList))
+
def test_sort_on_empty_rdd(self):
self.assertEqual([], self.sc.parallelize(zip([], [])).sortByKey().collect())
@@ -752,9 +794,9 @@ class RDDTests(ReusedPySparkTestCase):
self.assertEqual(rdd.getNumPartitions() + 2, parted.union(rdd).getNumPartitions())
self.assertEqual(rdd.getNumPartitions() + 2, rdd.union(parted).getNumPartitions())
- self.sc.setJobGroup("test1", "test", True)
tracker = self.sc.statusTracker()
+ self.sc.setJobGroup("test1", "test", True)
d = sorted(parted.join(parted).collect())
self.assertEqual(10, len(d))
self.assertEqual((0, (0, 0)), d[0])