aboutsummaryrefslogtreecommitdiff
path: root/bagel
diff options
context:
space:
mode:
authorAnkur Dave <ankurdave@gmail.com>2011-10-09 16:15:15 -0700
committerAnkur Dave <ankurdave@gmail.com>2011-10-09 16:53:10 -0700
commitab3889f62753beb51a550b500cd0830adb03d4cc (patch)
treed98d26bc9d66af6bf35d2550cc820769c07fa667 /bagel
parentcbdc01eecd235d03bf04f8e63c5dfac7cd622134 (diff)
downloadspark-ab3889f62753beb51a550b500cd0830adb03d4cc.tar.gz
spark-ab3889f62753beb51a550b500cd0830adb03d4cc.tar.bz2
spark-ab3889f62753beb51a550b500cd0830adb03d4cc.zip
Implement standalone WikipediaPageRank with custom serializer
Diffstat (limited to 'bagel')
-rw-r--r--bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala198
1 files changed, 198 insertions, 0 deletions
diff --git a/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala b/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala
new file mode 100644
index 0000000000..2e38376499
--- /dev/null
+++ b/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala
@@ -0,0 +1,198 @@
+package spark.bagel.examples
+
+import spark._
+import spark.SparkContext._
+
+import spark.bagel._
+import spark.bagel.Bagel._
+
+import scala.xml.{XML,NodeSeq}
+
+import scala.collection.mutable.ArrayBuffer
+
+import java.io.{InputStream, OutputStream, DataInputStream, DataOutputStream}
+
+object WikipediaPageRankStandalone {
+ def main(args: Array[String]) {
+ if (args.length < 5) {
+ System.err.println("Usage: WikipediaPageRankStandalone <inputFile> <threshold> <numIterations> <host> <usePartitioner>")
+ System.exit(-1)
+ }
+
+ System.setProperty("spark.serializer", "spark.bagel.examples.WPRSerializer")
+
+ val inputFile = args(0)
+ val threshold = args(1).toDouble
+ val numIterations = args(2).toInt
+ val host = args(3)
+ val usePartitioner = args(4).toBoolean
+ val sc = new SparkContext(host, "WikipediaPageRankStandalone")
+
+ val input = sc.textFile(inputFile)
+ val partitioner = new HashPartitioner(sc.defaultParallelism)
+ val links =
+ if (usePartitioner)
+ input.map(parseArticle _).partitionBy(partitioner).cache
+ else
+ input.map(parseArticle _).cache
+ val n = links.count
+ val defaultRank = 1.0 / n
+ val a = 0.15
+
+ // Do the computation
+ val startTime = System.currentTimeMillis
+ val ranks =
+ pageRank(links, numIterations, defaultRank, a, n, partitioner, usePartitioner, sc.defaultParallelism)
+
+ // Print the result
+ System.err.println("Articles with PageRank >= "+threshold+":")
+ val top =
+ (ranks
+ .filter { case (id, rank) => rank >= threshold }
+ .map { case (id, rank) => "%s\t%s\n".format(id, rank) }
+ .collect.mkString)
+ println(top)
+
+ val time = (System.currentTimeMillis - startTime) / 1000.0
+ println("Completed %d iterations in %f seconds: %f seconds per iteration"
+ .format(numIterations, time, time / numIterations))
+ System.exit(0)
+ }
+
+ def parseArticle(line: String): (String, Array[String]) = {
+ val fields = line.split("\t")
+ val (title, body) = (fields(1), fields(3).replace("\\n", "\n"))
+ val id = new String(title)
+ val links =
+ if (body == "\\N")
+ NodeSeq.Empty
+ else
+ try {
+ XML.loadString(body) \\ "link" \ "target"
+ } catch {
+ case e: org.xml.sax.SAXParseException =>
+ System.err.println("Article \""+title+"\" has malformed XML in body:\n"+body)
+ NodeSeq.Empty
+ }
+ val outEdges = links.map(link => new String(link.text)).toArray
+ (id, outEdges)
+ }
+
+ def pageRank(
+ links: RDD[(String, Array[String])],
+ numIterations: Int,
+ defaultRank: Double,
+ a: Double,
+ n: Long,
+ partitioner: Partitioner,
+ usePartitioner: Boolean,
+ numSplits: Int
+ ): RDD[(String, Double)] = {
+ var ranks = links.mapValues { edges => defaultRank }
+ for (i <- 1 to numIterations) {
+ val contribs = links.groupWith(ranks).flatMap {
+ case (id, (linksWrapper, rankWrapper)) =>
+ if (linksWrapper.length > 0) {
+ if (rankWrapper.length > 0) {
+ linksWrapper(0).map(dest => (dest, rankWrapper(0) / linksWrapper(0).size))
+ } else {
+ linksWrapper(0).map(dest => (dest, defaultRank / linksWrapper(0).size))
+ }
+ } else {
+ Array[(String, Double)]()
+ }
+ }
+ ranks = (contribs.combineByKey((x: Double) => x,
+ (x: Double, y: Double) => x + y,
+ (x: Double, y: Double) => x + y,
+ numSplits,
+ partitioner)
+ .mapValues(sum => a/n + (1-a)*sum))
+ }
+ ranks
+ }
+}
+
+class WPRSerializer extends spark.Serializer {
+ def newInstance(): SerializerInstance = new WPRSerializerInstance()
+}
+
+class WPRSerializerInstance extends SerializerInstance {
+ def serialize[T](t: T): Array[Byte] = {
+ throw new UnsupportedOperationException()
+ }
+
+ def deserialize[T](bytes: Array[Byte]): T = {
+ throw new UnsupportedOperationException()
+ }
+
+ def outputStream(s: OutputStream): SerializationStream = {
+ new WPRSerializationStream(s)
+ }
+
+ def inputStream(s: InputStream): DeserializationStream = {
+ new WPRDeserializationStream(s)
+ }
+}
+
+class WPRSerializationStream(os: OutputStream) extends SerializationStream {
+ val dos = new DataOutputStream(os)
+
+ def writeObject[T](t: T): Unit = t match {
+ case (id: String, wrapper: ArrayBuffer[_]) => wrapper(0) match {
+ case links: Array[String] => {
+ dos.writeInt(0) // links
+ dos.writeUTF(id)
+ dos.writeInt(links.length)
+ for (link <- links) {
+ dos.writeUTF(link)
+ }
+ }
+ case rank: Double => {
+ dos.writeInt(1) // rank
+ dos.writeUTF(id)
+ dos.writeDouble(rank)
+ }
+ }
+ case (id: String, rank: Double) => {
+ dos.writeInt(2) // rank without wrapper
+ dos.writeUTF(id)
+ dos.writeDouble(rank)
+ }
+ }
+
+ def flush() { dos.flush() }
+ def close() { dos.close() }
+}
+
+class WPRDeserializationStream(is: InputStream) extends DeserializationStream {
+ val dis = new DataInputStream(is)
+
+ def readObject[T](): T = {
+ val typeId = dis.readInt()
+ typeId match {
+ case 0 => {
+ val id = dis.readUTF()
+ val numLinks = dis.readInt()
+ val links = new Array[String](numLinks)
+ for (i <- 0 until numLinks) {
+ val link = dis.readUTF()
+ links(i) = link
+ }
+ (id, ArrayBuffer(links)).asInstanceOf[T]
+ }
+ case 1 => {
+ val id = dis.readUTF()
+ val rank = dis.readDouble()
+ (id, ArrayBuffer(rank)).asInstanceOf[T]
+ }
+ case 2 => {
+ val id = dis.readUTF()
+ val rank = dis.readDouble()
+ (id, rank).asInstanceOf[T]
+ }
+ }
+ }
+
+ def close() { dis.close() }
+}