aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/rdd.py91
1 files changed, 90 insertions, 1 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index b6dd5a3bf0..2f8a0edfe9 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -29,7 +29,7 @@ import warnings
import heapq
import bisect
import random
-from math import sqrt, log, isinf, isnan
+from math import sqrt, log, isinf, isnan, pow, ceil
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
BatchedSerializer, CloudPickleSerializer, PairDeserializer, \
@@ -726,6 +726,43 @@ class RDD(object):
return reduce(f, vals)
raise ValueError("Can not reduce() empty RDD")
+ def treeReduce(self, f, depth=2):
+ """
+ Reduces the elements of this RDD in a multi-level tree pattern.
+
+ :param depth: suggested depth of the tree (default: 2)
+
+ >>> add = lambda x, y: x + y
+ >>> rdd = sc.parallelize([-5, -4, -3, -2, -1, 1, 2, 3, 4], 10)
+ >>> rdd.treeReduce(add)
+ -5
+ >>> rdd.treeReduce(add, 1)
+ -5
+ >>> rdd.treeReduce(add, 2)
+ -5
+ >>> rdd.treeReduce(add, 5)
+ -5
+ >>> rdd.treeReduce(add, 10)
+ -5
+ """
+ if depth < 1:
+ raise ValueError("Depth cannot be smaller than 1 but got %d." % depth)
+
+ zeroValue = None, True # Use the second entry to indicate whether this is a dummy value.
+
+ def op(x, y):
+ if x[1]:
+ return y
+ elif y[1]:
+ return x
+ else:
+ return f(x[0], y[0]), False
+
+ reduced = self.map(lambda x: (x, False)).treeAggregate(zeroValue, op, op, depth)
+ if reduced[1]:
+ raise ValueError("Cannot reduce empty RDD.")
+ return reduced[0]
+
def fold(self, zeroValue, op):
"""
Aggregate the elements of each partition, and then the results for all
@@ -777,6 +814,58 @@ class RDD(object):
return self.mapPartitions(func).fold(zeroValue, combOp)
+ def treeAggregate(self, zeroValue, seqOp, combOp, depth=2):
+ """
+ Aggregates the elements of this RDD in a multi-level tree
+ pattern.
+
+ :param depth: suggested depth of the tree (default: 2)
+
+ >>> add = lambda x, y: x + y
+ >>> rdd = sc.parallelize([-5, -4, -3, -2, -1, 1, 2, 3, 4], 10)
+ >>> rdd.treeAggregate(0, add, add)
+ -5
+ >>> rdd.treeAggregate(0, add, add, 1)
+ -5
+ >>> rdd.treeAggregate(0, add, add, 2)
+ -5
+ >>> rdd.treeAggregate(0, add, add, 5)
+ -5
+ >>> rdd.treeAggregate(0, add, add, 10)
+ -5
+ """
+ if depth < 1:
+ raise ValueError("Depth cannot be smaller than 1 but got %d." % depth)
+
+ if self.getNumPartitions() == 0:
+ return zeroValue
+
+ def aggregatePartition(iterator):
+ acc = zeroValue
+ for obj in iterator:
+ acc = seqOp(acc, obj)
+ yield acc
+
+ partiallyAggregated = self.mapPartitions(aggregatePartition)
+ numPartitions = partiallyAggregated.getNumPartitions()
+ scale = max(int(ceil(pow(numPartitions, 1.0 / depth))), 2)
+ # If creating an extra level doesn't help reduce the wall-clock time, we stop the tree
+ # aggregation.
+ while numPartitions > scale + numPartitions / scale:
+ numPartitions /= scale
+ curNumPartitions = numPartitions
+
+ def mapPartition(i, iterator):
+ for obj in iterator:
+ yield (i % curNumPartitions, obj)
+
+ partiallyAggregated = partiallyAggregated \
+ .mapPartitionsWithIndex(mapPartition) \
+ .reduceByKey(combOp, curNumPartitions) \
+ .values()
+
+ return partiallyAggregated.reduce(combOp)
+
def max(self, key=None):
"""
Find the maximum item in this RDD.