aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala2
-rw-r--r--docs/configuration.md9
-rw-r--r--python/epydoc.conf2
-rw-r--r--python/pyspark/rdd.py92
-rw-r--r--python/pyspark/serializers.py29
-rw-r--r--python/pyspark/shuffle.py439
-rw-r--r--python/pyspark/tests.py57
-rwxr-xr-xpython/run-tests1
9 files changed, 611 insertions, 25 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 462e09466b..d6b0988641 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -57,7 +57,10 @@ private[spark] class PythonRDD[T: ClassTag](
override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
val startTime = System.currentTimeMillis
val env = SparkEnv.get
- val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap)
+ val localdir = env.blockManager.diskBlockManager.localDirs.map(
+ f => f.getPath()).mkString(",")
+ val worker: Socket = env.createPythonWorker(pythonExec,
+ envVars.toMap + ("SPARK_LOCAL_DIR" -> localdir))
// Start a thread to feed the process input from our parent's iterator
val writerThread = new WriterThread(env, worker, split, context)
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
index 673fc19c06..2e7ed7538e 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
@@ -43,7 +43,7 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD
/* Create one local directory for each path mentioned in spark.local.dir; then, inside this
* directory, create multiple subdirectories that we will hash files into, in order to avoid
* having really large inodes at the top level. */
- private val localDirs: Array[File] = createLocalDirs()
+ val localDirs: Array[File] = createLocalDirs()
if (localDirs.isEmpty) {
logError("Failed to create any local dir.")
System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR)
diff --git a/docs/configuration.md b/docs/configuration.md
index cb0c65e2d2..dac8bb1d52 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -197,6 +197,15 @@ Apart from these, the following properties are also available, and may be useful
Spark's dependencies and user dependencies. It is currently an experimental feature.
</td>
</tr>
+<tr>
+ <td><code>spark.python.worker.memory</code></td>
+ <td>512m</td>
+ <td>
+ Amount of memory to use per python worker process during aggregation, in the same
+ format as JVM memory strings (e.g. <code>512m</code>, <code>2g</code>). If the memory
+ used during aggregation goes above this amount, it will spill the data into disks.
+ </td>
+</tr>
</table>
#### Shuffle Behavior
diff --git a/python/epydoc.conf b/python/epydoc.conf
index b73860bad8..51c0faf359 100644
--- a/python/epydoc.conf
+++ b/python/epydoc.conf
@@ -35,4 +35,4 @@ private: no
exclude: pyspark.cloudpickle pyspark.worker pyspark.join
pyspark.java_gateway pyspark.examples pyspark.shell pyspark.tests
pyspark.rddsampler pyspark.daemon pyspark.mllib._common
- pyspark.mllib.tests
+ pyspark.mllib.tests pyspark.shuffle
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index a38dd0b923..7ad6108261 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -42,6 +42,8 @@ from pyspark.statcounter import StatCounter
from pyspark.rddsampler import RDDSampler
from pyspark.storagelevel import StorageLevel
from pyspark.resultiterable import ResultIterable
+from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \
+ get_used_memory
from py4j.java_collections import ListConverter, MapConverter
@@ -197,6 +199,22 @@ class MaxHeapQ(object):
self._sink(1)
+def _parse_memory(s):
+ """
+ Parse a memory string in the format supported by Java (e.g. 1g, 200m) and
+ return the value in MB
+
+ >>> _parse_memory("256m")
+ 256
+ >>> _parse_memory("2g")
+ 2048
+ """
+ units = {'g': 1024, 'm': 1, 't': 1 << 20, 'k': 1.0 / 1024}
+ if s[-1] not in units:
+ raise ValueError("invalid format: " + s)
+ return int(float(s[:-1]) * units[s[-1].lower()])
+
+
class RDD(object):
"""
@@ -1207,20 +1225,49 @@ class RDD(object):
if numPartitions is None:
numPartitions = self._defaultReducePartitions()
- # Transferring O(n) objects to Java is too expensive. Instead, we'll
- # form the hash buckets in Python, transferring O(numPartitions) objects
- # to Java. Each object is a (splitNumber, [objects]) pair.
+ # Transferring O(n) objects to Java is too expensive.
+ # Instead, we'll form the hash buckets in Python,
+ # transferring O(numPartitions) objects to Java.
+ # Each object is a (splitNumber, [objects]) pair.
+ # In order to avoid too huge objects, the objects are
+ # grouped into chunks.
outputSerializer = self.ctx._unbatched_serializer
+ limit = (_parse_memory(self.ctx._conf.get(
+ "spark.python.worker.memory", "512m")) / 2)
+
def add_shuffle_key(split, iterator):
buckets = defaultdict(list)
+ c, batch = 0, min(10 * numPartitions, 1000)
for (k, v) in iterator:
buckets[partitionFunc(k) % numPartitions].append((k, v))
+ c += 1
+
+ # check used memory and avg size of chunk of objects
+ if (c % 1000 == 0 and get_used_memory() > limit
+ or c > batch):
+ n, size = len(buckets), 0
+ for split in buckets.keys():
+ yield pack_long(split)
+ d = outputSerializer.dumps(buckets[split])
+ del buckets[split]
+ yield d
+ size += len(d)
+
+ avg = (size / n) >> 20
+ # let 1M < avg < 10M
+ if avg < 1:
+ batch *= 1.5
+ elif avg > 10:
+ batch = max(batch / 1.5, 1)
+ c = 0
+
for (split, items) in buckets.iteritems():
yield pack_long(split)
yield outputSerializer.dumps(items)
+
keyed = PipelinedRDD(self, add_shuffle_key)
keyed._bypass_serializer = True
with _JavaStackTrace(self.context) as st:
@@ -1230,8 +1277,8 @@ class RDD(object):
id(partitionFunc))
jrdd = pairRDD.partitionBy(partitioner).values()
rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer))
- # This is required so that id(partitionFunc) remains unique, even if
- # partitionFunc is a lambda:
+ # This is required so that id(partitionFunc) remains unique,
+ # even if partitionFunc is a lambda:
rdd._partitionFunc = partitionFunc
return rdd
@@ -1265,26 +1312,28 @@ class RDD(object):
if numPartitions is None:
numPartitions = self._defaultReducePartitions()
+ serializer = self.ctx.serializer
+ spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower()
+ == 'true')
+ memory = _parse_memory(self.ctx._conf.get(
+ "spark.python.worker.memory", "512m"))
+ agg = Aggregator(createCombiner, mergeValue, mergeCombiners)
+
def combineLocally(iterator):
- combiners = {}
- for x in iterator:
- (k, v) = x
- if k not in combiners:
- combiners[k] = createCombiner(v)
- else:
- combiners[k] = mergeValue(combiners[k], v)
- return combiners.iteritems()
+ merger = ExternalMerger(agg, memory * 0.9, serializer) \
+ if spill else InMemoryMerger(agg)
+ merger.mergeValues(iterator)
+ return merger.iteritems()
+
locally_combined = self.mapPartitions(combineLocally)
shuffled = locally_combined.partitionBy(numPartitions)
def _mergeCombiners(iterator):
- combiners = {}
- for (k, v) in iterator:
- if k not in combiners:
- combiners[k] = v
- else:
- combiners[k] = mergeCombiners(combiners[k], v)
- return combiners.iteritems()
+ merger = ExternalMerger(agg, memory, serializer) \
+ if spill else InMemoryMerger(agg)
+ merger.mergeCombiners(iterator)
+ return merger.iteritems()
+
return shuffled.mapPartitions(_mergeCombiners)
def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None):
@@ -1343,7 +1392,8 @@ class RDD(object):
return xs
def mergeCombiners(a, b):
- return a + b
+ a.extend(b)
+ return a
return self.combineByKey(createCombiner, mergeValue, mergeCombiners,
numPartitions).mapValues(lambda x: ResultIterable(x))
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 9be78b39fb..03b31ae962 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -193,7 +193,7 @@ class BatchedSerializer(Serializer):
return chain.from_iterable(self._load_stream_without_unbatching(stream))
def _load_stream_without_unbatching(self, stream):
- return self.serializer.load_stream(stream)
+ return self.serializer.load_stream(stream)
def __eq__(self, other):
return (isinstance(other, BatchedSerializer) and
@@ -302,6 +302,33 @@ class MarshalSerializer(FramedSerializer):
loads = marshal.loads
+class AutoSerializer(FramedSerializer):
+ """
+ Choose marshal or cPickle as serialization protocol autumatically
+ """
+ def __init__(self):
+ FramedSerializer.__init__(self)
+ self._type = None
+
+ def dumps(self, obj):
+ if self._type is not None:
+ return 'P' + cPickle.dumps(obj, -1)
+ try:
+ return 'M' + marshal.dumps(obj)
+ except Exception:
+ self._type = 'P'
+ return 'P' + cPickle.dumps(obj, -1)
+
+ def loads(self, obj):
+ _type = obj[0]
+ if _type == 'M':
+ return marshal.loads(obj[1:])
+ elif _type == 'P':
+ return cPickle.loads(obj[1:])
+ else:
+ raise ValueError("invalid sevialization type: %s" % _type)
+
+
class UTF8Deserializer(Serializer):
"""
Deserializes streams written by String.getBytes.
diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py
new file mode 100644
index 0000000000..e3923d1c36
--- /dev/null
+++ b/python/pyspark/shuffle.py
@@ -0,0 +1,439 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import sys
+import platform
+import shutil
+import warnings
+import gc
+
+from pyspark.serializers import BatchedSerializer, PickleSerializer
+
+try:
+ import psutil
+
+ def get_used_memory():
+ """ Return the used memory in MB """
+ process = psutil.Process(os.getpid())
+ if hasattr(process, "memory_info"):
+ info = process.memory_info()
+ else:
+ info = process.get_memory_info()
+ return info.rss >> 20
+except ImportError:
+
+ def get_used_memory():
+ """ Return the used memory in MB """
+ if platform.system() == 'Linux':
+ for line in open('/proc/self/status'):
+ if line.startswith('VmRSS:'):
+ return int(line.split()[1]) >> 10
+ else:
+ warnings.warn("Please install psutil to have better "
+ "support with spilling")
+ if platform.system() == "Darwin":
+ import resource
+ rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
+ return rss >> 20
+ # TODO: support windows
+ return 0
+
+
+class Aggregator(object):
+
+ """
+ Aggregator has tree functions to merge values into combiner.
+
+ createCombiner: (value) -> combiner
+ mergeValue: (combine, value) -> combiner
+ mergeCombiners: (combiner, combiner) -> combiner
+ """
+
+ def __init__(self, createCombiner, mergeValue, mergeCombiners):
+ self.createCombiner = createCombiner
+ self.mergeValue = mergeValue
+ self.mergeCombiners = mergeCombiners
+
+
+class SimpleAggregator(Aggregator):
+
+ """
+ SimpleAggregator is useful for the cases that combiners have
+ same type with values
+ """
+
+ def __init__(self, combiner):
+ Aggregator.__init__(self, lambda x: x, combiner, combiner)
+
+
+class Merger(object):
+
+ """
+ Merge shuffled data together by aggregator
+ """
+
+ def __init__(self, aggregator):
+ self.agg = aggregator
+
+ def mergeValues(self, iterator):
+ """ Combine the items by creator and combiner """
+ raise NotImplementedError
+
+ def mergeCombiners(self, iterator):
+ """ Merge the combined items by mergeCombiner """
+ raise NotImplementedError
+
+ def iteritems(self):
+ """ Return the merged items ad iterator """
+ raise NotImplementedError
+
+
+class InMemoryMerger(Merger):
+
+ """
+ In memory merger based on in-memory dict.
+ """
+
+ def __init__(self, aggregator):
+ Merger.__init__(self, aggregator)
+ self.data = {}
+
+ def mergeValues(self, iterator):
+ """ Combine the items by creator and combiner """
+ # speed up attributes lookup
+ d, creator = self.data, self.agg.createCombiner
+ comb = self.agg.mergeValue
+ for k, v in iterator:
+ d[k] = comb(d[k], v) if k in d else creator(v)
+
+ def mergeCombiners(self, iterator):
+ """ Merge the combined items by mergeCombiner """
+ # speed up attributes lookup
+ d, comb = self.data, self.agg.mergeCombiners
+ for k, v in iterator:
+ d[k] = comb(d[k], v) if k in d else v
+
+ def iteritems(self):
+ """ Return the merged items ad iterator """
+ return self.data.iteritems()
+
+
+class ExternalMerger(Merger):
+
+ """
+ External merger will dump the aggregated data into disks when
+ memory usage goes above the limit, then merge them together.
+
+ This class works as follows:
+
+ - It repeatedly combine the items and save them in one dict in
+ memory.
+
+ - When the used memory goes above memory limit, it will split
+ the combined data into partitions by hash code, dump them
+ into disk, one file per partition.
+
+ - Then it goes through the rest of the iterator, combine items
+ into different dict by hash. Until the used memory goes over
+ memory limit, it dump all the dicts into disks, one file per
+ dict. Repeat this again until combine all the items.
+
+ - Before return any items, it will load each partition and
+ combine them seperately. Yield them before loading next
+ partition.
+
+ - During loading a partition, if the memory goes over limit,
+ it will partition the loaded data and dump them into disks
+ and load them partition by partition again.
+
+ `data` and `pdata` are used to hold the merged items in memory.
+ At first, all the data are merged into `data`. Once the used
+ memory goes over limit, the items in `data` are dumped indo
+ disks, `data` will be cleared, all rest of items will be merged
+ into `pdata` and then dumped into disks. Before returning, all
+ the items in `pdata` will be dumped into disks.
+
+ Finally, if any items were spilled into disks, each partition
+ will be merged into `data` and be yielded, then cleared.
+
+ >>> agg = SimpleAggregator(lambda x, y: x + y)
+ >>> merger = ExternalMerger(agg, 10)
+ >>> N = 10000
+ >>> merger.mergeValues(zip(xrange(N), xrange(N)) * 10)
+ >>> assert merger.spills > 0
+ >>> sum(v for k,v in merger.iteritems())
+ 499950000
+
+ >>> merger = ExternalMerger(agg, 10)
+ >>> merger.mergeCombiners(zip(xrange(N), xrange(N)) * 10)
+ >>> assert merger.spills > 0
+ >>> sum(v for k,v in merger.iteritems())
+ 499950000
+ """
+
+ # the max total partitions created recursively
+ MAX_TOTAL_PARTITIONS = 4096
+
+ def __init__(self, aggregator, memory_limit=512, serializer=None,
+ localdirs=None, scale=1, partitions=59, batch=1000):
+ Merger.__init__(self, aggregator)
+ self.memory_limit = memory_limit
+ # default serializer is only used for tests
+ self.serializer = serializer or \
+ BatchedSerializer(PickleSerializer(), 1024)
+ self.localdirs = localdirs or self._get_dirs()
+ # number of partitions when spill data into disks
+ self.partitions = partitions
+ # check the memory after # of items merged
+ self.batch = batch
+ # scale is used to scale down the hash of key for recursive hash map
+ self.scale = scale
+ # unpartitioned merged data
+ self.data = {}
+ # partitioned merged data, list of dicts
+ self.pdata = []
+ # number of chunks dumped into disks
+ self.spills = 0
+ # randomize the hash of key, id(o) is the address of o (aligned by 8)
+ self._seed = id(self) + 7
+
+ def _get_dirs(self):
+ """ Get all the directories """
+ path = os.environ.get("SPARK_LOCAL_DIR", "/tmp")
+ dirs = path.split(",")
+ return [os.path.join(d, "python", str(os.getpid()), str(id(self)))
+ for d in dirs]
+
+ def _get_spill_dir(self, n):
+ """ Choose one directory for spill by number n """
+ return os.path.join(self.localdirs[n % len(self.localdirs)], str(n))
+
+ def _next_limit(self):
+ """
+ Return the next memory limit. If the memory is not released
+ after spilling, it will dump the data only when the used memory
+ starts to increase.
+ """
+ return max(self.memory_limit, get_used_memory() * 1.05)
+
+ def mergeValues(self, iterator):
+ """ Combine the items by creator and combiner """
+ iterator = iter(iterator)
+ # speedup attribute lookup
+ creator, comb = self.agg.createCombiner, self.agg.mergeValue
+ d, c, batch = self.data, 0, self.batch
+
+ for k, v in iterator:
+ d[k] = comb(d[k], v) if k in d else creator(v)
+
+ c += 1
+ if c % batch == 0 and get_used_memory() > self.memory_limit:
+ self._spill()
+ self._partitioned_mergeValues(iterator, self._next_limit())
+ break
+
+ def _partition(self, key):
+ """ Return the partition for key """
+ return hash((key, self._seed)) % self.partitions
+
+ def _partitioned_mergeValues(self, iterator, limit=0):
+ """ Partition the items by key, then combine them """
+ # speedup attribute lookup
+ creator, comb = self.agg.createCombiner, self.agg.mergeValue
+ c, pdata, hfun, batch = 0, self.pdata, self._partition, self.batch
+
+ for k, v in iterator:
+ d = pdata[hfun(k)]
+ d[k] = comb(d[k], v) if k in d else creator(v)
+ if not limit:
+ continue
+
+ c += 1
+ if c % batch == 0 and get_used_memory() > limit:
+ self._spill()
+ limit = self._next_limit()
+
+ def mergeCombiners(self, iterator, check=True):
+ """ Merge (K,V) pair by mergeCombiner """
+ iterator = iter(iterator)
+ # speedup attribute lookup
+ d, comb, batch = self.data, self.agg.mergeCombiners, self.batch
+ c = 0
+ for k, v in iterator:
+ d[k] = comb(d[k], v) if k in d else v
+ if not check:
+ continue
+
+ c += 1
+ if c % batch == 0 and get_used_memory() > self.memory_limit:
+ self._spill()
+ self._partitioned_mergeCombiners(iterator, self._next_limit())
+ break
+
+ def _partitioned_mergeCombiners(self, iterator, limit=0):
+ """ Partition the items by key, then merge them """
+ comb, pdata = self.agg.mergeCombiners, self.pdata
+ c, hfun = 0, self._partition
+ for k, v in iterator:
+ d = pdata[hfun(k)]
+ d[k] = comb(d[k], v) if k in d else v
+ if not limit:
+ continue
+
+ c += 1
+ if c % self.batch == 0 and get_used_memory() > limit:
+ self._spill()
+ limit = self._next_limit()
+
+ def _spill(self):
+ """
+ dump already partitioned data into disks.
+
+ It will dump the data in batch for better performance.
+ """
+ path = self._get_spill_dir(self.spills)
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+ if not self.pdata:
+ # The data has not been partitioned, it will iterator the
+ # dataset once, write them into different files, has no
+ # additional memory. It only called when the memory goes
+ # above limit at the first time.
+
+ # open all the files for writing
+ streams = [open(os.path.join(path, str(i)), 'w')
+ for i in range(self.partitions)]
+
+ for k, v in self.data.iteritems():
+ h = self._partition(k)
+ # put one item in batch, make it compatitable with load_stream
+ # it will increase the memory if dump them in batch
+ self.serializer.dump_stream([(k, v)], streams[h])
+
+ for s in streams:
+ s.close()
+
+ self.data.clear()
+ self.pdata = [{} for i in range(self.partitions)]
+
+ else:
+ for i in range(self.partitions):
+ p = os.path.join(path, str(i))
+ with open(p, "w") as f:
+ # dump items in batch
+ self.serializer.dump_stream(self.pdata[i].iteritems(), f)
+ self.pdata[i].clear()
+
+ self.spills += 1
+ gc.collect() # release the memory as much as possible
+
+ def iteritems(self):
+ """ Return all merged items as iterator """
+ if not self.pdata and not self.spills:
+ return self.data.iteritems()
+ return self._external_items()
+
+ def _external_items(self):
+ """ Return all partitioned items as iterator """
+ assert not self.data
+ if any(self.pdata):
+ self._spill()
+ hard_limit = self._next_limit()
+
+ try:
+ for i in range(self.partitions):
+ self.data = {}
+ for j in range(self.spills):
+ path = self._get_spill_dir(j)
+ p = os.path.join(path, str(i))
+ # do not check memory during merging
+ self.mergeCombiners(self.serializer.load_stream(open(p)),
+ False)
+
+ # limit the total partitions
+ if (self.scale * self.partitions < self.MAX_TOTAL_PARTITIONS
+ and j < self.spills - 1
+ and get_used_memory() > hard_limit):
+ self.data.clear() # will read from disk again
+ gc.collect() # release the memory as much as possible
+ for v in self._recursive_merged_items(i):
+ yield v
+ return
+
+ for v in self.data.iteritems():
+ yield v
+ self.data.clear()
+ gc.collect()
+
+ # remove the merged partition
+ for j in range(self.spills):
+ path = self._get_spill_dir(j)
+ os.remove(os.path.join(path, str(i)))
+
+ finally:
+ self._cleanup()
+
+ def _cleanup(self):
+ """ Clean up all the files in disks """
+ for d in self.localdirs:
+ shutil.rmtree(d, True)
+
+ def _recursive_merged_items(self, start):
+ """
+ merge the partitioned items and return the as iterator
+
+ If one partition can not be fit in memory, then them will be
+ partitioned and merged recursively.
+ """
+ # make sure all the data are dumps into disks.
+ assert not self.data
+ if any(self.pdata):
+ self._spill()
+ assert self.spills > 0
+
+ for i in range(start, self.partitions):
+ subdirs = [os.path.join(d, "parts", str(i))
+ for d in self.localdirs]
+ m = ExternalMerger(self.agg, self.memory_limit, self.serializer,
+ subdirs, self.scale * self.partitions)
+ m.pdata = [{} for _ in range(self.partitions)]
+ limit = self._next_limit()
+
+ for j in range(self.spills):
+ path = self._get_spill_dir(j)
+ p = os.path.join(path, str(i))
+ m._partitioned_mergeCombiners(
+ self.serializer.load_stream(open(p)))
+
+ if get_used_memory() > limit:
+ m._spill()
+ limit = self._next_limit()
+
+ for v in m._external_items():
+ yield v
+
+ # remove the merged partition
+ for j in range(self.spills):
+ path = self._get_spill_dir(j)
+ os.remove(os.path.join(path, str(i)))
+
+
+if __name__ == "__main__":
+ import doctest
+ doctest.testmod()
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 9c5ecd0bb0..a92abbf371 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -34,6 +34,7 @@ import zipfile
from pyspark.context import SparkContext
from pyspark.files import SparkFiles
from pyspark.serializers import read_int
+from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger
_have_scipy = False
try:
@@ -47,6 +48,62 @@ except:
SPARK_HOME = os.environ["SPARK_HOME"]
+class TestMerger(unittest.TestCase):
+
+ def setUp(self):
+ self.N = 1 << 16
+ self.l = [i for i in xrange(self.N)]
+ self.data = zip(self.l, self.l)
+ self.agg = Aggregator(lambda x: [x],
+ lambda x, y: x.append(y) or x,
+ lambda x, y: x.extend(y) or x)
+
+ def test_in_memory(self):
+ m = InMemoryMerger(self.agg)
+ m.mergeValues(self.data)
+ self.assertEqual(sum(sum(v) for k, v in m.iteritems()),
+ sum(xrange(self.N)))
+
+ m = InMemoryMerger(self.agg)
+ m.mergeCombiners(map(lambda (x, y): (x, [y]), self.data))
+ self.assertEqual(sum(sum(v) for k, v in m.iteritems()),
+ sum(xrange(self.N)))
+
+ def test_small_dataset(self):
+ m = ExternalMerger(self.agg, 1000)
+ m.mergeValues(self.data)
+ self.assertEqual(m.spills, 0)
+ self.assertEqual(sum(sum(v) for k, v in m.iteritems()),
+ sum(xrange(self.N)))
+
+ m = ExternalMerger(self.agg, 1000)
+ m.mergeCombiners(map(lambda (x, y): (x, [y]), self.data))
+ self.assertEqual(m.spills, 0)
+ self.assertEqual(sum(sum(v) for k, v in m.iteritems()),
+ sum(xrange(self.N)))
+
+ def test_medium_dataset(self):
+ m = ExternalMerger(self.agg, 10)
+ m.mergeValues(self.data)
+ self.assertTrue(m.spills >= 1)
+ self.assertEqual(sum(sum(v) for k, v in m.iteritems()),
+ sum(xrange(self.N)))
+
+ m = ExternalMerger(self.agg, 10)
+ m.mergeCombiners(map(lambda (x, y): (x, [y]), self.data * 3))
+ self.assertTrue(m.spills >= 1)
+ self.assertEqual(sum(sum(v) for k, v in m.iteritems()),
+ sum(xrange(self.N)) * 3)
+
+ def test_huge_dataset(self):
+ m = ExternalMerger(self.agg, 10)
+ m.mergeCombiners(map(lambda (k, v): (k, [str(v)]), self.data * 10))
+ self.assertTrue(m.spills >= 1)
+ self.assertEqual(sum(len(v) for k, v in m._recursive_merged_items(0)),
+ self.N * 10)
+ m._cleanup()
+
+
class PySparkTestCase(unittest.TestCase):
def setUp(self):
diff --git a/python/run-tests b/python/run-tests
index 9282aa47e8..29f755fc0d 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -61,6 +61,7 @@ run_test "pyspark/broadcast.py"
run_test "pyspark/accumulators.py"
run_test "pyspark/serializers.py"
unset PYSPARK_DOC_TEST
+run_test "pyspark/shuffle.py"
run_test "pyspark/tests.py"
run_test "pyspark/mllib/_common.py"
run_test "pyspark/mllib/classification.py"