From c18fa3ebc6848d2da19ac2f68c9e22870e135ecd Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Fri, 15 Apr 2011 21:40:54 -0700 Subject: Package combiner functions into a trait --- bagel/src/main/scala/bagel/Pregel.scala | 27 ++++---- bagel/src/main/scala/bagel/ShortestPath.scala | 16 +++-- bagel/src/main/scala/bagel/WikipediaPageRank.scala | 73 ++++++++++------------ bagel/src/test/scala/bagel/BagelSuite.scala | 10 +-- 4 files changed, 60 insertions(+), 66 deletions(-) diff --git a/bagel/src/main/scala/bagel/Pregel.scala b/bagel/src/main/scala/bagel/Pregel.scala index 67bc582fd1..e3b6d0c70a 100644 --- a/bagel/src/main/scala/bagel/Pregel.scala +++ b/bagel/src/main/scala/bagel/Pregel.scala @@ -2,7 +2,7 @@ package bagel import spark._ import spark.SparkContext._ -import scala.collection.mutable.HashMap + import scala.collection.mutable.ArrayBuffer object Pregel extends Logging { @@ -24,9 +24,7 @@ object Pregel extends Logging { sc: SparkContext, verts: RDD[(String, V)], msgs: RDD[(String, M)], - createCombiner: M => C, - mergeMsg: (C, M) => C, - mergeCombiners: (C, C) => C, + combiner: Combiner[M, C], numSplits: Int, superstep: Int = 0 )(compute: (V, Option[C], Int) => (V, Iterable[M])): RDD[V] = { @@ -35,7 +33,7 @@ object Pregel extends Logging { val startTime = System.currentTimeMillis // Bring together vertices and messages - val combinedMsgs = msgs.combineByKey(createCombiner, mergeMsg, mergeCombiners, numSplits) + val combinedMsgs = msgs.combineByKey(combiner.createCombiner, combiner.mergeMsg, combiner.mergeCombiners, numSplits) val grouped = verts.groupWith(combinedMsgs) // Run compute on each vertex @@ -72,17 +70,24 @@ object Pregel extends Logging { val newMsgs = processed.flatMap { case (id, (vert, msgs)) => msgs.map(m => (m.targetId, m)) } - run(sc, newVerts, newMsgs, createCombiner, mergeMsg, mergeCombiners, numSplits, superstep + 1)(compute) + run(sc, newVerts, newMsgs, combiner, numSplits, superstep + 1)(compute) } } +} + +trait Combiner[M, C] { + def createCombiner(msg: M): C + def mergeMsg(combiner: C, msg: M): C + def mergeCombiners(a: C, b: C): C +} - def defaultCreateCombiner[M <: Message](msg: M): ArrayBuffer[M] = ArrayBuffer(msg) - def defaultMergeMsg[M <: Message](combiner: ArrayBuffer[M], msg: M): ArrayBuffer[M] = +@serializable class DefaultCombiner[M] extends Combiner[M, ArrayBuffer[M]] { + def createCombiner(msg: M): ArrayBuffer[M] = + ArrayBuffer(msg) + def mergeMsg(combiner: ArrayBuffer[M], msg: M): ArrayBuffer[M] = combiner += msg - def defaultMergeCombiners[M <: Message](a: ArrayBuffer[M], b: ArrayBuffer[M]): ArrayBuffer[M] = + def mergeCombiners(a: ArrayBuffer[M], b: ArrayBuffer[M]): ArrayBuffer[M] = a ++= b - def defaultCompute[V <: Vertex, M <: Message](self: V, msgs: Option[ArrayBuffer[M]], superstep: Int): (V, Iterable[M]) = - (self, List()) } /** diff --git a/bagel/src/main/scala/bagel/ShortestPath.scala b/bagel/src/main/scala/bagel/ShortestPath.scala index 6699f58a31..3fd2f39334 100644 --- a/bagel/src/main/scala/bagel/ShortestPath.scala +++ b/bagel/src/main/scala/bagel/ShortestPath.scala @@ -49,12 +49,7 @@ object ShortestPath { messages.count()+" messages.") // Do the computation - def createCombiner(message: SPMessage): Int = message.value - def mergeMsg(combiner: Int, message: SPMessage): Int = - min(combiner, message.value) - def mergeCombiners(a: Int, b: Int): Int = min(a, b) - - val result = Pregel.run(sc, vertices, messages, createCombiner, mergeMsg, mergeCombiners, numSplits) { + val result = Pregel.run(sc, vertices, messages, MinCombiner, numSplits) { (self: SPVertex, messageMinValue: Option[Int], superstep: Int) => val newValue = messageMinValue match { case Some(minVal) => min(self.value, minVal) @@ -82,6 +77,15 @@ object ShortestPath { } } +object MinCombiner extends Combiner[SPMessage, Int] { + def createCombiner(msg: SPMessage): Int = + msg.value + def mergeMsg(combiner: Int, msg: SPMessage): Int = + min(combiner, msg.value) + def mergeCombiners(a: Int, b: Int): Int = + min(a, b) +} + @serializable class SPVertex(val id: String, val value: Int, val outEdges: Seq[SPEdge], val active: Boolean) extends Vertex @serializable class SPEdge(val targetId: String, val value: Int) extends Edge @serializable class SPMessage(val targetId: String, val value: Int) extends Message diff --git a/bagel/src/main/scala/bagel/WikipediaPageRank.scala b/bagel/src/main/scala/bagel/WikipediaPageRank.scala index f6aeff0782..994cea8ec3 100644 --- a/bagel/src/main/scala/bagel/WikipediaPageRank.scala +++ b/bagel/src/main/scala/bagel/WikipediaPageRank.scala @@ -60,9 +60,9 @@ object WikipediaPageRank { val messages = sc.parallelize(List[(String, PRMessage)]()) val result = if (noCombiner) { - Pregel.run[PRVertex, PRMessage, ArrayBuffer[PRMessage]](sc, vertices, messages, NoCombiner.createCombiner, NoCombiner.mergeMsg, NoCombiner.mergeCombiners, numSplits)(NoCombiner.compute(numVertices, epsilon)) + Pregel.run(sc, vertices, messages, PRNoCombiner, numSplits)(PRNoCombiner.compute(numVertices, epsilon)) } else { - Pregel.run[PRVertex, PRMessage, Double](sc, vertices, messages, Combiner.createCombiner, Combiner.mergeMsg, Combiner.mergeCombiners, numSplits)(Combiner.compute(numVertices, epsilon)) + Pregel.run(sc, vertices, messages, PRCombiner, numSplits)(PRCombiner.compute(numVertices, epsilon)) } // Print the result @@ -71,53 +71,44 @@ object WikipediaPageRank { "%s\t%s\n".format(vertex.id, vertex.value)).collect.mkString println(top) } +} - object Combiner { - def createCombiner(message: PRMessage): Double = message.value - - def mergeMsg(combiner: Double, message: PRMessage): Double = - combiner + message.value - - def mergeCombiners(a: Double, b: Double) = a + b - - def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messageSum: Option[Double], superstep: Int): (PRVertex, Iterable[PRMessage]) = { - val newValue = messageSum match { - case Some(msgSum) if msgSum != 0 => - 0.15 / numVertices + 0.85 * msgSum - case _ => self.value - } - - val terminate = (superstep >= 10 && (newValue - self.value).abs < epsilon) || superstep >= 30 - - val outbox = - if (!terminate) - self.outEdges.map(edge => - new PRMessage(edge.targetId, newValue / self.outEdges.size)) - else - ArrayBuffer[PRMessage]() - - (new PRVertex(self.id, newValue, self.outEdges, !terminate), outbox) +object PRCombiner extends Combiner[PRMessage, Double] { + def createCombiner(msg: PRMessage): Double = + msg.value + def mergeMsg(combiner: Double, msg: PRMessage): Double = + combiner + msg.value + def mergeCombiners(a: Double, b: Double): Double = + a + b + + def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messageSum: Option[Double], superstep: Int): (PRVertex, Iterable[PRMessage]) = { + val newValue = messageSum match { + case Some(msgSum) if msgSum != 0 => + 0.15 / numVertices + 0.85 * msgSum + case _ => self.value } - } - object NoCombiner { - def createCombiner(message: PRMessage): ArrayBuffer[PRMessage] = - ArrayBuffer(message) + val terminate = (superstep >= 10 && (newValue - self.value).abs < epsilon) || superstep >= 30 - def mergeMsg(combiner: ArrayBuffer[PRMessage], message: PRMessage): ArrayBuffer[PRMessage] = - combiner += message + val outbox = + if (!terminate) + self.outEdges.map(edge => + new PRMessage(edge.targetId, newValue / self.outEdges.size)) + else + ArrayBuffer[PRMessage]() - def mergeCombiners(a: ArrayBuffer[PRMessage], b: ArrayBuffer[PRMessage]): ArrayBuffer[PRMessage] = - a ++= b - - def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messages: Option[ArrayBuffer[PRMessage]], superstep: Int): (PRVertex, Iterable[PRMessage]) = - Combiner.compute(numVertices, epsilon)(self, messages match { - case Some(msgs) => Some(msgs.map(_.value).sum) - case None => None - }, superstep) + (new PRVertex(self.id, newValue, self.outEdges, !terminate), outbox) } } +object PRNoCombiner extends DefaultCombiner[PRMessage] { + def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messages: Option[ArrayBuffer[PRMessage]], superstep: Int): (PRVertex, Iterable[PRMessage]) = + PRCombiner.compute(numVertices, epsilon)(self, messages match { + case Some(msgs) => Some(msgs.map(_.value).sum) + case None => None + }, superstep) +} + @serializable class PRVertex() extends Vertex { var id: String = _ var value: Double = _ diff --git a/bagel/src/test/scala/bagel/BagelSuite.scala b/bagel/src/test/scala/bagel/BagelSuite.scala index 72aecb7fd8..29f5f0c358 100644 --- a/bagel/src/test/scala/bagel/BagelSuite.scala +++ b/bagel/src/test/scala/bagel/BagelSuite.scala @@ -20,10 +20,7 @@ class BagelSuite extends FunSuite with Assertions { val msgs = sc.parallelize(Array[(String, TestMessage)]()) val numSupersteps = 5 val result = - Pregel.run(sc, verts, msgs, - Pregel.defaultCreateCombiner[TestMessage], - Pregel.defaultMergeMsg[TestMessage], - Pregel.defaultMergeCombiners[TestMessage], 1) { + Pregel.run(sc, verts, msgs, new DefaultCombiner[TestMessage], 1) { (self: TestVertex, msgs: Option[ArrayBuffer[TestMessage]], superstep: Int) => (new TestVertex(self.id, superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]()) } @@ -37,10 +34,7 @@ class BagelSuite extends FunSuite with Assertions { val msgs = sc.parallelize(Array("a" -> new TestMessage("a"))) val numSupersteps = 5 val result = - Pregel.run(sc, verts, msgs, - Pregel.defaultCreateCombiner[TestMessage], - Pregel.defaultMergeMsg[TestMessage], - Pregel.defaultMergeCombiners[TestMessage], 1) { + Pregel.run(sc, verts, msgs, new DefaultCombiner[TestMessage], 1) { (self: TestVertex, msgs: Option[ArrayBuffer[TestMessage]], superstep: Int) => val msgsOut = msgs match { -- cgit v1.2.3