aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-01-28 17:26:03 -0800
committerXiangrui Meng <meng@databricks.com>2015-01-28 17:26:03 -0800
commit4ee79c71afc5175ba42b5e3d4088fe23db3e45d1 (patch)
treeaf05f349a568617cbd75a5db34c4ae6fd90a00de /python
parente80dc1c5a80cddba8b367cf5cdf9f71df5d87250 (diff)
downloadspark-4ee79c71afc5175ba42b5e3d4088fe23db3e45d1.tar.gz
spark-4ee79c71afc5175ba42b5e3d4088fe23db3e45d1.tar.bz2
spark-4ee79c71afc5175ba42b5e3d4088fe23db3e45d1.zip
[SPARK-5430] move treeReduce and treeAggregate from mllib to core
We have seen many use cases of `treeAggregate`/`treeReduce` outside the ML domain. Maybe it is time to move them to Core. pwendell Author: Xiangrui Meng <meng@databricks.com> Closes #4228 from mengxr/SPARK-5430 and squashes the following commits: 20ad40d [Xiangrui Meng] exclude tree* from mima e89a43e [Xiangrui Meng] fix compile and update java doc 3ae1a4b [Xiangrui Meng] add treeReduce/treeAggregate to Python 6f948c5 [Xiangrui Meng] add treeReduce/treeAggregate to JavaRDDLike d600b6c [Xiangrui Meng] move treeReduce and treeAggregate to core
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/rdd.py91
1 files changed, 90 insertions, 1 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index b6dd5a3bf0..2f8a0edfe9 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -29,7 +29,7 @@ import warnings
import heapq
import bisect
import random
-from math import sqrt, log, isinf, isnan
+from math import sqrt, log, isinf, isnan, pow, ceil
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
BatchedSerializer, CloudPickleSerializer, PairDeserializer, \
@@ -726,6 +726,43 @@ class RDD(object):
return reduce(f, vals)
raise ValueError("Can not reduce() empty RDD")
+ def treeReduce(self, f, depth=2):
+ """
+ Reduces the elements of this RDD in a multi-level tree pattern.
+
+ :param depth: suggested depth of the tree (default: 2)
+
+ >>> add = lambda x, y: x + y
+ >>> rdd = sc.parallelize([-5, -4, -3, -2, -1, 1, 2, 3, 4], 10)
+ >>> rdd.treeReduce(add)
+ -5
+ >>> rdd.treeReduce(add, 1)
+ -5
+ >>> rdd.treeReduce(add, 2)
+ -5
+ >>> rdd.treeReduce(add, 5)
+ -5
+ >>> rdd.treeReduce(add, 10)
+ -5
+ """
+ if depth < 1:
+ raise ValueError("Depth cannot be smaller than 1 but got %d." % depth)
+
+ zeroValue = None, True # Use the second entry to indicate whether this is a dummy value.
+
+ def op(x, y):
+ if x[1]:
+ return y
+ elif y[1]:
+ return x
+ else:
+ return f(x[0], y[0]), False
+
+ reduced = self.map(lambda x: (x, False)).treeAggregate(zeroValue, op, op, depth)
+ if reduced[1]:
+ raise ValueError("Cannot reduce empty RDD.")
+ return reduced[0]
+
def fold(self, zeroValue, op):
"""
Aggregate the elements of each partition, and then the results for all
@@ -777,6 +814,58 @@ class RDD(object):
return self.mapPartitions(func).fold(zeroValue, combOp)
+ def treeAggregate(self, zeroValue, seqOp, combOp, depth=2):
+ """
+ Aggregates the elements of this RDD in a multi-level tree
+ pattern.
+
+ :param depth: suggested depth of the tree (default: 2)
+
+ >>> add = lambda x, y: x + y
+ >>> rdd = sc.parallelize([-5, -4, -3, -2, -1, 1, 2, 3, 4], 10)
+ >>> rdd.treeAggregate(0, add, add)
+ -5
+ >>> rdd.treeAggregate(0, add, add, 1)
+ -5
+ >>> rdd.treeAggregate(0, add, add, 2)
+ -5
+ >>> rdd.treeAggregate(0, add, add, 5)
+ -5
+ >>> rdd.treeAggregate(0, add, add, 10)
+ -5
+ """
+ if depth < 1:
+ raise ValueError("Depth cannot be smaller than 1 but got %d." % depth)
+
+ if self.getNumPartitions() == 0:
+ return zeroValue
+
+ def aggregatePartition(iterator):
+ acc = zeroValue
+ for obj in iterator:
+ acc = seqOp(acc, obj)
+ yield acc
+
+ partiallyAggregated = self.mapPartitions(aggregatePartition)
+ numPartitions = partiallyAggregated.getNumPartitions()
+ scale = max(int(ceil(pow(numPartitions, 1.0 / depth))), 2)
+ # If creating an extra level doesn't help reduce the wall-clock time, we stop the tree
+ # aggregation.
+ while numPartitions > scale + numPartitions / scale:
+ numPartitions /= scale
+ curNumPartitions = numPartitions
+
+ def mapPartition(i, iterator):
+ for obj in iterator:
+ yield (i % curNumPartitions, obj)
+
+ partiallyAggregated = partiallyAggregated \
+ .mapPartitionsWithIndex(mapPartition) \
+ .reduceByKey(combOp, curNumPartitions) \
+ .values()
+
+ return partiallyAggregated.reduce(combOp)
+
def max(self, key=None):
"""
Find the maximum item in this RDD.