From 4ee79c71afc5175ba42b5e3d4088fe23db3e45d1 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 28 Jan 2015 17:26:03 -0800 Subject: [SPARK-5430] move treeReduce and treeAggregate from mllib to core We have seen many use cases of `treeAggregate`/`treeReduce` outside the ML domain. Maybe it is time to move them to Core. pwendell Author: Xiangrui Meng Closes #4228 from mengxr/SPARK-5430 and squashes the following commits: 20ad40d [Xiangrui Meng] exclude tree* from mima e89a43e [Xiangrui Meng] fix compile and update java doc 3ae1a4b [Xiangrui Meng] add treeReduce/treeAggregate to Python 6f948c5 [Xiangrui Meng] add treeReduce/treeAggregate to JavaRDDLike d600b6c [Xiangrui Meng] move treeReduce and treeAggregate to core --- python/pyspark/rdd.py | 91 ++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 90 insertions(+), 1 deletion(-) (limited to 'python/pyspark/rdd.py') diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index b6dd5a3bf0..2f8a0edfe9 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -29,7 +29,7 @@ import warnings import heapq import bisect import random -from math import sqrt, log, isinf, isnan +from math import sqrt, log, isinf, isnan, pow, ceil from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ @@ -726,6 +726,43 @@ class RDD(object): return reduce(f, vals) raise ValueError("Can not reduce() empty RDD") + def treeReduce(self, f, depth=2): + """ + Reduces the elements of this RDD in a multi-level tree pattern. + + :param depth: suggested depth of the tree (default: 2) + + >>> add = lambda x, y: x + y + >>> rdd = sc.parallelize([-5, -4, -3, -2, -1, 1, 2, 3, 4], 10) + >>> rdd.treeReduce(add) + -5 + >>> rdd.treeReduce(add, 1) + -5 + >>> rdd.treeReduce(add, 2) + -5 + >>> rdd.treeReduce(add, 5) + -5 + >>> rdd.treeReduce(add, 10) + -5 + """ + if depth < 1: + raise ValueError("Depth cannot be smaller than 1 but got %d." % depth) + + zeroValue = None, True # Use the second entry to indicate whether this is a dummy value. + + def op(x, y): + if x[1]: + return y + elif y[1]: + return x + else: + return f(x[0], y[0]), False + + reduced = self.map(lambda x: (x, False)).treeAggregate(zeroValue, op, op, depth) + if reduced[1]: + raise ValueError("Cannot reduce empty RDD.") + return reduced[0] + def fold(self, zeroValue, op): """ Aggregate the elements of each partition, and then the results for all @@ -777,6 +814,58 @@ class RDD(object): return self.mapPartitions(func).fold(zeroValue, combOp) + def treeAggregate(self, zeroValue, seqOp, combOp, depth=2): + """ + Aggregates the elements of this RDD in a multi-level tree + pattern. + + :param depth: suggested depth of the tree (default: 2) + + >>> add = lambda x, y: x + y + >>> rdd = sc.parallelize([-5, -4, -3, -2, -1, 1, 2, 3, 4], 10) + >>> rdd.treeAggregate(0, add, add) + -5 + >>> rdd.treeAggregate(0, add, add, 1) + -5 + >>> rdd.treeAggregate(0, add, add, 2) + -5 + >>> rdd.treeAggregate(0, add, add, 5) + -5 + >>> rdd.treeAggregate(0, add, add, 10) + -5 + """ + if depth < 1: + raise ValueError("Depth cannot be smaller than 1 but got %d." % depth) + + if self.getNumPartitions() == 0: + return zeroValue + + def aggregatePartition(iterator): + acc = zeroValue + for obj in iterator: + acc = seqOp(acc, obj) + yield acc + + partiallyAggregated = self.mapPartitions(aggregatePartition) + numPartitions = partiallyAggregated.getNumPartitions() + scale = max(int(ceil(pow(numPartitions, 1.0 / depth))), 2) + # If creating an extra level doesn't help reduce the wall-clock time, we stop the tree + # aggregation. + while numPartitions > scale + numPartitions / scale: + numPartitions /= scale + curNumPartitions = numPartitions + + def mapPartition(i, iterator): + for obj in iterator: + yield (i % curNumPartitions, obj) + + partiallyAggregated = partiallyAggregated \ + .mapPartitionsWithIndex(mapPartition) \ + .reduceByKey(combOp, curNumPartitions) \ + .values() + + return partiallyAggregated.reduce(combOp) + def max(self, key=None): """ Find the maximum item in this RDD. -- cgit v1.2.3