From f40a0898a7f627f0d66f8393f724b518c50fba09 Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Mon, 9 May 2011 15:23:21 -0700 Subject: Rename bagel to spark.bagel and Pregel to Bagel --- bagel/src/main/scala/bagel/Pregel.scala | 159 --------------------- .../main/scala/bagel/examples/ShortestPath.scala | 96 ------------- .../scala/bagel/examples/WikipediaPageRank.scala | 158 -------------------- bagel/src/main/scala/spark/bagel/Bagel.scala | 159 +++++++++++++++++++++ .../scala/spark/bagel/examples/ShortestPath.scala | 96 +++++++++++++ .../spark/bagel/examples/WikipediaPageRank.scala | 158 ++++++++++++++++++++ bagel/src/test/scala/bagel/BagelSuite.scala | 8 +- 7 files changed, 417 insertions(+), 417 deletions(-) delete mode 100644 bagel/src/main/scala/bagel/Pregel.scala delete mode 100644 bagel/src/main/scala/bagel/examples/ShortestPath.scala delete mode 100644 bagel/src/main/scala/bagel/examples/WikipediaPageRank.scala create mode 100644 bagel/src/main/scala/spark/bagel/Bagel.scala create mode 100644 bagel/src/main/scala/spark/bagel/examples/ShortestPath.scala create mode 100644 bagel/src/main/scala/spark/bagel/examples/WikipediaPageRank.scala (limited to 'bagel/src') diff --git a/bagel/src/main/scala/bagel/Pregel.scala b/bagel/src/main/scala/bagel/Pregel.scala deleted file mode 100644 index 2ec94ebb0d..0000000000 --- a/bagel/src/main/scala/bagel/Pregel.scala +++ /dev/null @@ -1,159 +0,0 @@ -package bagel - -import spark._ -import spark.SparkContext._ - -import scala.collection.mutable.ArrayBuffer - -object Pregel extends Logging { - def run[V <: Vertex : Manifest, M <: Message : Manifest, C : Manifest, A : Manifest]( - sc: SparkContext, - verts: RDD[(String, V)], - msgs: RDD[(String, M)] - )( - combiner: Combiner[M, C] = new DefaultCombiner[M], - aggregator: Aggregator[V, A] = new NullAggregator[V], - superstep: Int = 0, - numSplits: Int = sc.numCores - )( - compute: (V, Option[C], A, Int) => (V, Iterable[M]) - ): RDD[V] = { - - logInfo("Starting superstep "+superstep+".") - val startTime = System.currentTimeMillis - - val aggregated = agg(verts, aggregator) - val combinedMsgs = msgs.combineByKey(combiner.createCombiner, combiner.mergeMsg, combiner.mergeCombiners, numSplits) - val grouped = verts.groupWith(combinedMsgs) - val (processed, numMsgs, numActiveVerts) = comp[V, M, C](sc, grouped, compute(_, _, aggregated, superstep)) - - val timeTaken = System.currentTimeMillis - startTime - logInfo("Superstep %d took %d s".format(superstep, timeTaken / 1000)) - - // Check stopping condition and iterate - val noActivity = numMsgs == 0 && numActiveVerts == 0 - if (noActivity) { - processed.map { case (id, (vert, msgs)) => vert } - } else { - val newVerts = processed.mapValues { case (vert, msgs) => vert } - val newMsgs = processed.flatMap { - case (id, (vert, msgs)) => msgs.map(m => (m.targetId, m)) - } - run(sc, newVerts, newMsgs)(combiner, aggregator, superstep + 1, numSplits)(compute) - } - } - - /** - * Aggregates the given vertices using the given aggregator, or does - * nothing if it is a NullAggregator. - */ - def agg[V <: Vertex, A : Manifest](verts: RDD[(String, V)], aggregator: Aggregator[V, A]): A = aggregator match { - case _: NullAggregator[_] => - None - case _ => - verts.map { - case (id, vert) => aggregator.createAggregator(vert) - }.reduce(aggregator.mergeAggregators(_, _)) - } - - /** - * Processes the given vertex-message RDD using the compute - * function. Returns the processed RDD, the number of messages - * created, and the number of active vertices. - */ - def comp[V <: Vertex, M <: Message, C](sc: SparkContext, grouped: RDD[(String, (Seq[V], Seq[C]))], compute: (V, Option[C]) => (V, Iterable[M])): (RDD[(String, (V, Iterable[M]))], Int, Int) = { - var numMsgs = sc.accumulator(0) - var numActiveVerts = sc.accumulator(0) - val processed = grouped.flatMapValues { - case (Seq(), _) => None - case (Seq(v), c) => - val (newVert, newMsgs) = - compute(v, c match { - case Seq(comb) => Some(comb) - case Seq() => None - }) - - numMsgs += newMsgs.size - if (newVert.active) - numActiveVerts += 1 - - Some((newVert, newMsgs)) - }.cache - - // Force evaluation of processed RDD for accurate performance measurements - processed.foreach(x => {}) - - (processed, numMsgs.value, numActiveVerts.value) - } - - /** - * Converts a compute function that doesn't take an aggregator to - * one that does, so it can be passed to Pregel.run. - */ - implicit def addAggregatorArg[ - V <: Vertex : Manifest, M <: Message : Manifest, C - ]( - compute: (V, Option[C], Int) => (V, Iterable[M]) - ): (V, Option[C], Option[Nothing], Int) => (V, Iterable[M]) = { - (vert: V, messages: Option[C], aggregator: Option[Nothing], superstep: Int) => compute(vert, messages, superstep) - } -} - -// TODO: Simplify Combiner interface and make it more OO. -trait Combiner[M, C] { - def createCombiner(msg: M): C - def mergeMsg(combiner: C, msg: M): C - def mergeCombiners(a: C, b: C): C -} - -trait Aggregator[V, A] { - def createAggregator(vert: V): A - def mergeAggregators(a: A, b: A): A -} - -@serializable -class DefaultCombiner[M] extends Combiner[M, ArrayBuffer[M]] { - def createCombiner(msg: M): ArrayBuffer[M] = - ArrayBuffer(msg) - def mergeMsg(combiner: ArrayBuffer[M], msg: M): ArrayBuffer[M] = - combiner += msg - def mergeCombiners(a: ArrayBuffer[M], b: ArrayBuffer[M]): ArrayBuffer[M] = - a ++= b -} - -@serializable -class NullAggregator[V] extends Aggregator[V, Option[Nothing]] { - def createAggregator(vert: V): Option[Nothing] = None - def mergeAggregators(a: Option[Nothing], b: Option[Nothing]): Option[Nothing] = None -} - -/** - * Represents a Pregel vertex. - * - * Subclasses may store state along with each vertex and must be - * annotated with @serializable. - */ -trait Vertex { - def id: String - def active: Boolean -} - -/** - * Represents a Pregel message to a target vertex. - * - * Subclasses may contain a payload to deliver to the target vertex - * and must be annotated with @serializable. - */ -trait Message { - def targetId: String -} - -/** - * Represents a directed edge between two vertices. - * - * Subclasses may store state along each edge and must be annotated - * with @serializable. - */ -trait Edge { - def targetId: String -} diff --git a/bagel/src/main/scala/bagel/examples/ShortestPath.scala b/bagel/src/main/scala/bagel/examples/ShortestPath.scala deleted file mode 100644 index 2e6100c070..0000000000 --- a/bagel/src/main/scala/bagel/examples/ShortestPath.scala +++ /dev/null @@ -1,96 +0,0 @@ -package bagel.examples - -import spark._ -import spark.SparkContext._ - -import scala.math.min - -import bagel._ -import bagel.Pregel._ - -object ShortestPath { - def main(args: Array[String]) { - if (args.length < 4) { - System.err.println("Usage: ShortestPath " + - " ") - System.exit(-1) - } - - val graphFile = args(0) - val startVertex = args(1) - val numSplits = args(2).toInt - val host = args(3) - val sc = new SparkContext(host, "ShortestPath") - - // Parse the graph data from a file into two RDDs, vertices and messages - val lines = - (sc.textFile(graphFile) - .filter(!_.matches("^\\s*#.*")) - .map(line => line.split("\t"))) - - val vertices: RDD[(String, SPVertex)] = - (lines.groupBy(line => line(0)) - .map { - case (vertexId, lines) => { - val outEdges = lines.collect { - case Array(_, targetId, edgeValue) => - new SPEdge(targetId, edgeValue.toInt) - } - - (vertexId, new SPVertex(vertexId, Int.MaxValue, outEdges, true)) - } - }) - - val messages: RDD[(String, SPMessage)] = - (lines.filter(_.length == 2) - .map { - case Array(vertexId, messageValue) => - (vertexId, new SPMessage(vertexId, messageValue.toInt)) - }) - - System.err.println("Read "+vertices.count()+" vertices and "+ - messages.count()+" messages.") - - // Do the computation - val compute = addAggregatorArg { - (self: SPVertex, messageMinValue: Option[Int], superstep: Int) => - val newValue = messageMinValue match { - case Some(minVal) => min(self.value, minVal) - case None => self.value - } - - val outbox = - if (newValue != self.value) - self.outEdges.map(edge => - new SPMessage(edge.targetId, newValue + edge.value)) - else - List() - - (new SPVertex(self.id, newValue, self.outEdges, false), outbox) - } - val result = Pregel.run(sc, vertices, messages)(combiner = MinCombiner, numSplits = numSplits)(compute) - - // Print the result - System.err.println("Shortest path from "+startVertex+" to all vertices:") - val shortest = result.map(vertex => - "%s\t%s\n".format(vertex.id, vertex.value match { - case x if x == Int.MaxValue => "inf" - case x => x - })).collect.mkString - println(shortest) - } -} - -@serializable -object MinCombiner extends Combiner[SPMessage, Int] { - def createCombiner(msg: SPMessage): Int = - msg.value - def mergeMsg(combiner: Int, msg: SPMessage): Int = - min(combiner, msg.value) - def mergeCombiners(a: Int, b: Int): Int = - min(a, b) -} - -@serializable class SPVertex(val id: String, val value: Int, val outEdges: Seq[SPEdge], val active: Boolean) extends Vertex -@serializable class SPEdge(val targetId: String, val value: Int) extends Edge -@serializable class SPMessage(val targetId: String, val value: Int) extends Message diff --git a/bagel/src/main/scala/bagel/examples/WikipediaPageRank.scala b/bagel/src/main/scala/bagel/examples/WikipediaPageRank.scala deleted file mode 100644 index a5e0a9ffb6..0000000000 --- a/bagel/src/main/scala/bagel/examples/WikipediaPageRank.scala +++ /dev/null @@ -1,158 +0,0 @@ -package bagel.examples - -import spark._ -import spark.SparkContext._ - -import bagel._ -import bagel.Pregel._ - -import scala.collection.mutable.ArrayBuffer -import scala.xml.{XML,NodeSeq} - -import java.io.{Externalizable,ObjectInput,ObjectOutput,DataOutputStream,DataInputStream} - -import com.esotericsoftware.kryo._ - -object WikipediaPageRank { - def main(args: Array[String]) { - if (args.length < 4) { - System.err.println("Usage: WikipediaPageRank []") - System.exit(-1) - } - - System.setProperty("spark.serialization", "spark.KryoSerialization") - System.setProperty("spark.kryo.registrator", classOf[PRKryoRegistrator].getName) - - val inputFile = args(0) - val threshold = args(1).toDouble - val numSplits = args(2).toInt - val host = args(3) - val noCombiner = args.length > 4 && args(4).nonEmpty - val sc = new SparkContext(host, "WikipediaPageRank") - - // Parse the Wikipedia page data into a graph - val input = sc.textFile(inputFile) - - println("Counting vertices...") - val numVertices = input.count() - println("Done counting vertices.") - - println("Parsing input file...") - val vertices: RDD[(String, PRVertex)] = input.map(line => { - val fields = line.split("\t") - val (title, body) = (fields(1), fields(3).replace("\\n", "\n")) - val links = - if (body == "\\N") - NodeSeq.Empty - else - try { - XML.loadString(body) \\ "link" \ "target" - } catch { - case e: org.xml.sax.SAXParseException => - System.err.println("Article \""+title+"\" has malformed XML in body:\n"+body) - NodeSeq.Empty - } - val outEdges = ArrayBuffer(links.map(link => new PREdge(new String(link.text))): _*) - val id = new String(title) - (id, new PRVertex(id, 1.0 / numVertices, outEdges, true)) - }).cache - println("Done parsing input file.") - - // Do the computation - val epsilon = 0.01 / numVertices - val messages = sc.parallelize(List[(String, PRMessage)]()) - val result = - if (noCombiner) { - Pregel.run(sc, vertices, messages)(numSplits = numSplits)(PRNoCombiner.compute(numVertices, epsilon)) - } else { - Pregel.run(sc, vertices, messages)(combiner = PRCombiner, numSplits = numSplits)(PRCombiner.compute(numVertices, epsilon)) - } - - // Print the result - System.err.println("Articles with PageRank >= "+threshold+":") - val top = result.filter(_.value >= threshold).map(vertex => - "%s\t%s\n".format(vertex.id, vertex.value)).collect.mkString - println(top) - } -} - -@serializable -object PRCombiner extends Combiner[PRMessage, Double] { - def createCombiner(msg: PRMessage): Double = - msg.value - def mergeMsg(combiner: Double, msg: PRMessage): Double = - combiner + msg.value - def mergeCombiners(a: Double, b: Double): Double = - a + b - - def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messageSum: Option[Double], superstep: Int): (PRVertex, Iterable[PRMessage]) = { - val newValue = messageSum match { - case Some(msgSum) if msgSum != 0 => - 0.15 / numVertices + 0.85 * msgSum - case _ => self.value - } - - val terminate = (superstep >= 10 && (newValue - self.value).abs < epsilon) || superstep >= 30 - - val outbox = - if (!terminate) - self.outEdges.map(edge => - new PRMessage(edge.targetId, newValue / self.outEdges.size)) - else - ArrayBuffer[PRMessage]() - - (new PRVertex(self.id, newValue, self.outEdges, !terminate), outbox) - } -} - -@serializable -object PRNoCombiner extends DefaultCombiner[PRMessage] { - def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messages: Option[ArrayBuffer[PRMessage]], superstep: Int): (PRVertex, Iterable[PRMessage]) = - PRCombiner.compute(numVertices, epsilon)(self, messages match { - case Some(msgs) => Some(msgs.map(_.value).sum) - case None => None - }, superstep) -} - -@serializable class PRVertex() extends Vertex { - var id: String = _ - var value: Double = _ - var outEdges: ArrayBuffer[PREdge] = _ - var active: Boolean = true - - def this(id: String, value: Double, outEdges: ArrayBuffer[PREdge], active: Boolean) { - this() - this.id = id - this.value = value - this.outEdges = outEdges - this.active = active - } -} - -@serializable class PRMessage() extends Message { - var targetId: String = _ - var value: Double = _ - - def this(targetId: String, value: Double) { - this() - this.targetId = targetId - this.value = value - } -} - -@serializable class PREdge() extends Edge { - var targetId: String = _ - - def this(targetId: String) { - this() - this.targetId = targetId - } -} - -class PRKryoRegistrator extends KryoRegistrator { - def registerClasses(kryo: Kryo) { - kryo.register(classOf[PRVertex]) - kryo.register(classOf[PRMessage]) - kryo.register(classOf[PREdge]) - } -} diff --git a/bagel/src/main/scala/spark/bagel/Bagel.scala b/bagel/src/main/scala/spark/bagel/Bagel.scala new file mode 100644 index 0000000000..08ff1d8a01 --- /dev/null +++ b/bagel/src/main/scala/spark/bagel/Bagel.scala @@ -0,0 +1,159 @@ +package spark.bagel + +import spark._ +import spark.SparkContext._ + +import scala.collection.mutable.ArrayBuffer + +object Bagel extends Logging { + def run[V <: Vertex : Manifest, M <: Message : Manifest, C : Manifest, A : Manifest]( + sc: SparkContext, + verts: RDD[(String, V)], + msgs: RDD[(String, M)] + )( + combiner: Combiner[M, C] = new DefaultCombiner[M], + aggregator: Aggregator[V, A] = new NullAggregator[V], + superstep: Int = 0, + numSplits: Int = sc.numCores + )( + compute: (V, Option[C], A, Int) => (V, Iterable[M]) + ): RDD[V] = { + + logInfo("Starting superstep "+superstep+".") + val startTime = System.currentTimeMillis + + val aggregated = agg(verts, aggregator) + val combinedMsgs = msgs.combineByKey(combiner.createCombiner, combiner.mergeMsg, combiner.mergeCombiners, numSplits) + val grouped = verts.groupWith(combinedMsgs) + val (processed, numMsgs, numActiveVerts) = comp[V, M, C](sc, grouped, compute(_, _, aggregated, superstep)) + + val timeTaken = System.currentTimeMillis - startTime + logInfo("Superstep %d took %d s".format(superstep, timeTaken / 1000)) + + // Check stopping condition and iterate + val noActivity = numMsgs == 0 && numActiveVerts == 0 + if (noActivity) { + processed.map { case (id, (vert, msgs)) => vert } + } else { + val newVerts = processed.mapValues { case (vert, msgs) => vert } + val newMsgs = processed.flatMap { + case (id, (vert, msgs)) => msgs.map(m => (m.targetId, m)) + } + run(sc, newVerts, newMsgs)(combiner, aggregator, superstep + 1, numSplits)(compute) + } + } + + /** + * Aggregates the given vertices using the given aggregator, or does + * nothing if it is a NullAggregator. + */ + def agg[V <: Vertex, A : Manifest](verts: RDD[(String, V)], aggregator: Aggregator[V, A]): A = aggregator match { + case _: NullAggregator[_] => + None + case _ => + verts.map { + case (id, vert) => aggregator.createAggregator(vert) + }.reduce(aggregator.mergeAggregators(_, _)) + } + + /** + * Processes the given vertex-message RDD using the compute + * function. Returns the processed RDD, the number of messages + * created, and the number of active vertices. + */ + def comp[V <: Vertex, M <: Message, C](sc: SparkContext, grouped: RDD[(String, (Seq[V], Seq[C]))], compute: (V, Option[C]) => (V, Iterable[M])): (RDD[(String, (V, Iterable[M]))], Int, Int) = { + var numMsgs = sc.accumulator(0) + var numActiveVerts = sc.accumulator(0) + val processed = grouped.flatMapValues { + case (Seq(), _) => None + case (Seq(v), c) => + val (newVert, newMsgs) = + compute(v, c match { + case Seq(comb) => Some(comb) + case Seq() => None + }) + + numMsgs += newMsgs.size + if (newVert.active) + numActiveVerts += 1 + + Some((newVert, newMsgs)) + }.cache + + // Force evaluation of processed RDD for accurate performance measurements + processed.foreach(x => {}) + + (processed, numMsgs.value, numActiveVerts.value) + } + + /** + * Converts a compute function that doesn't take an aggregator to + * one that does, so it can be passed to Bagel.run. + */ + implicit def addAggregatorArg[ + V <: Vertex : Manifest, M <: Message : Manifest, C + ]( + compute: (V, Option[C], Int) => (V, Iterable[M]) + ): (V, Option[C], Option[Nothing], Int) => (V, Iterable[M]) = { + (vert: V, messages: Option[C], aggregator: Option[Nothing], superstep: Int) => compute(vert, messages, superstep) + } +} + +// TODO: Simplify Combiner interface and make it more OO. +trait Combiner[M, C] { + def createCombiner(msg: M): C + def mergeMsg(combiner: C, msg: M): C + def mergeCombiners(a: C, b: C): C +} + +trait Aggregator[V, A] { + def createAggregator(vert: V): A + def mergeAggregators(a: A, b: A): A +} + +@serializable +class DefaultCombiner[M] extends Combiner[M, ArrayBuffer[M]] { + def createCombiner(msg: M): ArrayBuffer[M] = + ArrayBuffer(msg) + def mergeMsg(combiner: ArrayBuffer[M], msg: M): ArrayBuffer[M] = + combiner += msg + def mergeCombiners(a: ArrayBuffer[M], b: ArrayBuffer[M]): ArrayBuffer[M] = + a ++= b +} + +@serializable +class NullAggregator[V] extends Aggregator[V, Option[Nothing]] { + def createAggregator(vert: V): Option[Nothing] = None + def mergeAggregators(a: Option[Nothing], b: Option[Nothing]): Option[Nothing] = None +} + +/** + * Represents a Bagel vertex. + * + * Subclasses may store state along with each vertex and must be + * annotated with @serializable. + */ +trait Vertex { + def id: String + def active: Boolean +} + +/** + * Represents a Bagel message to a target vertex. + * + * Subclasses may contain a payload to deliver to the target vertex + * and must be annotated with @serializable. + */ +trait Message { + def targetId: String +} + +/** + * Represents a directed edge between two vertices. + * + * Subclasses may store state along each edge and must be annotated + * with @serializable. + */ +trait Edge { + def targetId: String +} diff --git a/bagel/src/main/scala/spark/bagel/examples/ShortestPath.scala b/bagel/src/main/scala/spark/bagel/examples/ShortestPath.scala new file mode 100644 index 0000000000..a7fd386310 --- /dev/null +++ b/bagel/src/main/scala/spark/bagel/examples/ShortestPath.scala @@ -0,0 +1,96 @@ +package spark.bagel.examples + +import spark._ +import spark.SparkContext._ + +import scala.math.min + +import spark.bagel._ +import spark.bagel.Bagel._ + +object ShortestPath { + def main(args: Array[String]) { + if (args.length < 4) { + System.err.println("Usage: ShortestPath " + + " ") + System.exit(-1) + } + + val graphFile = args(0) + val startVertex = args(1) + val numSplits = args(2).toInt + val host = args(3) + val sc = new SparkContext(host, "ShortestPath") + + // Parse the graph data from a file into two RDDs, vertices and messages + val lines = + (sc.textFile(graphFile) + .filter(!_.matches("^\\s*#.*")) + .map(line => line.split("\t"))) + + val vertices: RDD[(String, SPVertex)] = + (lines.groupBy(line => line(0)) + .map { + case (vertexId, lines) => { + val outEdges = lines.collect { + case Array(_, targetId, edgeValue) => + new SPEdge(targetId, edgeValue.toInt) + } + + (vertexId, new SPVertex(vertexId, Int.MaxValue, outEdges, true)) + } + }) + + val messages: RDD[(String, SPMessage)] = + (lines.filter(_.length == 2) + .map { + case Array(vertexId, messageValue) => + (vertexId, new SPMessage(vertexId, messageValue.toInt)) + }) + + System.err.println("Read "+vertices.count()+" vertices and "+ + messages.count()+" messages.") + + // Do the computation + val compute = addAggregatorArg { + (self: SPVertex, messageMinValue: Option[Int], superstep: Int) => + val newValue = messageMinValue match { + case Some(minVal) => min(self.value, minVal) + case None => self.value + } + + val outbox = + if (newValue != self.value) + self.outEdges.map(edge => + new SPMessage(edge.targetId, newValue + edge.value)) + else + List() + + (new SPVertex(self.id, newValue, self.outEdges, false), outbox) + } + val result = Bagel.run(sc, vertices, messages)(combiner = MinCombiner, numSplits = numSplits)(compute) + + // Print the result + System.err.println("Shortest path from "+startVertex+" to all vertices:") + val shortest = result.map(vertex => + "%s\t%s\n".format(vertex.id, vertex.value match { + case x if x == Int.MaxValue => "inf" + case x => x + })).collect.mkString + println(shortest) + } +} + +@serializable +object MinCombiner extends Combiner[SPMessage, Int] { + def createCombiner(msg: SPMessage): Int = + msg.value + def mergeMsg(combiner: Int, msg: SPMessage): Int = + min(combiner, msg.value) + def mergeCombiners(a: Int, b: Int): Int = + min(a, b) +} + +@serializable class SPVertex(val id: String, val value: Int, val outEdges: Seq[SPEdge], val active: Boolean) extends Vertex +@serializable class SPEdge(val targetId: String, val value: Int) extends Edge +@serializable class SPMessage(val targetId: String, val value: Int) extends Message diff --git a/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRank.scala b/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRank.scala new file mode 100644 index 0000000000..1bce5bebad --- /dev/null +++ b/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRank.scala @@ -0,0 +1,158 @@ +package spark.bagel.examples + +import spark._ +import spark.SparkContext._ + +import spark.bagel._ +import spark.bagel.Bagel._ + +import scala.collection.mutable.ArrayBuffer +import scala.xml.{XML,NodeSeq} + +import java.io.{Externalizable,ObjectInput,ObjectOutput,DataOutputStream,DataInputStream} + +import com.esotericsoftware.kryo._ + +object WikipediaPageRank { + def main(args: Array[String]) { + if (args.length < 4) { + System.err.println("Usage: WikipediaPageRank []") + System.exit(-1) + } + + System.setProperty("spark.serialization", "spark.KryoSerialization") + System.setProperty("spark.kryo.registrator", classOf[PRKryoRegistrator].getName) + + val inputFile = args(0) + val threshold = args(1).toDouble + val numSplits = args(2).toInt + val host = args(3) + val noCombiner = args.length > 4 && args(4).nonEmpty + val sc = new SparkContext(host, "WikipediaPageRank") + + // Parse the Wikipedia page data into a graph + val input = sc.textFile(inputFile) + + println("Counting vertices...") + val numVertices = input.count() + println("Done counting vertices.") + + println("Parsing input file...") + val vertices: RDD[(String, PRVertex)] = input.map(line => { + val fields = line.split("\t") + val (title, body) = (fields(1), fields(3).replace("\\n", "\n")) + val links = + if (body == "\\N") + NodeSeq.Empty + else + try { + XML.loadString(body) \\ "link" \ "target" + } catch { + case e: org.xml.sax.SAXParseException => + System.err.println("Article \""+title+"\" has malformed XML in body:\n"+body) + NodeSeq.Empty + } + val outEdges = ArrayBuffer(links.map(link => new PREdge(new String(link.text))): _*) + val id = new String(title) + (id, new PRVertex(id, 1.0 / numVertices, outEdges, true)) + }).cache + println("Done parsing input file.") + + // Do the computation + val epsilon = 0.01 / numVertices + val messages = sc.parallelize(List[(String, PRMessage)]()) + val result = + if (noCombiner) { + Bagel.run(sc, vertices, messages)(numSplits = numSplits)(PRNoCombiner.compute(numVertices, epsilon)) + } else { + Bagel.run(sc, vertices, messages)(combiner = PRCombiner, numSplits = numSplits)(PRCombiner.compute(numVertices, epsilon)) + } + + // Print the result + System.err.println("Articles with PageRank >= "+threshold+":") + val top = result.filter(_.value >= threshold).map(vertex => + "%s\t%s\n".format(vertex.id, vertex.value)).collect.mkString + println(top) + } +} + +@serializable +object PRCombiner extends Combiner[PRMessage, Double] { + def createCombiner(msg: PRMessage): Double = + msg.value + def mergeMsg(combiner: Double, msg: PRMessage): Double = + combiner + msg.value + def mergeCombiners(a: Double, b: Double): Double = + a + b + + def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messageSum: Option[Double], superstep: Int): (PRVertex, Iterable[PRMessage]) = { + val newValue = messageSum match { + case Some(msgSum) if msgSum != 0 => + 0.15 / numVertices + 0.85 * msgSum + case _ => self.value + } + + val terminate = (superstep >= 10 && (newValue - self.value).abs < epsilon) || superstep >= 30 + + val outbox = + if (!terminate) + self.outEdges.map(edge => + new PRMessage(edge.targetId, newValue / self.outEdges.size)) + else + ArrayBuffer[PRMessage]() + + (new PRVertex(self.id, newValue, self.outEdges, !terminate), outbox) + } +} + +@serializable +object PRNoCombiner extends DefaultCombiner[PRMessage] { + def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messages: Option[ArrayBuffer[PRMessage]], superstep: Int): (PRVertex, Iterable[PRMessage]) = + PRCombiner.compute(numVertices, epsilon)(self, messages match { + case Some(msgs) => Some(msgs.map(_.value).sum) + case None => None + }, superstep) +} + +@serializable class PRVertex() extends Vertex { + var id: String = _ + var value: Double = _ + var outEdges: ArrayBuffer[PREdge] = _ + var active: Boolean = true + + def this(id: String, value: Double, outEdges: ArrayBuffer[PREdge], active: Boolean) { + this() + this.id = id + this.value = value + this.outEdges = outEdges + this.active = active + } +} + +@serializable class PRMessage() extends Message { + var targetId: String = _ + var value: Double = _ + + def this(targetId: String, value: Double) { + this() + this.targetId = targetId + this.value = value + } +} + +@serializable class PREdge() extends Edge { + var targetId: String = _ + + def this(targetId: String) { + this() + this.targetId = targetId + } +} + +class PRKryoRegistrator extends KryoRegistrator { + def registerClasses(kryo: Kryo) { + kryo.register(classOf[PRVertex]) + kryo.register(classOf[PRMessage]) + kryo.register(classOf[PREdge]) + } +} diff --git a/bagel/src/test/scala/bagel/BagelSuite.scala b/bagel/src/test/scala/bagel/BagelSuite.scala index 53a93a6b80..1b47fc9ed5 100644 --- a/bagel/src/test/scala/bagel/BagelSuite.scala +++ b/bagel/src/test/scala/bagel/BagelSuite.scala @@ -1,4 +1,4 @@ -package bagel +package spark.bagel import org.scalatest.{FunSuite, Assertions} import org.scalatest.prop.Checkers @@ -10,7 +10,7 @@ import scala.collection.mutable.ArrayBuffer import spark._ -import bagel.Pregel._ +import spark.bagel.Bagel._ @serializable class TestVertex(val id: String, val active: Boolean, val age: Int) extends Vertex @serializable class TestMessage(val targetId: String) extends Message @@ -22,7 +22,7 @@ class BagelSuite extends FunSuite with Assertions { val msgs = sc.parallelize(Array[(String, TestMessage)]()) val numSupersteps = 5 val result = - Pregel.run(sc, verts, msgs)()(addAggregatorArg { + Bagel.run(sc, verts, msgs)()(addAggregatorArg { (self: TestVertex, msgs: Option[ArrayBuffer[TestMessage]], superstep: Int) => (new TestVertex(self.id, superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]()) }) @@ -36,7 +36,7 @@ class BagelSuite extends FunSuite with Assertions { val msgs = sc.parallelize(Array("a" -> new TestMessage("a"))) val numSupersteps = 5 val result = - Pregel.run(sc, verts, msgs)()(addAggregatorArg { + Bagel.run(sc, verts, msgs)()(addAggregatorArg { (self: TestVertex, msgs: Option[ArrayBuffer[TestMessage]], superstep: Int) => val msgsOut = msgs match { -- cgit v1.2.3