aboutsummaryrefslogtreecommitdiff
path: root/bagel/src/main/scala/bagel/Pregel.scala
diff options
context:
space:
mode:
Diffstat (limited to 'bagel/src/main/scala/bagel/Pregel.scala')
-rw-r--r--bagel/src/main/scala/bagel/Pregel.scala109
1 files changed, 59 insertions, 50 deletions
diff --git a/bagel/src/main/scala/bagel/Pregel.scala b/bagel/src/main/scala/bagel/Pregel.scala
index b0f7a48b7a..5ef398d783 100644
--- a/bagel/src/main/scala/bagel/Pregel.scala
+++ b/bagel/src/main/scala/bagel/Pregel.scala
@@ -7,75 +7,81 @@ import scala.collection.mutable.ArrayBuffer
object Pregel extends Logging {
/**
- * Runs a Pregel job on the given vertices, running the specified
- * compute function on each vertex in every superstep. Before
- * beginning the first superstep, sends the given messages to their
- * destination vertices. In the join stage, launches splits
- * separate tasks (where splits is manually specified to work
- * around a bug in Spark).
+ * Runs a Pregel job on the given vertices consisting of the
+ * specified compute function.
*
- * Halts when no more messages are being sent between vertices, and
- * all vertices have voted to halt by setting their state to
- * Inactive.
+ * Before beginning the first superstep, the given messages are sent
+ * to their destination vertices.
+ *
+ * During the job, the specified combiner functions are applied to
+ * messages as they travel between vertices.
+ *
+ * The job halts and returns the resulting set of vertices when no
+ * messages are being sent between vertices and all vertices have
+ * voted to halt by setting their state to inactive.
*/
- def run[V <: Vertex : Manifest, M <: Message : Manifest, C](sc: SparkContext, verts: RDD[(String, V)], msgs: RDD[(String, M)], splits: Int, messageCombiner: (C, M) => C, defaultCombined: () => C, mergeCombined: (C, C) => C, maxSupersteps: Option[Int] = None, superstep: Int = 0)(compute: (V, C, Int) => (V, Iterable[M])): RDD[V] = {
+ def run[V <: Vertex : Manifest, M <: Message : Manifest, C](
+ sc: SparkContext,
+ verts: RDD[(String, V)],
+ msgs: RDD[(String, M)],
+ createCombiner: M => C,
+ mergeMsg: (C, M) => C,
+ mergeCombiners: (C, C) => C,
+ numSplits: Int,
+ superstep: Int = 0
+ )(compute: (V, Option[C], Int) => (V, Iterable[M])): RDD[V] = {
+
logInfo("Starting superstep "+superstep+".")
val startTime = System.currentTimeMillis
// Bring together vertices and messages
- val combinedMsgs = msgs.combineByKey({x => messageCombiner(defaultCombined(), x)}, messageCombiner, mergeCombined, splits)
- logDebug("verts.splits.size = " + verts.splits.size)
- logDebug("combinedMsgs.splits.size = " + combinedMsgs.splits.size)
- logDebug("verts.partitioner = " + verts.partitioner)
- logDebug("combinedMsgs.partitioner = " + combinedMsgs.partitioner)
-
- val joined = verts.groupWith(combinedMsgs)
- logDebug("joined.splits.size = " + joined.splits.size)
- logDebug("joined.partitioner = " + joined.partitioner)
+ val combinedMsgs = msgs.combineByKey(createCombiner, mergeMsg, mergeCombiners, numSplits)
+ val grouped = verts.groupWith(combinedMsgs)
// Run compute on each vertex
- var messageCount = sc.accumulator(0)
- var activeVertexCount = sc.accumulator(0)
- val processed = joined.flatMapValues {
+ var numMsgs = sc.accumulator(0)
+ var numActiveVerts = sc.accumulator(0)
+ val processed = grouped.flatMapValues {
case (Seq(), _) => None
- case (Seq(v), Seq(comb)) =>
- val (newVertex, newMessages) = compute(v, comb, superstep)
+ case (Seq(v), c) =>
+ val (newVert, newMsgs) =
+ compute(v, c match {
+ case Seq(comb) => Some(comb)
+ case Seq() => None
+ }, superstep)
- messageCount += newMessages.size
- if (newVertex.active)
- activeVertexCount += 1
+ numMsgs += newMsgs.size
+ if (newVert.active)
+ numActiveVerts += 1
- Some((newVertex, newMessages))
- case (Seq(v), Seq()) =>
- val (newVertex, newMessages) = compute(v, defaultCombined(), superstep)
-
- messageCount += newMessages.size
- if (newVertex.active)
- activeVertexCount += 1
-
- Some((newVertex, newMessages))
+ Some((newVert, newMsgs))
}.cache
+
// Force evaluation of processed RDD for accurate performance measurements
processed.foreach(x => {})
val timeTaken = System.currentTimeMillis - startTime
logInfo("Superstep %d took %d s".format(superstep, timeTaken / 1000))
- // Check stopping condition and recurse
- val stop = messageCount.value == 0 && activeVertexCount.value == 0
- if (stop || (maxSupersteps.isDefined && superstep >= maxSupersteps.get)) {
- processed.map { _._2._1 }
+ // Check stopping condition and iterate
+ val noActivity = numMsgs.value == 0 && numActiveVerts.value == 0
+ if (noActivity) {
+ processed.map { case (id, (vert, msgs)) => vert }
} else {
- val newVerts = processed.mapValues(_._1)
- val newMsgs = processed.flatMap(x => x._2._2.map(m => (m.targetId, m)))
- run(sc, newVerts, newMsgs, splits, messageCombiner, defaultCombined, mergeCombined, maxSupersteps, superstep + 1)(compute)
+ val newVerts = processed.mapValues { case (vert, msgs) => vert }
+ val newMsgs = processed.flatMap {
+ case (id, (vert, msgs)) => msgs.map(m => (m.targetId, m))
+ }
+ run(sc, newVerts, newMsgs, createCombiner, mergeMsg, mergeCombiners, numSplits, superstep + 1)(compute)
}
}
}
/**
- * Represents a Pregel vertex. Must be subclassed to store state
- * along with each vertex. Must be annotated with @serializable.
+ * Represents a Pregel vertex.
+ *
+ * Subclasses may store state along with each vertex and must be
+ * annotated with @serializable.
*/
trait Vertex {
def id: String
@@ -83,17 +89,20 @@ trait Vertex {
}
/**
- * Represents a Pregel message to a target vertex. Must be
- * subclassed to contain a payload. Must be annotated with @serializable.
+ * Represents a Pregel message to a target vertex.
+ *
+ * Subclasses may contain a payload to deliver to the target vertex
+ * and must be annotated with @serializable.
*/
trait Message {
def targetId: String
}
/**
- * Represents a directed edge between two vertices. Owned by the
- * source vertex, and contains the ID of the target vertex. Must
- * be subclassed to store state along with each edge. Must be annotated with @serializable.
+ * Represents a directed edge between two vertices.
+ *
+ * Subclasses may store state along each edge and must be annotated
+ * with @serializable.
*/
trait Edge {
def targetId: String