diff options
author | Xiangrui Meng <meng@databricks.com> | 2015-01-28 17:26:03 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-01-28 17:26:03 -0800 |
commit | 4ee79c71afc5175ba42b5e3d4088fe23db3e45d1 (patch) | |
tree | af05f349a568617cbd75a5db34c4ae6fd90a00de /core | |
parent | e80dc1c5a80cddba8b367cf5cdf9f71df5d87250 (diff) | |
download | spark-4ee79c71afc5175ba42b5e3d4088fe23db3e45d1.tar.gz spark-4ee79c71afc5175ba42b5e3d4088fe23db3e45d1.tar.bz2 spark-4ee79c71afc5175ba42b5e3d4088fe23db3e45d1.zip |
[SPARK-5430] move treeReduce and treeAggregate from mllib to core
We have seen many use cases of `treeAggregate`/`treeReduce` outside the ML domain. Maybe it is time to move them to Core. pwendell
Author: Xiangrui Meng <meng@databricks.com>
Closes #4228 from mengxr/SPARK-5430 and squashes the following commits:
20ad40d [Xiangrui Meng] exclude tree* from mima
e89a43e [Xiangrui Meng] fix compile and update java doc
3ae1a4b [Xiangrui Meng] add treeReduce/treeAggregate to Python
6f948c5 [Xiangrui Meng] add treeReduce/treeAggregate to JavaRDDLike
d600b6c [Xiangrui Meng] move treeReduce and treeAggregate to core
Diffstat (limited to 'core')
4 files changed, 149 insertions, 0 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 62bf18d82d..0f91c942ec 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -349,6 +349,19 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def reduce(f: JFunction2[T, T, T]): T = rdd.reduce(f) /** + * Reduces the elements of this RDD in a multi-level tree pattern. + * + * @param depth suggested depth of the tree + * @see [[org.apache.spark.api.java.JavaRDDLike#reduce]] + */ + def treeReduce(f: JFunction2[T, T, T], depth: Int): T = rdd.treeReduce(f, depth) + + /** + * [[org.apache.spark.api.java.JavaRDDLike#treeReduce]] with suggested depth 2. + */ + def treeReduce(f: JFunction2[T, T, T]): T = treeReduce(f, 2) + + /** * Aggregate the elements of each partition, and then the results for all the partitions, using a * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to * modify t1 and return it as its result value to avoid object allocation; however, it should not @@ -370,6 +383,30 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { rdd.aggregate(zeroValue)(seqOp, combOp)(fakeClassTag[U]) /** + * Aggregates the elements of this RDD in a multi-level tree pattern. + * + * @param depth suggested depth of the tree + * @see [[org.apache.spark.api.java.JavaRDDLike#aggregate]] + */ + def treeAggregate[U]( + zeroValue: U, + seqOp: JFunction2[U, T, U], + combOp: JFunction2[U, U, U], + depth: Int): U = { + rdd.treeAggregate(zeroValue)(seqOp, combOp, depth)(fakeClassTag[U]) + } + + /** + * [[org.apache.spark.api.java.JavaRDDLike#treeAggregate]] with suggested depth 2. + */ + def treeAggregate[U]( + zeroValue: U, + seqOp: JFunction2[U, T, U], + combOp: JFunction2[U, U, U]): U = { + treeAggregate(zeroValue, seqOp, combOp, 2) + } + + /** * Return the number of elements in the RDD. */ def count(): Long = rdd.count() diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index ab7410a1f7..5f39384975 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -901,6 +901,38 @@ abstract class RDD[T: ClassTag]( } /** + * Reduces the elements of this RDD in a multi-level tree pattern. + * + * @param depth suggested depth of the tree (default: 2) + * @see [[org.apache.spark.rdd.RDD#reduce]] + */ + def treeReduce(f: (T, T) => T, depth: Int = 2): T = { + require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.") + val cleanF = context.clean(f) + val reducePartition: Iterator[T] => Option[T] = iter => { + if (iter.hasNext) { + Some(iter.reduceLeft(cleanF)) + } else { + None + } + } + val partiallyReduced = mapPartitions(it => Iterator(reducePartition(it))) + val op: (Option[T], Option[T]) => Option[T] = (c, x) => { + if (c.isDefined && x.isDefined) { + Some(cleanF(c.get, x.get)) + } else if (c.isDefined) { + c + } else if (x.isDefined) { + x + } else { + None + } + } + partiallyReduced.treeAggregate(Option.empty[T])(op, op, depth) + .getOrElse(throw new UnsupportedOperationException("empty collection")) + } + + /** * Aggregate the elements of each partition, and then the results for all the partitions, using a * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to * modify t1 and return it as its result value to avoid object allocation; however, it should not @@ -936,6 +968,37 @@ abstract class RDD[T: ClassTag]( } /** + * Aggregates the elements of this RDD in a multi-level tree pattern. + * + * @param depth suggested depth of the tree (default: 2) + * @see [[org.apache.spark.rdd.RDD#aggregate]] + */ + def treeAggregate[U: ClassTag](zeroValue: U)( + seqOp: (U, T) => U, + combOp: (U, U) => U, + depth: Int = 2): U = { + require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.") + if (partitions.size == 0) { + return Utils.clone(zeroValue, context.env.closureSerializer.newInstance()) + } + val cleanSeqOp = context.clean(seqOp) + val cleanCombOp = context.clean(combOp) + val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp) + var partiallyAggregated = mapPartitions(it => Iterator(aggregatePartition(it))) + var numPartitions = partiallyAggregated.partitions.size + val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2) + // If creating an extra level doesn't help reduce the wall-clock time, we stop tree aggregation. + while (numPartitions > scale + numPartitions / scale) { + numPartitions /= scale + val curNumPartitions = numPartitions + partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex { (i, iter) => + iter.map((i % curNumPartitions, _)) + }.reduceByKey(new HashPartitioner(curNumPartitions), cleanCombOp).values + } + partiallyAggregated.reduce(cleanCombOp) + } + + /** * Return the number of elements in the RDD. */ def count(): Long = sc.runJob(this, Utils.getIteratorSize _).sum diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 004de05c10..b16a1e9460 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -492,6 +492,36 @@ public class JavaAPISuite implements Serializable { Assert.assertEquals(33, sum); } + @Test + public void treeReduce() { + JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(-5, -4, -3, -2, -1, 1, 2, 3, 4), 10); + Function2<Integer, Integer, Integer> add = new Function2<Integer, Integer, Integer>() { + @Override + public Integer call(Integer a, Integer b) { + return a + b; + } + }; + for (int depth = 1; depth <= 10; depth++) { + int sum = rdd.treeReduce(add, depth); + Assert.assertEquals(-5, sum); + } + } + + @Test + public void treeAggregate() { + JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(-5, -4, -3, -2, -1, 1, 2, 3, 4), 10); + Function2<Integer, Integer, Integer> add = new Function2<Integer, Integer, Integer>() { + @Override + public Integer call(Integer a, Integer b) { + return a + b; + } + }; + for (int depth = 1; depth <= 10; depth++) { + int sum = rdd.treeAggregate(0, add, add, depth); + Assert.assertEquals(-5, sum); + } + } + @SuppressWarnings("unchecked") @Test public void aggregateByKey() { diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index e33b4bbbb8..bede1ffb3e 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -157,6 +157,24 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5))) } + test("treeAggregate") { + val rdd = sc.makeRDD(-1000 until 1000, 10) + def seqOp = (c: Long, x: Int) => c + x + def combOp = (c1: Long, c2: Long) => c1 + c2 + for (depth <- 1 until 10) { + val sum = rdd.treeAggregate(0L)(seqOp, combOp, depth) + assert(sum === -1000L) + } + } + + test("treeReduce") { + val rdd = sc.makeRDD(-1000 until 1000, 10) + for (depth <- 1 until 10) { + val sum = rdd.treeReduce(_ + _, depth) + assert(sum === -1000) + } + } + test("basic caching") { val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() assert(rdd.collect().toList === List(1, 2, 3, 4)) @@ -967,4 +985,5 @@ class RDDSuite extends FunSuite with SharedSparkContext { assertFails { sc.parallelize(1 to 100) } assertFails { sc.textFile("/nonexistent-path") } } + } |