aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/rdd.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/rdd.py')
-rw-r--r--python/pyspark/rdd.py48
1 files changed, 36 insertions, 12 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 2d05611321..1b18789040 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -41,7 +41,7 @@ from pyspark.rddsampler import RDDSampler, RDDRangeSampler, RDDStratifiedSampler
from pyspark.storagelevel import StorageLevel
from pyspark.resultiterable import ResultIterable
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \
- get_used_memory, ExternalSorter
+ get_used_memory, ExternalSorter, ExternalGroupBy
from pyspark.traceback_utils import SCCallSiteSync
from py4j.java_collections import ListConverter, MapConverter
@@ -573,8 +573,8 @@ class RDD(object):
if numPartitions is None:
numPartitions = self._defaultReducePartitions()
- spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower() == 'true')
- memory = _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m"))
+ spill = self._can_spill()
+ memory = self._memory_limit()
serializer = self._jrdd_deserializer
def sortPartition(iterator):
@@ -1699,10 +1699,8 @@ class RDD(object):
numPartitions = self._defaultReducePartitions()
serializer = self.ctx.serializer
- spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower()
- == 'true')
- memory = _parse_memory(self.ctx._conf.get(
- "spark.python.worker.memory", "512m"))
+ spill = self._can_spill()
+ memory = self._memory_limit()
agg = Aggregator(createCombiner, mergeValue, mergeCombiners)
def combineLocally(iterator):
@@ -1755,21 +1753,28 @@ class RDD(object):
return self.combineByKey(lambda v: func(createZero(), v), func, func, numPartitions)
+ def _can_spill(self):
+ return self.ctx._conf.get("spark.shuffle.spill", "True").lower() == "true"
+
+ def _memory_limit(self):
+ return _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m"))
+
# TODO: support variant with custom partitioner
def groupByKey(self, numPartitions=None):
"""
Group the values for each key in the RDD into a single sequence.
- Hash-partitions the resulting RDD with into numPartitions partitions.
+ Hash-partitions the resulting RDD with numPartitions partitions.
Note: If you are grouping in order to perform an aggregation (such as a
sum or average) over each key, using reduceByKey or aggregateByKey will
provide much better performance.
>>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
- >>> map((lambda (x,y): (x, list(y))), sorted(x.groupByKey().collect()))
+ >>> sorted(x.groupByKey().mapValues(len).collect())
+ [('a', 2), ('b', 1)]
+ >>> sorted(x.groupByKey().mapValues(list).collect())
[('a', [1, 1]), ('b', [1])]
"""
-
def createCombiner(x):
return [x]
@@ -1781,8 +1786,27 @@ class RDD(object):
a.extend(b)
return a
- return self.combineByKey(createCombiner, mergeValue, mergeCombiners,
- numPartitions).mapValues(lambda x: ResultIterable(x))
+ spill = self._can_spill()
+ memory = self._memory_limit()
+ serializer = self._jrdd_deserializer
+ agg = Aggregator(createCombiner, mergeValue, mergeCombiners)
+
+ def combine(iterator):
+ merger = ExternalMerger(agg, memory * 0.9, serializer) \
+ if spill else InMemoryMerger(agg)
+ merger.mergeValues(iterator)
+ return merger.iteritems()
+
+ locally_combined = self.mapPartitions(combine, preservesPartitioning=True)
+ shuffled = locally_combined.partitionBy(numPartitions)
+
+ def groupByKey(it):
+ merger = ExternalGroupBy(agg, memory, serializer)\
+ if spill else InMemoryMerger(agg)
+ merger.mergeCombiners(it)
+ return merger.iteritems()
+
+ return shuffled.mapPartitions(groupByKey, True).mapValues(ResultIterable)
def flatMapValues(self, f):
"""