From c5b3ea755ff8a69aa39dd6e46d57cbe9d5bcbcae Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Wed, 13 Apr 2011 18:40:44 -0700 Subject: Clean up Bagel source and interface --- bagel/src/main/scala/bagel/Pregel.scala | 109 +++++++++++---------- bagel/src/main/scala/bagel/ShortestPath.scala | 15 ++- bagel/src/main/scala/bagel/WikipediaPageRank.scala | 99 ++++++------------- 3 files changed, 99 insertions(+), 124 deletions(-) (limited to 'bagel') diff --git a/bagel/src/main/scala/bagel/Pregel.scala b/bagel/src/main/scala/bagel/Pregel.scala index b0f7a48b7a..5ef398d783 100644 --- a/bagel/src/main/scala/bagel/Pregel.scala +++ b/bagel/src/main/scala/bagel/Pregel.scala @@ -7,75 +7,81 @@ import scala.collection.mutable.ArrayBuffer object Pregel extends Logging { /** - * Runs a Pregel job on the given vertices, running the specified - * compute function on each vertex in every superstep. Before - * beginning the first superstep, sends the given messages to their - * destination vertices. In the join stage, launches splits - * separate tasks (where splits is manually specified to work - * around a bug in Spark). + * Runs a Pregel job on the given vertices consisting of the + * specified compute function. * - * Halts when no more messages are being sent between vertices, and - * all vertices have voted to halt by setting their state to - * Inactive. + * Before beginning the first superstep, the given messages are sent + * to their destination vertices. + * + * During the job, the specified combiner functions are applied to + * messages as they travel between vertices. + * + * The job halts and returns the resulting set of vertices when no + * messages are being sent between vertices and all vertices have + * voted to halt by setting their state to inactive. */ - def run[V <: Vertex : Manifest, M <: Message : Manifest, C](sc: SparkContext, verts: RDD[(String, V)], msgs: RDD[(String, M)], splits: Int, messageCombiner: (C, M) => C, defaultCombined: () => C, mergeCombined: (C, C) => C, maxSupersteps: Option[Int] = None, superstep: Int = 0)(compute: (V, C, Int) => (V, Iterable[M])): RDD[V] = { + def run[V <: Vertex : Manifest, M <: Message : Manifest, C]( + sc: SparkContext, + verts: RDD[(String, V)], + msgs: RDD[(String, M)], + createCombiner: M => C, + mergeMsg: (C, M) => C, + mergeCombiners: (C, C) => C, + numSplits: Int, + superstep: Int = 0 + )(compute: (V, Option[C], Int) => (V, Iterable[M])): RDD[V] = { + logInfo("Starting superstep "+superstep+".") val startTime = System.currentTimeMillis // Bring together vertices and messages - val combinedMsgs = msgs.combineByKey({x => messageCombiner(defaultCombined(), x)}, messageCombiner, mergeCombined, splits) - logDebug("verts.splits.size = " + verts.splits.size) - logDebug("combinedMsgs.splits.size = " + combinedMsgs.splits.size) - logDebug("verts.partitioner = " + verts.partitioner) - logDebug("combinedMsgs.partitioner = " + combinedMsgs.partitioner) - - val joined = verts.groupWith(combinedMsgs) - logDebug("joined.splits.size = " + joined.splits.size) - logDebug("joined.partitioner = " + joined.partitioner) + val combinedMsgs = msgs.combineByKey(createCombiner, mergeMsg, mergeCombiners, numSplits) + val grouped = verts.groupWith(combinedMsgs) // Run compute on each vertex - var messageCount = sc.accumulator(0) - var activeVertexCount = sc.accumulator(0) - val processed = joined.flatMapValues { + var numMsgs = sc.accumulator(0) + var numActiveVerts = sc.accumulator(0) + val processed = grouped.flatMapValues { case (Seq(), _) => None - case (Seq(v), Seq(comb)) => - val (newVertex, newMessages) = compute(v, comb, superstep) + case (Seq(v), c) => + val (newVert, newMsgs) = + compute(v, c match { + case Seq(comb) => Some(comb) + case Seq() => None + }, superstep) - messageCount += newMessages.size - if (newVertex.active) - activeVertexCount += 1 + numMsgs += newMsgs.size + if (newVert.active) + numActiveVerts += 1 - Some((newVertex, newMessages)) - case (Seq(v), Seq()) => - val (newVertex, newMessages) = compute(v, defaultCombined(), superstep) - - messageCount += newMessages.size - if (newVertex.active) - activeVertexCount += 1 - - Some((newVertex, newMessages)) + Some((newVert, newMsgs)) }.cache + // Force evaluation of processed RDD for accurate performance measurements processed.foreach(x => {}) val timeTaken = System.currentTimeMillis - startTime logInfo("Superstep %d took %d s".format(superstep, timeTaken / 1000)) - // Check stopping condition and recurse - val stop = messageCount.value == 0 && activeVertexCount.value == 0 - if (stop || (maxSupersteps.isDefined && superstep >= maxSupersteps.get)) { - processed.map { _._2._1 } + // Check stopping condition and iterate + val noActivity = numMsgs.value == 0 && numActiveVerts.value == 0 + if (noActivity) { + processed.map { case (id, (vert, msgs)) => vert } } else { - val newVerts = processed.mapValues(_._1) - val newMsgs = processed.flatMap(x => x._2._2.map(m => (m.targetId, m))) - run(sc, newVerts, newMsgs, splits, messageCombiner, defaultCombined, mergeCombined, maxSupersteps, superstep + 1)(compute) + val newVerts = processed.mapValues { case (vert, msgs) => vert } + val newMsgs = processed.flatMap { + case (id, (vert, msgs)) => msgs.map(m => (m.targetId, m)) + } + run(sc, newVerts, newMsgs, createCombiner, mergeMsg, mergeCombiners, numSplits, superstep + 1)(compute) } } } /** - * Represents a Pregel vertex. Must be subclassed to store state - * along with each vertex. Must be annotated with @serializable. + * Represents a Pregel vertex. + * + * Subclasses may store state along with each vertex and must be + * annotated with @serializable. */ trait Vertex { def id: String @@ -83,17 +89,20 @@ trait Vertex { } /** - * Represents a Pregel message to a target vertex. Must be - * subclassed to contain a payload. Must be annotated with @serializable. + * Represents a Pregel message to a target vertex. + * + * Subclasses may contain a payload to deliver to the target vertex + * and must be annotated with @serializable. */ trait Message { def targetId: String } /** - * Represents a directed edge between two vertices. Owned by the - * source vertex, and contains the ID of the target vertex. Must - * be subclassed to store state along with each edge. Must be annotated with @serializable. + * Represents a directed edge between two vertices. + * + * Subclasses may store state along each edge and must be annotated + * with @serializable. */ trait Edge { def targetId: String diff --git a/bagel/src/main/scala/bagel/ShortestPath.scala b/bagel/src/main/scala/bagel/ShortestPath.scala index 38f533728d..6699f58a31 100644 --- a/bagel/src/main/scala/bagel/ShortestPath.scala +++ b/bagel/src/main/scala/bagel/ShortestPath.scala @@ -49,12 +49,17 @@ object ShortestPath { messages.count()+" messages.") // Do the computation - def messageCombiner(minSoFar: Int, message: SPMessage): Int = - min(minSoFar, message.value) + def createCombiner(message: SPMessage): Int = message.value + def mergeMsg(combiner: Int, message: SPMessage): Int = + min(combiner, message.value) + def mergeCombiners(a: Int, b: Int): Int = min(a, b) - val result = Pregel.run(sc, vertices, messages, numSplits, messageCombiner, () => Int.MaxValue, min _) { - (self: SPVertex, messageMinValue: Int, superstep: Int) => - val newValue = min(self.value, messageMinValue) + val result = Pregel.run(sc, vertices, messages, createCombiner, mergeMsg, mergeCombiners, numSplits) { + (self: SPVertex, messageMinValue: Option[Int], superstep: Int) => + val newValue = messageMinValue match { + case Some(minVal) => min(self.value, minVal) + case None => self.value + } val outbox = if (newValue != self.value) diff --git a/bagel/src/main/scala/bagel/WikipediaPageRank.scala b/bagel/src/main/scala/bagel/WikipediaPageRank.scala index a98fd371e1..f6aeff0782 100644 --- a/bagel/src/main/scala/bagel/WikipediaPageRank.scala +++ b/bagel/src/main/scala/bagel/WikipediaPageRank.scala @@ -4,7 +4,6 @@ import spark._ import spark.SparkContext._ import scala.collection.mutable.ArrayBuffer - import scala.xml.{XML,NodeSeq} import java.io.{Externalizable,ObjectInput,ObjectOutput,DataOutputStream,DataInputStream} @@ -14,7 +13,7 @@ import com.esotericsoftware.kryo._ object WikipediaPageRank { def main(args: Array[String]) { if (args.length < 4) { - System.err.println("Usage: PageRank []") + System.err.println("Usage: WikipediaPageRank []") System.exit(-1) } @@ -52,22 +51,18 @@ object WikipediaPageRank { } val outEdges = ArrayBuffer(links.map(link => new PREdge(new String(link.text))): _*) val id = new String(title) - (id, (new PRVertex(id, 1.0 / numVertices, outEdges, true))) - }) - val graph = vertices.groupByKey(numSplits).mapValues(_.head).cache - + (id, new PRVertex(id, 1.0 / numVertices, outEdges, true)) + }).cache println("Done parsing input file.") - println("Input file had "+graph.count+" vertices.") // Do the computation val epsilon = 0.01 / numVertices + val messages = sc.parallelize(List[(String, PRMessage)]()) val result = if (noCombiner) { - val messages = sc.parallelize(List[(String, PRMessage)]()) - Pregel.run[PRVertex, PRMessage, ArrayBuffer[PRMessage]](sc, graph, messages, numSplits, NoCombiner.messageCombiner, NoCombiner.defaultCombined, NoCombiner.mergeCombined)(NoCombiner.compute(numVertices, epsilon)) + Pregel.run[PRVertex, PRMessage, ArrayBuffer[PRMessage]](sc, vertices, messages, NoCombiner.createCombiner, NoCombiner.mergeMsg, NoCombiner.mergeCombiners, numSplits)(NoCombiner.compute(numVertices, epsilon)) } else { - val messages = sc.parallelize(List[(String, PRMessage)]()) - Pregel.run[PRVertex, PRMessage, Double](sc, graph, messages, numSplits, Combiner.messageCombiner, Combiner.defaultCombined, Combiner.mergeCombined)(Combiner.compute(numVertices, epsilon)) + Pregel.run[PRVertex, PRMessage, Double](sc, vertices, messages, Combiner.createCombiner, Combiner.mergeMsg, Combiner.mergeCombiners, numSplits)(Combiner.compute(numVertices, epsilon)) } // Print the result @@ -78,19 +73,19 @@ object WikipediaPageRank { } object Combiner { - def messageCombiner(minSoFar: Double, message: PRMessage): Double = - minSoFar + message.value + def createCombiner(message: PRMessage): Double = message.value - def mergeCombined(a: Double, b: Double) = a + b + def mergeMsg(combiner: Double, message: PRMessage): Double = + combiner + message.value - def defaultCombined(): Double = 0.0 + def mergeCombiners(a: Double, b: Double) = a + b - def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messageSum: Double, superstep: Int): (PRVertex, Iterable[PRMessage]) = { - val newValue = - if (messageSum != 0) - 0.15 / numVertices + 0.85 * messageSum - else - self.value + def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messageSum: Option[Double], superstep: Int): (PRVertex, Iterable[PRMessage]) = { + val newValue = messageSum match { + case Some(msgSum) if msgSum != 0 => + 0.15 / numVertices + 0.85 * msgSum + case _ => self.value + } val terminate = (superstep >= 10 && (newValue - self.value).abs < epsilon) || superstep >= 30 @@ -106,20 +101,24 @@ object WikipediaPageRank { } object NoCombiner { - def messageCombiner(messagesSoFar: ArrayBuffer[PRMessage], message: PRMessage): ArrayBuffer[PRMessage] = - messagesSoFar += message + def createCombiner(message: PRMessage): ArrayBuffer[PRMessage] = + ArrayBuffer(message) - def mergeCombined(a: ArrayBuffer[PRMessage], b: ArrayBuffer[PRMessage]): ArrayBuffer[PRMessage] = - a ++= b + def mergeMsg(combiner: ArrayBuffer[PRMessage], message: PRMessage): ArrayBuffer[PRMessage] = + combiner += message - def defaultCombined(): ArrayBuffer[PRMessage] = ArrayBuffer[PRMessage]() + def mergeCombiners(a: ArrayBuffer[PRMessage], b: ArrayBuffer[PRMessage]): ArrayBuffer[PRMessage] = + a ++= b - def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messages: Seq[PRMessage], superstep: Int): (PRVertex, Iterable[PRMessage]) = - Combiner.compute(numVertices, epsilon)(self, messages.map(_.value).sum, superstep) + def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messages: Option[ArrayBuffer[PRMessage]], superstep: Int): (PRVertex, Iterable[PRMessage]) = + Combiner.compute(numVertices, epsilon)(self, messages match { + case Some(msgs) => Some(msgs.map(_.value).sum) + case None => None + }, superstep) } } -@serializable class PRVertex() extends Vertex with Externalizable { +@serializable class PRVertex() extends Vertex { var id: String = _ var value: Double = _ var outEdges: ArrayBuffer[PREdge] = _ @@ -132,29 +131,9 @@ object WikipediaPageRank { this.outEdges = outEdges this.active = active } - - def writeExternal(out: ObjectOutput) { - out.writeUTF(id) - out.writeDouble(value) - out.writeInt(outEdges.length) - for (e <- outEdges) - out.writeUTF(e.targetId) - out.writeBoolean(active) - } - - def readExternal(in: ObjectInput) { - id = in.readUTF() - value = in.readDouble() - val numEdges = in.readInt() - outEdges = new ArrayBuffer[PREdge](numEdges) - for (i <- 0 until numEdges) { - outEdges += new PREdge(in.readUTF()) - } - active = in.readBoolean() - } } -@serializable class PRMessage() extends Message with Externalizable { +@serializable class PRMessage() extends Message { var targetId: String = _ var value: Double = _ @@ -163,33 +142,15 @@ object WikipediaPageRank { this.targetId = targetId this.value = value } - - def writeExternal(out: ObjectOutput) { - out.writeUTF(targetId) - out.writeDouble(value) - } - - def readExternal(in: ObjectInput) { - targetId = in.readUTF() - value = in.readDouble() - } } -@serializable class PREdge() extends Edge with Externalizable { +@serializable class PREdge() extends Edge { var targetId: String = _ def this(targetId: String) { this() this.targetId = targetId } - - def writeExternal(out: ObjectOutput) { - out.writeUTF(targetId) - } - - def readExternal(in: ObjectInput) { - targetId = in.readUTF() - } } class PRKryoRegistrator extends KryoRegistrator { -- cgit v1.2.3