aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/rdd.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/rdd.py')
-rw-r--r--python/pyspark/rdd.py92
1 files changed, 71 insertions, 21 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index a38dd0b923..7ad6108261 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -42,6 +42,8 @@ from pyspark.statcounter import StatCounter
from pyspark.rddsampler import RDDSampler
from pyspark.storagelevel import StorageLevel
from pyspark.resultiterable import ResultIterable
+from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \
+ get_used_memory
from py4j.java_collections import ListConverter, MapConverter
@@ -197,6 +199,22 @@ class MaxHeapQ(object):
self._sink(1)
+def _parse_memory(s):
+ """
+ Parse a memory string in the format supported by Java (e.g. 1g, 200m) and
+ return the value in MB
+
+ >>> _parse_memory("256m")
+ 256
+ >>> _parse_memory("2g")
+ 2048
+ """
+ units = {'g': 1024, 'm': 1, 't': 1 << 20, 'k': 1.0 / 1024}
+ if s[-1] not in units:
+ raise ValueError("invalid format: " + s)
+ return int(float(s[:-1]) * units[s[-1].lower()])
+
+
class RDD(object):
"""
@@ -1207,20 +1225,49 @@ class RDD(object):
if numPartitions is None:
numPartitions = self._defaultReducePartitions()
- # Transferring O(n) objects to Java is too expensive. Instead, we'll
- # form the hash buckets in Python, transferring O(numPartitions) objects
- # to Java. Each object is a (splitNumber, [objects]) pair.
+ # Transferring O(n) objects to Java is too expensive.
+ # Instead, we'll form the hash buckets in Python,
+ # transferring O(numPartitions) objects to Java.
+ # Each object is a (splitNumber, [objects]) pair.
+ # In order to avoid too huge objects, the objects are
+ # grouped into chunks.
outputSerializer = self.ctx._unbatched_serializer
+ limit = (_parse_memory(self.ctx._conf.get(
+ "spark.python.worker.memory", "512m")) / 2)
+
def add_shuffle_key(split, iterator):
buckets = defaultdict(list)
+ c, batch = 0, min(10 * numPartitions, 1000)
for (k, v) in iterator:
buckets[partitionFunc(k) % numPartitions].append((k, v))
+ c += 1
+
+ # check used memory and avg size of chunk of objects
+ if (c % 1000 == 0 and get_used_memory() > limit
+ or c > batch):
+ n, size = len(buckets), 0
+ for split in buckets.keys():
+ yield pack_long(split)
+ d = outputSerializer.dumps(buckets[split])
+ del buckets[split]
+ yield d
+ size += len(d)
+
+ avg = (size / n) >> 20
+ # let 1M < avg < 10M
+ if avg < 1:
+ batch *= 1.5
+ elif avg > 10:
+ batch = max(batch / 1.5, 1)
+ c = 0
+
for (split, items) in buckets.iteritems():
yield pack_long(split)
yield outputSerializer.dumps(items)
+
keyed = PipelinedRDD(self, add_shuffle_key)
keyed._bypass_serializer = True
with _JavaStackTrace(self.context) as st:
@@ -1230,8 +1277,8 @@ class RDD(object):
id(partitionFunc))
jrdd = pairRDD.partitionBy(partitioner).values()
rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer))
- # This is required so that id(partitionFunc) remains unique, even if
- # partitionFunc is a lambda:
+ # This is required so that id(partitionFunc) remains unique,
+ # even if partitionFunc is a lambda:
rdd._partitionFunc = partitionFunc
return rdd
@@ -1265,26 +1312,28 @@ class RDD(object):
if numPartitions is None:
numPartitions = self._defaultReducePartitions()
+ serializer = self.ctx.serializer
+ spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower()
+ == 'true')
+ memory = _parse_memory(self.ctx._conf.get(
+ "spark.python.worker.memory", "512m"))
+ agg = Aggregator(createCombiner, mergeValue, mergeCombiners)
+
def combineLocally(iterator):
- combiners = {}
- for x in iterator:
- (k, v) = x
- if k not in combiners:
- combiners[k] = createCombiner(v)
- else:
- combiners[k] = mergeValue(combiners[k], v)
- return combiners.iteritems()
+ merger = ExternalMerger(agg, memory * 0.9, serializer) \
+ if spill else InMemoryMerger(agg)
+ merger.mergeValues(iterator)
+ return merger.iteritems()
+
locally_combined = self.mapPartitions(combineLocally)
shuffled = locally_combined.partitionBy(numPartitions)
def _mergeCombiners(iterator):
- combiners = {}
- for (k, v) in iterator:
- if k not in combiners:
- combiners[k] = v
- else:
- combiners[k] = mergeCombiners(combiners[k], v)
- return combiners.iteritems()
+ merger = ExternalMerger(agg, memory, serializer) \
+ if spill else InMemoryMerger(agg)
+ merger.mergeCombiners(iterator)
+ return merger.iteritems()
+
return shuffled.mapPartitions(_mergeCombiners)
def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None):
@@ -1343,7 +1392,8 @@ class RDD(object):
return xs
def mergeCombiners(a, b):
- return a + b
+ a.extend(b)
+ return a
return self.combineByKey(createCombiner, mergeValue, mergeCombiners,
numPartitions).mapValues(lambda x: ResultIterable(x))