From c0736f6f68e47b82e3634252f8dba4f709a33ba5 Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Tue, 12 Apr 2011 17:57:04 -0700 Subject: Add Bagel, an implementation of Pregel on Spark --- bagel/src/main/scala/bagel/Pregel.scala | 103 +++++++++++ bagel/src/main/scala/bagel/ShortestPath.scala | 86 +++++++++ bagel/src/main/scala/bagel/WikipediaPageRank.scala | 201 +++++++++++++++++++++ project/build/SparkProject.scala | 2 + 4 files changed, 392 insertions(+) create mode 100644 bagel/src/main/scala/bagel/Pregel.scala create mode 100644 bagel/src/main/scala/bagel/ShortestPath.scala create mode 100644 bagel/src/main/scala/bagel/WikipediaPageRank.scala diff --git a/bagel/src/main/scala/bagel/Pregel.scala b/bagel/src/main/scala/bagel/Pregel.scala new file mode 100644 index 0000000000..02eec40b2c --- /dev/null +++ b/bagel/src/main/scala/bagel/Pregel.scala @@ -0,0 +1,103 @@ +package bagel + +import spark._ +import spark.SparkContext._ +import scala.collection.mutable.HashMap +import scala.collection.mutable.ArrayBuffer + +object Pregel extends Logging { + /** + * Runs a Pregel job on the given vertices, running the specified + * compute function on each vertex in every superstep. Before + * beginning the first superstep, sends the given messages to their + * destination vertices. In the join stage, launches splits + * separate tasks (where splits is manually specified to work + * around a bug in Spark). + * + * Halts when no more messages are being sent between vertices, and + * all vertices have voted to halt by setting their state to + * Inactive. + */ + def run[V <: Vertex : Manifest, M <: Message : Manifest, C](sc: SparkContext, verts: RDD[(String, V)], msgs: RDD[(String, M)], splits: Int, messageCombiner: (C, M) => C, defaultCombined: () => C, mergeCombined: (C, C) => C, superstep: Int = 0)(compute: (V, C, Int) => (V, Iterable[M])): RDD[V] = { + println("Starting superstep "+superstep+".") + val startTime = System.currentTimeMillis + + // Bring together vertices and messages + println("Joining vertices and messages...") + val combinedMsgs = msgs.combineByKey({x => messageCombiner(defaultCombined(), x)}, messageCombiner, mergeCombined, splits) + println("verts.splits.size = " + verts.splits.size) + println("combinedMsgs.splits.size = " + combinedMsgs.splits.size) + println("verts.partitioner = " + verts.partitioner) + println("combinedMsgs.partitioner = " + combinedMsgs.partitioner) + val joined = verts.groupWith(combinedMsgs) + println("joined.splits.size = " + joined.splits.size) + println("joined.partitioner = " + joined.partitioner) + //val joined = graph.groupByKeyAsymmetrical(messageCombiner, defaultCombined, mergeCombined, splits) + println("Done joining vertices and messages.") + + // Run compute on each vertex + println("Running compute on each vertex...") + var messageCount = sc.accumulator(0) + var activeVertexCount = sc.accumulator(0) + val processed = joined.flatMapValues { + case (Seq(), _) => None + case (Seq(v), Seq(comb)) => + val (newVertex, newMessages) = compute(v, comb, superstep) + messageCount += newMessages.size + if (newVertex.active) + activeVertexCount += 1 + Some((newVertex, newMessages)) + //val result = ArrayBuffer[(String, Either[V, M])]((newVertex.id, Left(newVertex))) + //result ++= newMessages.map(m => (m.targetId, Right(m))) + case (Seq(v), Seq()) => + val (newVertex, newMessages) = compute(v, defaultCombined(), superstep) + messageCount += newMessages.size + if (newVertex.active) + activeVertexCount += 1 + Some((newVertex, newMessages)) + }.cache + //MATEI: Added this + processed.foreach(x => {}) + println("Done running compute on each vertex.") + + println("Checking stopping condition...") + val stop = messageCount.value == 0 && activeVertexCount.value == 0 + + val timeTaken = System.currentTimeMillis - startTime + println("Superstep %d took %d s".format(superstep, timeTaken / 1000)) + + val newVerts = processed.mapValues(_._1) + val newMsgs = processed.flatMap(x => x._2._2.map(m => (m.targetId, m))) + + if (superstep >= 10) + processed.map { _._2._1 } + else + run(sc, newVerts, newMsgs, splits, messageCombiner, defaultCombined, mergeCombined, superstep + 1)(compute) + } +} + +/** + * Represents a Pregel vertex. Must be subclassed to store state + * along with each vertex. Must be annotated with @serializable. + */ +trait Vertex { + def id: String + def active: Boolean +} + +/** + * Represents a Pregel message to a target vertex. Must be + * subclassed to contain a payload. Must be annotated with @serializable. + */ +trait Message { + def targetId: String +} + +/** + * Represents a directed edge between two vertices. Owned by the + * source vertex, and contains the ID of the target vertex. Must + * be subclassed to store state along with each edge. Must be annotated with @serializable. + */ +trait Edge { + def targetId: String +} diff --git a/bagel/src/main/scala/bagel/ShortestPath.scala b/bagel/src/main/scala/bagel/ShortestPath.scala new file mode 100644 index 0000000000..2af4dc5867 --- /dev/null +++ b/bagel/src/main/scala/bagel/ShortestPath.scala @@ -0,0 +1,86 @@ +package bagel + +import spark._ +import spark.SparkContext._ + +import scala.math.min + +/* +object ShortestPath { + def main(args: Array[String]) { + if (args.length < 4) { + System.err.println("Usage: ShortestPath " + + " ") + System.exit(-1) + } + + val graphFile = args(0) + val startVertex = args(1) + val numSplits = args(2).toInt + val host = args(3) + val sc = new SparkContext(host, "ShortestPath") + + // Parse the graph data from a file into two RDDs, vertices and messages + val lines = + (sc.textFile(graphFile) + .filter(!_.matches("^\\s*#.*")) + .map(line => line.split("\t"))) + + val vertices: RDD[(String, Either[SPVertex, SPMessage])] = + (lines.groupBy(line => line(0)) + .map { + case (vertexId, lines) => { + val outEdges = lines.collect { + case Array(_, targetId, edgeValue) => + new SPEdge(targetId, edgeValue.toInt) + } + + (vertexId, Left[SPVertex, SPMessage](new SPVertex(vertexId, Int.MaxValue, outEdges, true))) + } + }) + + val messages: RDD[(String, Either[SPVertex, SPMessage])] = + (lines.filter(_.length == 2) + .map { + case Array(vertexId, messageValue) => + (vertexId, Right[SPVertex, SPMessage](new SPMessage(vertexId, messageValue.toInt))) + }) + + val graph: RDD[(String, Either[SPVertex, SPMessage])] = vertices ++ messages + + System.err.println("Read "+vertices.count()+" vertices and "+ + messages.count()+" messages.") + + // Do the computation + def messageCombiner(minSoFar: Int, message: SPMessage): Int = + min(minSoFar, message.value) + + val result = Pregel.run(sc, graph, numSplits, messageCombiner, () => Int.MaxValue, min _) { + (self: SPVertex, messageMinValue: Int, superstep: Int) => + val newValue = min(self.value, messageMinValue) + + val outbox = + if (newValue != self.value) + self.outEdges.map(edge => + new SPMessage(edge.targetId, newValue + edge.value)) + else + List() + + (new SPVertex(self.id, newValue, self.outEdges, false), outbox) + } + + // Print the result + System.err.println("Shortest path from "+startVertex+" to all vertices:") + val shortest = result.map(vertex => + "%s\t%s\n".format(vertex.id, vertex.value match { + case x if x == Int.MaxValue => "inf" + case x => x + })).collect.mkString + println(shortest) + } +} + +@serializable class SPVertex(val id: String, val value: Int, val outEdges: Seq[SPEdge], val active: Boolean) extends Vertex +@serializable class SPEdge(val targetId: String, val value: Int) extends Edge +@serializable class SPMessage(val targetId: String, val value: Int) extends Message +*/ diff --git a/bagel/src/main/scala/bagel/WikipediaPageRank.scala b/bagel/src/main/scala/bagel/WikipediaPageRank.scala new file mode 100644 index 0000000000..a98fd371e1 --- /dev/null +++ b/bagel/src/main/scala/bagel/WikipediaPageRank.scala @@ -0,0 +1,201 @@ +package bagel + +import spark._ +import spark.SparkContext._ + +import scala.collection.mutable.ArrayBuffer + +import scala.xml.{XML,NodeSeq} + +import java.io.{Externalizable,ObjectInput,ObjectOutput,DataOutputStream,DataInputStream} + +import com.esotericsoftware.kryo._ + +object WikipediaPageRank { + def main(args: Array[String]) { + if (args.length < 4) { + System.err.println("Usage: PageRank []") + System.exit(-1) + } + + System.setProperty("spark.serialization", "spark.KryoSerialization") + System.setProperty("spark.kryo.registrator", classOf[PRKryoRegistrator].getName) + + val inputFile = args(0) + val threshold = args(1).toDouble + val numSplits = args(2).toInt + val host = args(3) + val noCombiner = args.length > 4 && args(4).nonEmpty + val sc = new SparkContext(host, "WikipediaPageRank") + + // Parse the Wikipedia page data into a graph + val input = sc.textFile(inputFile) + + println("Counting vertices...") + val numVertices = input.count() + println("Done counting vertices.") + + println("Parsing input file...") + val vertices: RDD[(String, PRVertex)] = input.map(line => { + val fields = line.split("\t") + val (title, body) = (fields(1), fields(3).replace("\\n", "\n")) + val links = + if (body == "\\N") + NodeSeq.Empty + else + try { + XML.loadString(body) \\ "link" \ "target" + } catch { + case e: org.xml.sax.SAXParseException => + System.err.println("Article \""+title+"\" has malformed XML in body:\n"+body) + NodeSeq.Empty + } + val outEdges = ArrayBuffer(links.map(link => new PREdge(new String(link.text))): _*) + val id = new String(title) + (id, (new PRVertex(id, 1.0 / numVertices, outEdges, true))) + }) + val graph = vertices.groupByKey(numSplits).mapValues(_.head).cache + + println("Done parsing input file.") + println("Input file had "+graph.count+" vertices.") + + // Do the computation + val epsilon = 0.01 / numVertices + val result = + if (noCombiner) { + val messages = sc.parallelize(List[(String, PRMessage)]()) + Pregel.run[PRVertex, PRMessage, ArrayBuffer[PRMessage]](sc, graph, messages, numSplits, NoCombiner.messageCombiner, NoCombiner.defaultCombined, NoCombiner.mergeCombined)(NoCombiner.compute(numVertices, epsilon)) + } else { + val messages = sc.parallelize(List[(String, PRMessage)]()) + Pregel.run[PRVertex, PRMessage, Double](sc, graph, messages, numSplits, Combiner.messageCombiner, Combiner.defaultCombined, Combiner.mergeCombined)(Combiner.compute(numVertices, epsilon)) + } + + // Print the result + System.err.println("Articles with PageRank >= "+threshold+":") + val top = result.filter(_.value >= threshold).map(vertex => + "%s\t%s\n".format(vertex.id, vertex.value)).collect.mkString + println(top) + } + + object Combiner { + def messageCombiner(minSoFar: Double, message: PRMessage): Double = + minSoFar + message.value + + def mergeCombined(a: Double, b: Double) = a + b + + def defaultCombined(): Double = 0.0 + + def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messageSum: Double, superstep: Int): (PRVertex, Iterable[PRMessage]) = { + val newValue = + if (messageSum != 0) + 0.15 / numVertices + 0.85 * messageSum + else + self.value + + val terminate = (superstep >= 10 && (newValue - self.value).abs < epsilon) || superstep >= 30 + + val outbox = + if (!terminate) + self.outEdges.map(edge => + new PRMessage(edge.targetId, newValue / self.outEdges.size)) + else + ArrayBuffer[PRMessage]() + + (new PRVertex(self.id, newValue, self.outEdges, !terminate), outbox) + } + } + + object NoCombiner { + def messageCombiner(messagesSoFar: ArrayBuffer[PRMessage], message: PRMessage): ArrayBuffer[PRMessage] = + messagesSoFar += message + + def mergeCombined(a: ArrayBuffer[PRMessage], b: ArrayBuffer[PRMessage]): ArrayBuffer[PRMessage] = + a ++= b + + def defaultCombined(): ArrayBuffer[PRMessage] = ArrayBuffer[PRMessage]() + + def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messages: Seq[PRMessage], superstep: Int): (PRVertex, Iterable[PRMessage]) = + Combiner.compute(numVertices, epsilon)(self, messages.map(_.value).sum, superstep) + } +} + +@serializable class PRVertex() extends Vertex with Externalizable { + var id: String = _ + var value: Double = _ + var outEdges: ArrayBuffer[PREdge] = _ + var active: Boolean = true + + def this(id: String, value: Double, outEdges: ArrayBuffer[PREdge], active: Boolean) { + this() + this.id = id + this.value = value + this.outEdges = outEdges + this.active = active + } + + def writeExternal(out: ObjectOutput) { + out.writeUTF(id) + out.writeDouble(value) + out.writeInt(outEdges.length) + for (e <- outEdges) + out.writeUTF(e.targetId) + out.writeBoolean(active) + } + + def readExternal(in: ObjectInput) { + id = in.readUTF() + value = in.readDouble() + val numEdges = in.readInt() + outEdges = new ArrayBuffer[PREdge](numEdges) + for (i <- 0 until numEdges) { + outEdges += new PREdge(in.readUTF()) + } + active = in.readBoolean() + } +} + +@serializable class PRMessage() extends Message with Externalizable { + var targetId: String = _ + var value: Double = _ + + def this(targetId: String, value: Double) { + this() + this.targetId = targetId + this.value = value + } + + def writeExternal(out: ObjectOutput) { + out.writeUTF(targetId) + out.writeDouble(value) + } + + def readExternal(in: ObjectInput) { + targetId = in.readUTF() + value = in.readDouble() + } +} + +@serializable class PREdge() extends Edge with Externalizable { + var targetId: String = _ + + def this(targetId: String) { + this() + this.targetId = targetId + } + + def writeExternal(out: ObjectOutput) { + out.writeUTF(targetId) + } + + def readExternal(in: ObjectInput) { + targetId = in.readUTF() + } +} + +class PRKryoRegistrator extends KryoRegistrator { + def registerClasses(kryo: Kryo) { + kryo.register(classOf[PRVertex]) + kryo.register(classOf[PRMessage]) + kryo.register(classOf[PREdge]) + } +} diff --git a/project/build/SparkProject.scala b/project/build/SparkProject.scala index 484daf5c50..a6ee25bc3d 100644 --- a/project/build/SparkProject.scala +++ b/project/build/SparkProject.scala @@ -14,6 +14,8 @@ extends ParentProject(info) with IdeaProject lazy val examples = project("examples", "Spark Examples", new ExamplesProject(_), core) + lazy val bagel = project("bagel", "Bagel", core) + class CoreProject(info: ProjectInfo) extends DefaultProject(info) with Eclipsify with IdeaProject with DepJar with XmlTestReport {} -- cgit v1.2.3