aboutsummaryrefslogtreecommitdiff
path: root/bagel/src/main/scala
diff options
context:
space:
mode:
authorAnkur Dave <ankurdave@gmail.com>2011-04-15 21:40:54 -0700
committerAnkur Dave <ankurdave@gmail.com>2011-05-03 15:40:41 -0700
commitc18fa3ebc6848d2da19ac2f68c9e22870e135ecd (patch)
treecd3987a5ad9d6bba8362f7b12ee017db81c81c68 /bagel/src/main/scala
parent1c8ca0ebe1537c8f424722294794a66ff123f132 (diff)
downloadspark-c18fa3ebc6848d2da19ac2f68c9e22870e135ecd.tar.gz
spark-c18fa3ebc6848d2da19ac2f68c9e22870e135ecd.tar.bz2
spark-c18fa3ebc6848d2da19ac2f68c9e22870e135ecd.zip
Package combiner functions into a trait
Diffstat (limited to 'bagel/src/main/scala')
-rw-r--r--bagel/src/main/scala/bagel/Pregel.scala27
-rw-r--r--bagel/src/main/scala/bagel/ShortestPath.scala16
-rw-r--r--bagel/src/main/scala/bagel/WikipediaPageRank.scala73
3 files changed, 58 insertions, 58 deletions
diff --git a/bagel/src/main/scala/bagel/Pregel.scala b/bagel/src/main/scala/bagel/Pregel.scala
index 67bc582fd1..e3b6d0c70a 100644
--- a/bagel/src/main/scala/bagel/Pregel.scala
+++ b/bagel/src/main/scala/bagel/Pregel.scala
@@ -2,7 +2,7 @@ package bagel
import spark._
import spark.SparkContext._
-import scala.collection.mutable.HashMap
+
import scala.collection.mutable.ArrayBuffer
object Pregel extends Logging {
@@ -24,9 +24,7 @@ object Pregel extends Logging {
sc: SparkContext,
verts: RDD[(String, V)],
msgs: RDD[(String, M)],
- createCombiner: M => C,
- mergeMsg: (C, M) => C,
- mergeCombiners: (C, C) => C,
+ combiner: Combiner[M, C],
numSplits: Int,
superstep: Int = 0
)(compute: (V, Option[C], Int) => (V, Iterable[M])): RDD[V] = {
@@ -35,7 +33,7 @@ object Pregel extends Logging {
val startTime = System.currentTimeMillis
// Bring together vertices and messages
- val combinedMsgs = msgs.combineByKey(createCombiner, mergeMsg, mergeCombiners, numSplits)
+ val combinedMsgs = msgs.combineByKey(combiner.createCombiner, combiner.mergeMsg, combiner.mergeCombiners, numSplits)
val grouped = verts.groupWith(combinedMsgs)
// Run compute on each vertex
@@ -72,17 +70,24 @@ object Pregel extends Logging {
val newMsgs = processed.flatMap {
case (id, (vert, msgs)) => msgs.map(m => (m.targetId, m))
}
- run(sc, newVerts, newMsgs, createCombiner, mergeMsg, mergeCombiners, numSplits, superstep + 1)(compute)
+ run(sc, newVerts, newMsgs, combiner, numSplits, superstep + 1)(compute)
}
}
+}
+
+trait Combiner[M, C] {
+ def createCombiner(msg: M): C
+ def mergeMsg(combiner: C, msg: M): C
+ def mergeCombiners(a: C, b: C): C
+}
- def defaultCreateCombiner[M <: Message](msg: M): ArrayBuffer[M] = ArrayBuffer(msg)
- def defaultMergeMsg[M <: Message](combiner: ArrayBuffer[M], msg: M): ArrayBuffer[M] =
+@serializable class DefaultCombiner[M] extends Combiner[M, ArrayBuffer[M]] {
+ def createCombiner(msg: M): ArrayBuffer[M] =
+ ArrayBuffer(msg)
+ def mergeMsg(combiner: ArrayBuffer[M], msg: M): ArrayBuffer[M] =
combiner += msg
- def defaultMergeCombiners[M <: Message](a: ArrayBuffer[M], b: ArrayBuffer[M]): ArrayBuffer[M] =
+ def mergeCombiners(a: ArrayBuffer[M], b: ArrayBuffer[M]): ArrayBuffer[M] =
a ++= b
- def defaultCompute[V <: Vertex, M <: Message](self: V, msgs: Option[ArrayBuffer[M]], superstep: Int): (V, Iterable[M]) =
- (self, List())
}
/**
diff --git a/bagel/src/main/scala/bagel/ShortestPath.scala b/bagel/src/main/scala/bagel/ShortestPath.scala
index 6699f58a31..3fd2f39334 100644
--- a/bagel/src/main/scala/bagel/ShortestPath.scala
+++ b/bagel/src/main/scala/bagel/ShortestPath.scala
@@ -49,12 +49,7 @@ object ShortestPath {
messages.count()+" messages.")
// Do the computation
- def createCombiner(message: SPMessage): Int = message.value
- def mergeMsg(combiner: Int, message: SPMessage): Int =
- min(combiner, message.value)
- def mergeCombiners(a: Int, b: Int): Int = min(a, b)
-
- val result = Pregel.run(sc, vertices, messages, createCombiner, mergeMsg, mergeCombiners, numSplits) {
+ val result = Pregel.run(sc, vertices, messages, MinCombiner, numSplits) {
(self: SPVertex, messageMinValue: Option[Int], superstep: Int) =>
val newValue = messageMinValue match {
case Some(minVal) => min(self.value, minVal)
@@ -82,6 +77,15 @@ object ShortestPath {
}
}
+object MinCombiner extends Combiner[SPMessage, Int] {
+ def createCombiner(msg: SPMessage): Int =
+ msg.value
+ def mergeMsg(combiner: Int, msg: SPMessage): Int =
+ min(combiner, msg.value)
+ def mergeCombiners(a: Int, b: Int): Int =
+ min(a, b)
+}
+
@serializable class SPVertex(val id: String, val value: Int, val outEdges: Seq[SPEdge], val active: Boolean) extends Vertex
@serializable class SPEdge(val targetId: String, val value: Int) extends Edge
@serializable class SPMessage(val targetId: String, val value: Int) extends Message
diff --git a/bagel/src/main/scala/bagel/WikipediaPageRank.scala b/bagel/src/main/scala/bagel/WikipediaPageRank.scala
index f6aeff0782..994cea8ec3 100644
--- a/bagel/src/main/scala/bagel/WikipediaPageRank.scala
+++ b/bagel/src/main/scala/bagel/WikipediaPageRank.scala
@@ -60,9 +60,9 @@ object WikipediaPageRank {
val messages = sc.parallelize(List[(String, PRMessage)]())
val result =
if (noCombiner) {
- Pregel.run[PRVertex, PRMessage, ArrayBuffer[PRMessage]](sc, vertices, messages, NoCombiner.createCombiner, NoCombiner.mergeMsg, NoCombiner.mergeCombiners, numSplits)(NoCombiner.compute(numVertices, epsilon))
+ Pregel.run(sc, vertices, messages, PRNoCombiner, numSplits)(PRNoCombiner.compute(numVertices, epsilon))
} else {
- Pregel.run[PRVertex, PRMessage, Double](sc, vertices, messages, Combiner.createCombiner, Combiner.mergeMsg, Combiner.mergeCombiners, numSplits)(Combiner.compute(numVertices, epsilon))
+ Pregel.run(sc, vertices, messages, PRCombiner, numSplits)(PRCombiner.compute(numVertices, epsilon))
}
// Print the result
@@ -71,53 +71,44 @@ object WikipediaPageRank {
"%s\t%s\n".format(vertex.id, vertex.value)).collect.mkString
println(top)
}
+}
- object Combiner {
- def createCombiner(message: PRMessage): Double = message.value
-
- def mergeMsg(combiner: Double, message: PRMessage): Double =
- combiner + message.value
-
- def mergeCombiners(a: Double, b: Double) = a + b
-
- def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messageSum: Option[Double], superstep: Int): (PRVertex, Iterable[PRMessage]) = {
- val newValue = messageSum match {
- case Some(msgSum) if msgSum != 0 =>
- 0.15 / numVertices + 0.85 * msgSum
- case _ => self.value
- }
-
- val terminate = (superstep >= 10 && (newValue - self.value).abs < epsilon) || superstep >= 30
-
- val outbox =
- if (!terminate)
- self.outEdges.map(edge =>
- new PRMessage(edge.targetId, newValue / self.outEdges.size))
- else
- ArrayBuffer[PRMessage]()
-
- (new PRVertex(self.id, newValue, self.outEdges, !terminate), outbox)
+object PRCombiner extends Combiner[PRMessage, Double] {
+ def createCombiner(msg: PRMessage): Double =
+ msg.value
+ def mergeMsg(combiner: Double, msg: PRMessage): Double =
+ combiner + msg.value
+ def mergeCombiners(a: Double, b: Double): Double =
+ a + b
+
+ def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messageSum: Option[Double], superstep: Int): (PRVertex, Iterable[PRMessage]) = {
+ val newValue = messageSum match {
+ case Some(msgSum) if msgSum != 0 =>
+ 0.15 / numVertices + 0.85 * msgSum
+ case _ => self.value
}
- }
- object NoCombiner {
- def createCombiner(message: PRMessage): ArrayBuffer[PRMessage] =
- ArrayBuffer(message)
+ val terminate = (superstep >= 10 && (newValue - self.value).abs < epsilon) || superstep >= 30
- def mergeMsg(combiner: ArrayBuffer[PRMessage], message: PRMessage): ArrayBuffer[PRMessage] =
- combiner += message
+ val outbox =
+ if (!terminate)
+ self.outEdges.map(edge =>
+ new PRMessage(edge.targetId, newValue / self.outEdges.size))
+ else
+ ArrayBuffer[PRMessage]()
- def mergeCombiners(a: ArrayBuffer[PRMessage], b: ArrayBuffer[PRMessage]): ArrayBuffer[PRMessage] =
- a ++= b
-
- def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messages: Option[ArrayBuffer[PRMessage]], superstep: Int): (PRVertex, Iterable[PRMessage]) =
- Combiner.compute(numVertices, epsilon)(self, messages match {
- case Some(msgs) => Some(msgs.map(_.value).sum)
- case None => None
- }, superstep)
+ (new PRVertex(self.id, newValue, self.outEdges, !terminate), outbox)
}
}
+object PRNoCombiner extends DefaultCombiner[PRMessage] {
+ def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messages: Option[ArrayBuffer[PRMessage]], superstep: Int): (PRVertex, Iterable[PRMessage]) =
+ PRCombiner.compute(numVertices, epsilon)(self, messages match {
+ case Some(msgs) => Some(msgs.map(_.value).sum)
+ case None => None
+ }, superstep)
+}
+
@serializable class PRVertex() extends Vertex {
var id: String = _
var value: Double = _