aboutsummaryrefslogtreecommitdiff
path: root/bagel/src/main
diff options
context:
space:
mode:
authorAnkur Dave <ankurdave@gmail.com>2011-10-09 16:00:59 -0700
committerAnkur Dave <ankurdave@gmail.com>2011-10-09 16:19:34 -0700
commitcbdc01eecd235d03bf04f8e63c5dfac7cd622134 (patch)
tree3a7a32223653fc6519c6fc51d19ba135758b140b /bagel/src/main
parent6d707f6b63e875f1b88210da2cf486f9d33f83c0 (diff)
downloadspark-cbdc01eecd235d03bf04f8e63c5dfac7cd622134.tar.gz
spark-cbdc01eecd235d03bf04f8e63c5dfac7cd622134.tar.bz2
spark-cbdc01eecd235d03bf04f8e63c5dfac7cd622134.zip
Update WikipediaPageRank to reflect Bagel API changes
Diffstat (limited to 'bagel/src/main')
-rw-r--r--bagel/src/main/scala/spark/bagel/examples/PageRankUtils.scala106
-rw-r--r--bagel/src/main/scala/spark/bagel/examples/WikipediaPageRank.scala123
2 files changed, 129 insertions, 100 deletions
diff --git a/bagel/src/main/scala/spark/bagel/examples/PageRankUtils.scala b/bagel/src/main/scala/spark/bagel/examples/PageRankUtils.scala
new file mode 100644
index 0000000000..b97d786ed4
--- /dev/null
+++ b/bagel/src/main/scala/spark/bagel/examples/PageRankUtils.scala
@@ -0,0 +1,106 @@
+package spark.bagel.examples
+
+import spark._
+import spark.SparkContext._
+
+import spark.bagel._
+import spark.bagel.Bagel._
+
+import scala.collection.mutable.ArrayBuffer
+
+import java.io.{InputStream, OutputStream, DataInputStream, DataOutputStream}
+
+import com.esotericsoftware.kryo._
+
+class PageRankUtils extends Serializable {
+ def computeWithCombiner(numVertices: Long, epsilon: Double)(
+ self: PRVertex, messageSum: Option[Double], superstep: Int
+ ): (PRVertex, Array[PRMessage]) = {
+ val newValue = messageSum match {
+ case Some(msgSum) if msgSum != 0 =>
+ 0.15 / numVertices + 0.85 * msgSum
+ case _ => self.value
+ }
+
+ val terminate = superstep >= 10
+
+ val outbox: Array[PRMessage] =
+ if (!terminate)
+ self.outEdges.map(targetId =>
+ new PRMessage(targetId, newValue / self.outEdges.size))
+ else
+ Array[PRMessage]()
+
+ (new PRVertex(newValue, self.outEdges, !terminate), outbox)
+ }
+
+ def computeNoCombiner(numVertices: Long, epsilon: Double)(self: PRVertex, messages: Option[Array[PRMessage]], superstep: Int): (PRVertex, Array[PRMessage]) =
+ computeWithCombiner(numVertices, epsilon)(self, messages match {
+ case Some(msgs) => Some(msgs.map(_.value).sum)
+ case None => None
+ }, superstep)
+}
+
+class PRCombiner extends Combiner[PRMessage, Double] with Serializable {
+ def createCombiner(msg: PRMessage): Double =
+ msg.value
+ def mergeMsg(combiner: Double, msg: PRMessage): Double =
+ combiner + msg.value
+ def mergeCombiners(a: Double, b: Double): Double =
+ a + b
+}
+
+class PRVertex() extends Vertex with Serializable {
+ var value: Double = _
+ var outEdges: Array[String] = _
+ var active: Boolean = _
+
+ def this(value: Double, outEdges: Array[String], active: Boolean = true) {
+ this()
+ this.value = value
+ this.outEdges = outEdges
+ this.active = active
+ }
+
+ override def toString(): String = {
+ "PRVertex(value=%f, outEdges.length=%d, active=%s)".format(value, outEdges.length, active.toString)
+ }
+}
+
+class PRMessage() extends Message[String] with Serializable {
+ var targetId: String = _
+ var value: Double = _
+
+ def this(targetId: String, value: Double) {
+ this()
+ this.targetId = targetId
+ this.value = value
+ }
+}
+
+class PRKryoRegistrator extends KryoRegistrator {
+ def registerClasses(kryo: Kryo) {
+ kryo.register(classOf[PRVertex])
+ kryo.register(classOf[PRMessage])
+ }
+}
+
+class CustomPartitioner(partitions: Int) extends Partitioner {
+ def numPartitions = partitions
+
+ def getPartition(key: Any): Int = {
+ val hash = key match {
+ case k: Long => (k & 0x00000000FFFFFFFFL).toInt
+ case _ => key.hashCode
+ }
+
+ val mod = key.hashCode % partitions
+ if (mod < 0) mod + partitions else mod
+ }
+
+ override def equals(other: Any): Boolean = other match {
+ case c: CustomPartitioner =>
+ c.numPartitions == numPartitions
+ case _ => false
+ }
+}
diff --git a/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRank.scala b/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRank.scala
index 9a0dbbe9d7..f37ee01fd2 100644
--- a/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRank.scala
+++ b/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRank.scala
@@ -6,28 +6,23 @@ import spark.SparkContext._
import spark.bagel._
import spark.bagel.Bagel._
-import scala.collection.mutable.ArrayBuffer
import scala.xml.{XML,NodeSeq}
-import java.io.{Externalizable,ObjectInput,ObjectOutput,DataOutputStream,DataInputStream}
-
-import com.esotericsoftware.kryo._
-
object WikipediaPageRank {
def main(args: Array[String]) {
- if (args.length < 4) {
- System.err.println("Usage: WikipediaPageRank <inputFile> <threshold> <numSplits> <host> [<noCombiner>]")
+ if (args.length < 5) {
+ System.err.println("Usage: WikipediaPageRank <inputFile> <threshold> <numSplits> <host> <usePartitioner>")
System.exit(-1)
}
- System.setProperty("spark.serialization", "spark.KryoSerialization")
+ System.setProperty("spark.serializer", "spark.KryoSerializer")
System.setProperty("spark.kryo.registrator", classOf[PRKryoRegistrator].getName)
val inputFile = args(0)
val threshold = args(1).toDouble
val numSplits = args(2).toInt
val host = args(3)
- val noCombiner = args.length > 4 && args(4).nonEmpty
+ val usePartitioner = args(4).toBoolean
val sc = new SparkContext(host, "WikipediaPageRank")
// Parse the Wikipedia page data into a graph
@@ -38,7 +33,7 @@ object WikipediaPageRank {
println("Done counting vertices.")
println("Parsing input file...")
- val vertices: RDD[(String, PRVertex)] = input.map(line => {
+ var vertices = input.map(line => {
val fields = line.split("\t")
val (title, body) = (fields(1), fields(3).replace("\\n", "\n"))
val links =
@@ -52,105 +47,33 @@ object WikipediaPageRank {
System.err.println("Article \""+title+"\" has malformed XML in body:\n"+body)
NodeSeq.Empty
}
- val outEdges = ArrayBuffer(links.map(link => new PREdge(new String(link.text))): _*)
+ val outEdges = links.map(link => new String(link.text)).toArray
val id = new String(title)
- (id, new PRVertex(id, 1.0 / numVertices, outEdges, true))
- }).cache
+ (id, new PRVertex(1.0 / numVertices, outEdges))
+ })
+ if (usePartitioner)
+ vertices = vertices.partitionBy(new HashPartitioner(sc.defaultParallelism)).cache
+ else
+ vertices = vertices.cache
println("Done parsing input file.")
// Do the computation
val epsilon = 0.01 / numVertices
- val messages = sc.parallelize(List[(String, PRMessage)]())
+ val messages = sc.parallelize(Array[(String, PRMessage)]())
+ val utils = new PageRankUtils
val result =
- if (noCombiner) {
- Bagel.run(sc, vertices, messages)(numSplits = numSplits)(PRNoCombiner.compute(numVertices, epsilon))
- } else {
- Bagel.run(sc, vertices, messages)(combiner = PRCombiner, numSplits = numSplits)(PRCombiner.compute(numVertices, epsilon))
- }
+ Bagel.run(
+ sc, vertices, messages, combiner = new PRCombiner(),
+ numSplits = numSplits)(
+ utils.computeWithCombiner(numVertices, epsilon))
// Print the result
System.err.println("Articles with PageRank >= "+threshold+":")
- val top = result.filter(_.value >= threshold).map(vertex =>
- "%s\t%s\n".format(vertex.id, vertex.value)).collect.mkString
+ val top =
+ (result
+ .filter { case (id, vertex) => vertex.value >= threshold }
+ .map { case (id, vertex) => "%s\t%s\n".format(id, vertex.value) }
+ .collect.mkString)
println(top)
}
}
-
-object PRCombiner extends Combiner[PRMessage, Double] with Serializable {
- def createCombiner(msg: PRMessage): Double =
- msg.value
- def mergeMsg(combiner: Double, msg: PRMessage): Double =
- combiner + msg.value
- def mergeCombiners(a: Double, b: Double): Double =
- a + b
-
- def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messageSum: Option[Double], superstep: Int): (PRVertex, Iterable[PRMessage]) = {
- val newValue = messageSum match {
- case Some(msgSum) if msgSum != 0 =>
- 0.15 / numVertices + 0.85 * msgSum
- case _ => self.value
- }
-
- val terminate = (superstep >= 10 && (newValue - self.value).abs < epsilon) || superstep >= 30
-
- val outbox =
- if (!terminate)
- self.outEdges.map(edge =>
- new PRMessage(edge.targetId, newValue / self.outEdges.size))
- else
- ArrayBuffer[PRMessage]()
-
- (new PRVertex(self.id, newValue, self.outEdges, !terminate), outbox)
- }
-}
-
-object PRNoCombiner extends DefaultCombiner[PRMessage] with Serializable {
- def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messages: Option[ArrayBuffer[PRMessage]], superstep: Int): (PRVertex, Iterable[PRMessage]) =
- PRCombiner.compute(numVertices, epsilon)(self, messages match {
- case Some(msgs) => Some(msgs.map(_.value).sum)
- case None => None
- }, superstep)
-}
-
-class PRVertex() extends Vertex with Serializable {
- var id: String = _
- var value: Double = _
- var outEdges: ArrayBuffer[PREdge] = _
- var active: Boolean = true
-
- def this(id: String, value: Double, outEdges: ArrayBuffer[PREdge], active: Boolean) {
- this()
- this.id = id
- this.value = value
- this.outEdges = outEdges
- this.active = active
- }
-}
-
-class PRMessage() extends Message with Serializable {
- var targetId: String = _
- var value: Double = _
-
- def this(targetId: String, value: Double) {
- this()
- this.targetId = targetId
- this.value = value
- }
-}
-
-class PREdge() extends Edge with Serializable {
- var targetId: String = _
-
- def this(targetId: String) {
- this()
- this.targetId = targetId
- }
-}
-
-class PRKryoRegistrator extends KryoRegistrator {
- def registerClasses(kryo: Kryo) {
- kryo.register(classOf[PRVertex])
- kryo.register(classOf[PRMessage])
- kryo.register(classOf[PREdge])
- }
-}