aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-01-28 17:26:03 -0800
committerXiangrui Meng <meng@databricks.com>2015-01-28 17:26:03 -0800
commit4ee79c71afc5175ba42b5e3d4088fe23db3e45d1 (patch)
treeaf05f349a568617cbd75a5db34c4ae6fd90a00de /core
parente80dc1c5a80cddba8b367cf5cdf9f71df5d87250 (diff)
downloadspark-4ee79c71afc5175ba42b5e3d4088fe23db3e45d1.tar.gz
spark-4ee79c71afc5175ba42b5e3d4088fe23db3e45d1.tar.bz2
spark-4ee79c71afc5175ba42b5e3d4088fe23db3e45d1.zip
[SPARK-5430] move treeReduce and treeAggregate from mllib to core
We have seen many use cases of `treeAggregate`/`treeReduce` outside the ML domain. Maybe it is time to move them to Core. pwendell Author: Xiangrui Meng <meng@databricks.com> Closes #4228 from mengxr/SPARK-5430 and squashes the following commits: 20ad40d [Xiangrui Meng] exclude tree* from mima e89a43e [Xiangrui Meng] fix compile and update java doc 3ae1a4b [Xiangrui Meng] add treeReduce/treeAggregate to Python 6f948c5 [Xiangrui Meng] add treeReduce/treeAggregate to JavaRDDLike d600b6c [Xiangrui Meng] move treeReduce and treeAggregate to core
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala37
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala63
-rw-r--r--core/src/test/java/org/apache/spark/JavaAPISuite.java30
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala19
4 files changed, 149 insertions, 0 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
index 62bf18d82d..0f91c942ec 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
@@ -349,6 +349,19 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
def reduce(f: JFunction2[T, T, T]): T = rdd.reduce(f)
/**
+ * Reduces the elements of this RDD in a multi-level tree pattern.
+ *
+ * @param depth suggested depth of the tree
+ * @see [[org.apache.spark.api.java.JavaRDDLike#reduce]]
+ */
+ def treeReduce(f: JFunction2[T, T, T], depth: Int): T = rdd.treeReduce(f, depth)
+
+ /**
+ * [[org.apache.spark.api.java.JavaRDDLike#treeReduce]] with suggested depth 2.
+ */
+ def treeReduce(f: JFunction2[T, T, T]): T = treeReduce(f, 2)
+
+ /**
* Aggregate the elements of each partition, and then the results for all the partitions, using a
* given associative function and a neutral "zero value". The function op(t1, t2) is allowed to
* modify t1 and return it as its result value to avoid object allocation; however, it should not
@@ -370,6 +383,30 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
rdd.aggregate(zeroValue)(seqOp, combOp)(fakeClassTag[U])
/**
+ * Aggregates the elements of this RDD in a multi-level tree pattern.
+ *
+ * @param depth suggested depth of the tree
+ * @see [[org.apache.spark.api.java.JavaRDDLike#aggregate]]
+ */
+ def treeAggregate[U](
+ zeroValue: U,
+ seqOp: JFunction2[U, T, U],
+ combOp: JFunction2[U, U, U],
+ depth: Int): U = {
+ rdd.treeAggregate(zeroValue)(seqOp, combOp, depth)(fakeClassTag[U])
+ }
+
+ /**
+ * [[org.apache.spark.api.java.JavaRDDLike#treeAggregate]] with suggested depth 2.
+ */
+ def treeAggregate[U](
+ zeroValue: U,
+ seqOp: JFunction2[U, T, U],
+ combOp: JFunction2[U, U, U]): U = {
+ treeAggregate(zeroValue, seqOp, combOp, 2)
+ }
+
+ /**
* Return the number of elements in the RDD.
*/
def count(): Long = rdd.count()
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index ab7410a1f7..5f39384975 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -901,6 +901,38 @@ abstract class RDD[T: ClassTag](
}
/**
+ * Reduces the elements of this RDD in a multi-level tree pattern.
+ *
+ * @param depth suggested depth of the tree (default: 2)
+ * @see [[org.apache.spark.rdd.RDD#reduce]]
+ */
+ def treeReduce(f: (T, T) => T, depth: Int = 2): T = {
+ require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.")
+ val cleanF = context.clean(f)
+ val reducePartition: Iterator[T] => Option[T] = iter => {
+ if (iter.hasNext) {
+ Some(iter.reduceLeft(cleanF))
+ } else {
+ None
+ }
+ }
+ val partiallyReduced = mapPartitions(it => Iterator(reducePartition(it)))
+ val op: (Option[T], Option[T]) => Option[T] = (c, x) => {
+ if (c.isDefined && x.isDefined) {
+ Some(cleanF(c.get, x.get))
+ } else if (c.isDefined) {
+ c
+ } else if (x.isDefined) {
+ x
+ } else {
+ None
+ }
+ }
+ partiallyReduced.treeAggregate(Option.empty[T])(op, op, depth)
+ .getOrElse(throw new UnsupportedOperationException("empty collection"))
+ }
+
+ /**
* Aggregate the elements of each partition, and then the results for all the partitions, using a
* given associative function and a neutral "zero value". The function op(t1, t2) is allowed to
* modify t1 and return it as its result value to avoid object allocation; however, it should not
@@ -936,6 +968,37 @@ abstract class RDD[T: ClassTag](
}
/**
+ * Aggregates the elements of this RDD in a multi-level tree pattern.
+ *
+ * @param depth suggested depth of the tree (default: 2)
+ * @see [[org.apache.spark.rdd.RDD#aggregate]]
+ */
+ def treeAggregate[U: ClassTag](zeroValue: U)(
+ seqOp: (U, T) => U,
+ combOp: (U, U) => U,
+ depth: Int = 2): U = {
+ require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.")
+ if (partitions.size == 0) {
+ return Utils.clone(zeroValue, context.env.closureSerializer.newInstance())
+ }
+ val cleanSeqOp = context.clean(seqOp)
+ val cleanCombOp = context.clean(combOp)
+ val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
+ var partiallyAggregated = mapPartitions(it => Iterator(aggregatePartition(it)))
+ var numPartitions = partiallyAggregated.partitions.size
+ val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2)
+ // If creating an extra level doesn't help reduce the wall-clock time, we stop tree aggregation.
+ while (numPartitions > scale + numPartitions / scale) {
+ numPartitions /= scale
+ val curNumPartitions = numPartitions
+ partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex { (i, iter) =>
+ iter.map((i % curNumPartitions, _))
+ }.reduceByKey(new HashPartitioner(curNumPartitions), cleanCombOp).values
+ }
+ partiallyAggregated.reduce(cleanCombOp)
+ }
+
+ /**
* Return the number of elements in the RDD.
*/
def count(): Long = sc.runJob(this, Utils.getIteratorSize _).sum
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index 004de05c10..b16a1e9460 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -492,6 +492,36 @@ public class JavaAPISuite implements Serializable {
Assert.assertEquals(33, sum);
}
+ @Test
+ public void treeReduce() {
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(-5, -4, -3, -2, -1, 1, 2, 3, 4), 10);
+ Function2<Integer, Integer, Integer> add = new Function2<Integer, Integer, Integer>() {
+ @Override
+ public Integer call(Integer a, Integer b) {
+ return a + b;
+ }
+ };
+ for (int depth = 1; depth <= 10; depth++) {
+ int sum = rdd.treeReduce(add, depth);
+ Assert.assertEquals(-5, sum);
+ }
+ }
+
+ @Test
+ public void treeAggregate() {
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(-5, -4, -3, -2, -1, 1, 2, 3, 4), 10);
+ Function2<Integer, Integer, Integer> add = new Function2<Integer, Integer, Integer>() {
+ @Override
+ public Integer call(Integer a, Integer b) {
+ return a + b;
+ }
+ };
+ for (int depth = 1; depth <= 10; depth++) {
+ int sum = rdd.treeAggregate(0, add, add, depth);
+ Assert.assertEquals(-5, sum);
+ }
+ }
+
@SuppressWarnings("unchecked")
@Test
public void aggregateByKey() {
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index e33b4bbbb8..bede1ffb3e 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -157,6 +157,24 @@ class RDDSuite extends FunSuite with SharedSparkContext {
assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5)))
}
+ test("treeAggregate") {
+ val rdd = sc.makeRDD(-1000 until 1000, 10)
+ def seqOp = (c: Long, x: Int) => c + x
+ def combOp = (c1: Long, c2: Long) => c1 + c2
+ for (depth <- 1 until 10) {
+ val sum = rdd.treeAggregate(0L)(seqOp, combOp, depth)
+ assert(sum === -1000L)
+ }
+ }
+
+ test("treeReduce") {
+ val rdd = sc.makeRDD(-1000 until 1000, 10)
+ for (depth <- 1 until 10) {
+ val sum = rdd.treeReduce(_ + _, depth)
+ assert(sum === -1000)
+ }
+ }
+
test("basic caching") {
val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()
assert(rdd.collect().toList === List(1, 2, 3, 4))
@@ -967,4 +985,5 @@ class RDDSuite extends FunSuite with SharedSparkContext {
assertFails { sc.parallelize(1 to 100) }
assertFails { sc.textFile("/nonexistent-path") }
}
+
}