From 563c5e717cc75869c328bba17116313eab9e976b Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Sat, 23 Apr 2011 13:38:37 -0700 Subject: Refactor and add aggregator support Refactored out the agg() and comp() methods from Pregel.run. Defined an implicit conversion to allow applications that don't use aggregators to avoid including a null argument for the result of the aggregator in the compute function. --- bagel/src/main/scala/bagel/Pregel.scala | 111 ++++++++++++++------- bagel/src/main/scala/bagel/ShortestPath.scala | 5 +- bagel/src/main/scala/bagel/WikipediaPageRank.scala | 6 +- bagel/src/test/scala/bagel/BagelSuite.scala | 10 +- 4 files changed, 88 insertions(+), 44 deletions(-) (limited to 'bagel/src') diff --git a/bagel/src/main/scala/bagel/Pregel.scala b/bagel/src/main/scala/bagel/Pregel.scala index e3b6d0c70a..2ec94ebb0d 100644 --- a/bagel/src/main/scala/bagel/Pregel.scala +++ b/bagel/src/main/scala/bagel/Pregel.scala @@ -6,37 +6,62 @@ import spark.SparkContext._ import scala.collection.mutable.ArrayBuffer object Pregel extends Logging { - /** - * Runs a Pregel job on the given vertices consisting of the - * specified compute function. - * - * Before beginning the first superstep, the given messages are sent - * to their destination vertices. - * - * During the job, the specified combiner functions are applied to - * messages as they travel between vertices. - * - * The job halts and returns the resulting set of vertices when no - * messages are being sent between vertices and all vertices have - * voted to halt by setting their state to inactive. - */ - def run[V <: Vertex : Manifest, M <: Message : Manifest, C]( + def run[V <: Vertex : Manifest, M <: Message : Manifest, C : Manifest, A : Manifest]( sc: SparkContext, verts: RDD[(String, V)], - msgs: RDD[(String, M)], - combiner: Combiner[M, C], - numSplits: Int, - superstep: Int = 0 - )(compute: (V, Option[C], Int) => (V, Iterable[M])): RDD[V] = { + msgs: RDD[(String, M)] + )( + combiner: Combiner[M, C] = new DefaultCombiner[M], + aggregator: Aggregator[V, A] = new NullAggregator[V], + superstep: Int = 0, + numSplits: Int = sc.numCores + )( + compute: (V, Option[C], A, Int) => (V, Iterable[M]) + ): RDD[V] = { logInfo("Starting superstep "+superstep+".") val startTime = System.currentTimeMillis - // Bring together vertices and messages + val aggregated = agg(verts, aggregator) val combinedMsgs = msgs.combineByKey(combiner.createCombiner, combiner.mergeMsg, combiner.mergeCombiners, numSplits) val grouped = verts.groupWith(combinedMsgs) + val (processed, numMsgs, numActiveVerts) = comp[V, M, C](sc, grouped, compute(_, _, aggregated, superstep)) + + val timeTaken = System.currentTimeMillis - startTime + logInfo("Superstep %d took %d s".format(superstep, timeTaken / 1000)) + + // Check stopping condition and iterate + val noActivity = numMsgs == 0 && numActiveVerts == 0 + if (noActivity) { + processed.map { case (id, (vert, msgs)) => vert } + } else { + val newVerts = processed.mapValues { case (vert, msgs) => vert } + val newMsgs = processed.flatMap { + case (id, (vert, msgs)) => msgs.map(m => (m.targetId, m)) + } + run(sc, newVerts, newMsgs)(combiner, aggregator, superstep + 1, numSplits)(compute) + } + } + + /** + * Aggregates the given vertices using the given aggregator, or does + * nothing if it is a NullAggregator. + */ + def agg[V <: Vertex, A : Manifest](verts: RDD[(String, V)], aggregator: Aggregator[V, A]): A = aggregator match { + case _: NullAggregator[_] => + None + case _ => + verts.map { + case (id, vert) => aggregator.createAggregator(vert) + }.reduce(aggregator.mergeAggregators(_, _)) + } - // Run compute on each vertex + /** + * Processes the given vertex-message RDD using the compute + * function. Returns the processed RDD, the number of messages + * created, and the number of active vertices. + */ + def comp[V <: Vertex, M <: Message, C](sc: SparkContext, grouped: RDD[(String, (Seq[V], Seq[C]))], compute: (V, Option[C]) => (V, Iterable[M])): (RDD[(String, (V, Iterable[M]))], Int, Int) = { var numMsgs = sc.accumulator(0) var numActiveVerts = sc.accumulator(0) val processed = grouped.flatMapValues { @@ -46,7 +71,7 @@ object Pregel extends Logging { compute(v, c match { case Seq(comb) => Some(comb) case Seq() => None - }, superstep) + }) numMsgs += newMsgs.size if (newVert.active) @@ -58,30 +83,36 @@ object Pregel extends Logging { // Force evaluation of processed RDD for accurate performance measurements processed.foreach(x => {}) - val timeTaken = System.currentTimeMillis - startTime - logInfo("Superstep %d took %d s".format(superstep, timeTaken / 1000)) + (processed, numMsgs.value, numActiveVerts.value) + } - // Check stopping condition and iterate - val noActivity = numMsgs.value == 0 && numActiveVerts.value == 0 - if (noActivity) { - processed.map { case (id, (vert, msgs)) => vert } - } else { - val newVerts = processed.mapValues { case (vert, msgs) => vert } - val newMsgs = processed.flatMap { - case (id, (vert, msgs)) => msgs.map(m => (m.targetId, m)) - } - run(sc, newVerts, newMsgs, combiner, numSplits, superstep + 1)(compute) - } + /** + * Converts a compute function that doesn't take an aggregator to + * one that does, so it can be passed to Pregel.run. + */ + implicit def addAggregatorArg[ + V <: Vertex : Manifest, M <: Message : Manifest, C + ]( + compute: (V, Option[C], Int) => (V, Iterable[M]) + ): (V, Option[C], Option[Nothing], Int) => (V, Iterable[M]) = { + (vert: V, messages: Option[C], aggregator: Option[Nothing], superstep: Int) => compute(vert, messages, superstep) } } +// TODO: Simplify Combiner interface and make it more OO. trait Combiner[M, C] { def createCombiner(msg: M): C def mergeMsg(combiner: C, msg: M): C def mergeCombiners(a: C, b: C): C } -@serializable class DefaultCombiner[M] extends Combiner[M, ArrayBuffer[M]] { +trait Aggregator[V, A] { + def createAggregator(vert: V): A + def mergeAggregators(a: A, b: A): A +} + +@serializable +class DefaultCombiner[M] extends Combiner[M, ArrayBuffer[M]] { def createCombiner(msg: M): ArrayBuffer[M] = ArrayBuffer(msg) def mergeMsg(combiner: ArrayBuffer[M], msg: M): ArrayBuffer[M] = @@ -90,6 +121,12 @@ trait Combiner[M, C] { a ++= b } +@serializable +class NullAggregator[V] extends Aggregator[V, Option[Nothing]] { + def createAggregator(vert: V): Option[Nothing] = None + def mergeAggregators(a: Option[Nothing], b: Option[Nothing]): Option[Nothing] = None +} + /** * Represents a Pregel vertex. * diff --git a/bagel/src/main/scala/bagel/ShortestPath.scala b/bagel/src/main/scala/bagel/ShortestPath.scala index 3fd2f39334..8f4a881850 100644 --- a/bagel/src/main/scala/bagel/ShortestPath.scala +++ b/bagel/src/main/scala/bagel/ShortestPath.scala @@ -5,6 +5,8 @@ import spark.SparkContext._ import scala.math.min +import bagel.Pregel._ + object ShortestPath { def main(args: Array[String]) { if (args.length < 4) { @@ -49,7 +51,7 @@ object ShortestPath { messages.count()+" messages.") // Do the computation - val result = Pregel.run(sc, vertices, messages, MinCombiner, numSplits) { + val compute = addAggregatorArg { (self: SPVertex, messageMinValue: Option[Int], superstep: Int) => val newValue = messageMinValue match { case Some(minVal) => min(self.value, minVal) @@ -65,6 +67,7 @@ object ShortestPath { (new SPVertex(self.id, newValue, self.outEdges, false), outbox) } + val result = Pregel.run(sc, vertices, messages)(combiner = MinCombiner, numSplits = numSplits)(compute) // Print the result System.err.println("Shortest path from "+startVertex+" to all vertices:") diff --git a/bagel/src/main/scala/bagel/WikipediaPageRank.scala b/bagel/src/main/scala/bagel/WikipediaPageRank.scala index 994cea8ec3..2fe77b4962 100644 --- a/bagel/src/main/scala/bagel/WikipediaPageRank.scala +++ b/bagel/src/main/scala/bagel/WikipediaPageRank.scala @@ -3,6 +3,8 @@ package bagel import spark._ import spark.SparkContext._ +import bagel.Pregel._ + import scala.collection.mutable.ArrayBuffer import scala.xml.{XML,NodeSeq} @@ -60,9 +62,9 @@ object WikipediaPageRank { val messages = sc.parallelize(List[(String, PRMessage)]()) val result = if (noCombiner) { - Pregel.run(sc, vertices, messages, PRNoCombiner, numSplits)(PRNoCombiner.compute(numVertices, epsilon)) + Pregel.run(sc, vertices, messages)(numSplits = numSplits)(PRNoCombiner.compute(numVertices, epsilon)) } else { - Pregel.run(sc, vertices, messages, PRCombiner, numSplits)(PRCombiner.compute(numVertices, epsilon)) + Pregel.run(sc, vertices, messages)(combiner = PRCombiner, numSplits = numSplits)(PRCombiner.compute(numVertices, epsilon)) } // Print the result diff --git a/bagel/src/test/scala/bagel/BagelSuite.scala b/bagel/src/test/scala/bagel/BagelSuite.scala index 29f5f0c358..53a93a6b80 100644 --- a/bagel/src/test/scala/bagel/BagelSuite.scala +++ b/bagel/src/test/scala/bagel/BagelSuite.scala @@ -10,6 +10,8 @@ import scala.collection.mutable.ArrayBuffer import spark._ +import bagel.Pregel._ + @serializable class TestVertex(val id: String, val active: Boolean, val age: Int) extends Vertex @serializable class TestMessage(val targetId: String) extends Message @@ -20,10 +22,10 @@ class BagelSuite extends FunSuite with Assertions { val msgs = sc.parallelize(Array[(String, TestMessage)]()) val numSupersteps = 5 val result = - Pregel.run(sc, verts, msgs, new DefaultCombiner[TestMessage], 1) { + Pregel.run(sc, verts, msgs)()(addAggregatorArg { (self: TestVertex, msgs: Option[ArrayBuffer[TestMessage]], superstep: Int) => (new TestVertex(self.id, superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]()) - } + }) for (vert <- result.collect) assert(vert.age === numSupersteps) } @@ -34,7 +36,7 @@ class BagelSuite extends FunSuite with Assertions { val msgs = sc.parallelize(Array("a" -> new TestMessage("a"))) val numSupersteps = 5 val result = - Pregel.run(sc, verts, msgs, new DefaultCombiner[TestMessage], 1) { + Pregel.run(sc, verts, msgs)()(addAggregatorArg { (self: TestVertex, msgs: Option[ArrayBuffer[TestMessage]], superstep: Int) => val msgsOut = msgs match { @@ -44,7 +46,7 @@ class BagelSuite extends FunSuite with Assertions { new ArrayBuffer[TestMessage]() } (new TestVertex(self.id, self.active, self.age + 1), msgsOut) - } + }) for (vert <- result.collect) assert(vert.age === numSupersteps) } -- cgit v1.2.3