From cbdc01eecd235d03bf04f8e63c5dfac7cd622134 Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Sun, 9 Oct 2011 16:00:59 -0700 Subject: Update WikipediaPageRank to reflect Bagel API changes --- .../scala/spark/bagel/examples/PageRankUtils.scala | 106 ++++++++++++++++++ .../spark/bagel/examples/WikipediaPageRank.scala | 123 ++++----------------- 2 files changed, 129 insertions(+), 100 deletions(-) create mode 100644 bagel/src/main/scala/spark/bagel/examples/PageRankUtils.scala (limited to 'bagel/src') diff --git a/bagel/src/main/scala/spark/bagel/examples/PageRankUtils.scala b/bagel/src/main/scala/spark/bagel/examples/PageRankUtils.scala new file mode 100644 index 0000000000..b97d786ed4 --- /dev/null +++ b/bagel/src/main/scala/spark/bagel/examples/PageRankUtils.scala @@ -0,0 +1,106 @@ +package spark.bagel.examples + +import spark._ +import spark.SparkContext._ + +import spark.bagel._ +import spark.bagel.Bagel._ + +import scala.collection.mutable.ArrayBuffer + +import java.io.{InputStream, OutputStream, DataInputStream, DataOutputStream} + +import com.esotericsoftware.kryo._ + +class PageRankUtils extends Serializable { + def computeWithCombiner(numVertices: Long, epsilon: Double)( + self: PRVertex, messageSum: Option[Double], superstep: Int + ): (PRVertex, Array[PRMessage]) = { + val newValue = messageSum match { + case Some(msgSum) if msgSum != 0 => + 0.15 / numVertices + 0.85 * msgSum + case _ => self.value + } + + val terminate = superstep >= 10 + + val outbox: Array[PRMessage] = + if (!terminate) + self.outEdges.map(targetId => + new PRMessage(targetId, newValue / self.outEdges.size)) + else + Array[PRMessage]() + + (new PRVertex(newValue, self.outEdges, !terminate), outbox) + } + + def computeNoCombiner(numVertices: Long, epsilon: Double)(self: PRVertex, messages: Option[Array[PRMessage]], superstep: Int): (PRVertex, Array[PRMessage]) = + computeWithCombiner(numVertices, epsilon)(self, messages match { + case Some(msgs) => Some(msgs.map(_.value).sum) + case None => None + }, superstep) +} + +class PRCombiner extends Combiner[PRMessage, Double] with Serializable { + def createCombiner(msg: PRMessage): Double = + msg.value + def mergeMsg(combiner: Double, msg: PRMessage): Double = + combiner + msg.value + def mergeCombiners(a: Double, b: Double): Double = + a + b +} + +class PRVertex() extends Vertex with Serializable { + var value: Double = _ + var outEdges: Array[String] = _ + var active: Boolean = _ + + def this(value: Double, outEdges: Array[String], active: Boolean = true) { + this() + this.value = value + this.outEdges = outEdges + this.active = active + } + + override def toString(): String = { + "PRVertex(value=%f, outEdges.length=%d, active=%s)".format(value, outEdges.length, active.toString) + } +} + +class PRMessage() extends Message[String] with Serializable { + var targetId: String = _ + var value: Double = _ + + def this(targetId: String, value: Double) { + this() + this.targetId = targetId + this.value = value + } +} + +class PRKryoRegistrator extends KryoRegistrator { + def registerClasses(kryo: Kryo) { + kryo.register(classOf[PRVertex]) + kryo.register(classOf[PRMessage]) + } +} + +class CustomPartitioner(partitions: Int) extends Partitioner { + def numPartitions = partitions + + def getPartition(key: Any): Int = { + val hash = key match { + case k: Long => (k & 0x00000000FFFFFFFFL).toInt + case _ => key.hashCode + } + + val mod = key.hashCode % partitions + if (mod < 0) mod + partitions else mod + } + + override def equals(other: Any): Boolean = other match { + case c: CustomPartitioner => + c.numPartitions == numPartitions + case _ => false + } +} diff --git a/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRank.scala b/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRank.scala index 9a0dbbe9d7..f37ee01fd2 100644 --- a/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRank.scala +++ b/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRank.scala @@ -6,28 +6,23 @@ import spark.SparkContext._ import spark.bagel._ import spark.bagel.Bagel._ -import scala.collection.mutable.ArrayBuffer import scala.xml.{XML,NodeSeq} -import java.io.{Externalizable,ObjectInput,ObjectOutput,DataOutputStream,DataInputStream} - -import com.esotericsoftware.kryo._ - object WikipediaPageRank { def main(args: Array[String]) { - if (args.length < 4) { - System.err.println("Usage: WikipediaPageRank []") + if (args.length < 5) { + System.err.println("Usage: WikipediaPageRank ") System.exit(-1) } - System.setProperty("spark.serialization", "spark.KryoSerialization") + System.setProperty("spark.serializer", "spark.KryoSerializer") System.setProperty("spark.kryo.registrator", classOf[PRKryoRegistrator].getName) val inputFile = args(0) val threshold = args(1).toDouble val numSplits = args(2).toInt val host = args(3) - val noCombiner = args.length > 4 && args(4).nonEmpty + val usePartitioner = args(4).toBoolean val sc = new SparkContext(host, "WikipediaPageRank") // Parse the Wikipedia page data into a graph @@ -38,7 +33,7 @@ object WikipediaPageRank { println("Done counting vertices.") println("Parsing input file...") - val vertices: RDD[(String, PRVertex)] = input.map(line => { + var vertices = input.map(line => { val fields = line.split("\t") val (title, body) = (fields(1), fields(3).replace("\\n", "\n")) val links = @@ -52,105 +47,33 @@ object WikipediaPageRank { System.err.println("Article \""+title+"\" has malformed XML in body:\n"+body) NodeSeq.Empty } - val outEdges = ArrayBuffer(links.map(link => new PREdge(new String(link.text))): _*) + val outEdges = links.map(link => new String(link.text)).toArray val id = new String(title) - (id, new PRVertex(id, 1.0 / numVertices, outEdges, true)) - }).cache + (id, new PRVertex(1.0 / numVertices, outEdges)) + }) + if (usePartitioner) + vertices = vertices.partitionBy(new HashPartitioner(sc.defaultParallelism)).cache + else + vertices = vertices.cache println("Done parsing input file.") // Do the computation val epsilon = 0.01 / numVertices - val messages = sc.parallelize(List[(String, PRMessage)]()) + val messages = sc.parallelize(Array[(String, PRMessage)]()) + val utils = new PageRankUtils val result = - if (noCombiner) { - Bagel.run(sc, vertices, messages)(numSplits = numSplits)(PRNoCombiner.compute(numVertices, epsilon)) - } else { - Bagel.run(sc, vertices, messages)(combiner = PRCombiner, numSplits = numSplits)(PRCombiner.compute(numVertices, epsilon)) - } + Bagel.run( + sc, vertices, messages, combiner = new PRCombiner(), + numSplits = numSplits)( + utils.computeWithCombiner(numVertices, epsilon)) // Print the result System.err.println("Articles with PageRank >= "+threshold+":") - val top = result.filter(_.value >= threshold).map(vertex => - "%s\t%s\n".format(vertex.id, vertex.value)).collect.mkString + val top = + (result + .filter { case (id, vertex) => vertex.value >= threshold } + .map { case (id, vertex) => "%s\t%s\n".format(id, vertex.value) } + .collect.mkString) println(top) } } - -object PRCombiner extends Combiner[PRMessage, Double] with Serializable { - def createCombiner(msg: PRMessage): Double = - msg.value - def mergeMsg(combiner: Double, msg: PRMessage): Double = - combiner + msg.value - def mergeCombiners(a: Double, b: Double): Double = - a + b - - def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messageSum: Option[Double], superstep: Int): (PRVertex, Iterable[PRMessage]) = { - val newValue = messageSum match { - case Some(msgSum) if msgSum != 0 => - 0.15 / numVertices + 0.85 * msgSum - case _ => self.value - } - - val terminate = (superstep >= 10 && (newValue - self.value).abs < epsilon) || superstep >= 30 - - val outbox = - if (!terminate) - self.outEdges.map(edge => - new PRMessage(edge.targetId, newValue / self.outEdges.size)) - else - ArrayBuffer[PRMessage]() - - (new PRVertex(self.id, newValue, self.outEdges, !terminate), outbox) - } -} - -object PRNoCombiner extends DefaultCombiner[PRMessage] with Serializable { - def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messages: Option[ArrayBuffer[PRMessage]], superstep: Int): (PRVertex, Iterable[PRMessage]) = - PRCombiner.compute(numVertices, epsilon)(self, messages match { - case Some(msgs) => Some(msgs.map(_.value).sum) - case None => None - }, superstep) -} - -class PRVertex() extends Vertex with Serializable { - var id: String = _ - var value: Double = _ - var outEdges: ArrayBuffer[PREdge] = _ - var active: Boolean = true - - def this(id: String, value: Double, outEdges: ArrayBuffer[PREdge], active: Boolean) { - this() - this.id = id - this.value = value - this.outEdges = outEdges - this.active = active - } -} - -class PRMessage() extends Message with Serializable { - var targetId: String = _ - var value: Double = _ - - def this(targetId: String, value: Double) { - this() - this.targetId = targetId - this.value = value - } -} - -class PREdge() extends Edge with Serializable { - var targetId: String = _ - - def this(targetId: String) { - this() - this.targetId = targetId - } -} - -class PRKryoRegistrator extends KryoRegistrator { - def registerClasses(kryo: Kryo) { - kryo.register(classOf[PRVertex]) - kryo.register(classOf[PRMessage]) - kryo.register(classOf[PREdge]) - } -} -- cgit v1.2.3