aboutsummaryrefslogtreecommitdiff
path: root/bagel
diff options
context:
space:
mode:
authorAnkur Dave <ankurdave@gmail.com>2011-04-23 13:38:37 -0700
committerAnkur Dave <ankurdave@gmail.com>2011-05-03 15:40:45 -0700
commit563c5e717cc75869c328bba17116313eab9e976b (patch)
tree99c2894a1bc9d4317b11d7de26c98f514b879794 /bagel
parentc18fa3ebc6848d2da19ac2f68c9e22870e135ecd (diff)
downloadspark-563c5e717cc75869c328bba17116313eab9e976b.tar.gz
spark-563c5e717cc75869c328bba17116313eab9e976b.tar.bz2
spark-563c5e717cc75869c328bba17116313eab9e976b.zip
Refactor and add aggregator support
Refactored out the agg() and comp() methods from Pregel.run. Defined an implicit conversion to allow applications that don't use aggregators to avoid including a null argument for the result of the aggregator in the compute function.
Diffstat (limited to 'bagel')
-rw-r--r--bagel/src/main/scala/bagel/Pregel.scala111
-rw-r--r--bagel/src/main/scala/bagel/ShortestPath.scala5
-rw-r--r--bagel/src/main/scala/bagel/WikipediaPageRank.scala6
-rw-r--r--bagel/src/test/scala/bagel/BagelSuite.scala10
4 files changed, 88 insertions, 44 deletions
diff --git a/bagel/src/main/scala/bagel/Pregel.scala b/bagel/src/main/scala/bagel/Pregel.scala
index e3b6d0c70a..2ec94ebb0d 100644
--- a/bagel/src/main/scala/bagel/Pregel.scala
+++ b/bagel/src/main/scala/bagel/Pregel.scala
@@ -6,37 +6,62 @@ import spark.SparkContext._
import scala.collection.mutable.ArrayBuffer
object Pregel extends Logging {
- /**
- * Runs a Pregel job on the given vertices consisting of the
- * specified compute function.
- *
- * Before beginning the first superstep, the given messages are sent
- * to their destination vertices.
- *
- * During the job, the specified combiner functions are applied to
- * messages as they travel between vertices.
- *
- * The job halts and returns the resulting set of vertices when no
- * messages are being sent between vertices and all vertices have
- * voted to halt by setting their state to inactive.
- */
- def run[V <: Vertex : Manifest, M <: Message : Manifest, C](
+ def run[V <: Vertex : Manifest, M <: Message : Manifest, C : Manifest, A : Manifest](
sc: SparkContext,
verts: RDD[(String, V)],
- msgs: RDD[(String, M)],
- combiner: Combiner[M, C],
- numSplits: Int,
- superstep: Int = 0
- )(compute: (V, Option[C], Int) => (V, Iterable[M])): RDD[V] = {
+ msgs: RDD[(String, M)]
+ )(
+ combiner: Combiner[M, C] = new DefaultCombiner[M],
+ aggregator: Aggregator[V, A] = new NullAggregator[V],
+ superstep: Int = 0,
+ numSplits: Int = sc.numCores
+ )(
+ compute: (V, Option[C], A, Int) => (V, Iterable[M])
+ ): RDD[V] = {
logInfo("Starting superstep "+superstep+".")
val startTime = System.currentTimeMillis
- // Bring together vertices and messages
+ val aggregated = agg(verts, aggregator)
val combinedMsgs = msgs.combineByKey(combiner.createCombiner, combiner.mergeMsg, combiner.mergeCombiners, numSplits)
val grouped = verts.groupWith(combinedMsgs)
+ val (processed, numMsgs, numActiveVerts) = comp[V, M, C](sc, grouped, compute(_, _, aggregated, superstep))
+
+ val timeTaken = System.currentTimeMillis - startTime
+ logInfo("Superstep %d took %d s".format(superstep, timeTaken / 1000))
+
+ // Check stopping condition and iterate
+ val noActivity = numMsgs == 0 && numActiveVerts == 0
+ if (noActivity) {
+ processed.map { case (id, (vert, msgs)) => vert }
+ } else {
+ val newVerts = processed.mapValues { case (vert, msgs) => vert }
+ val newMsgs = processed.flatMap {
+ case (id, (vert, msgs)) => msgs.map(m => (m.targetId, m))
+ }
+ run(sc, newVerts, newMsgs)(combiner, aggregator, superstep + 1, numSplits)(compute)
+ }
+ }
+
+ /**
+ * Aggregates the given vertices using the given aggregator, or does
+ * nothing if it is a NullAggregator.
+ */
+ def agg[V <: Vertex, A : Manifest](verts: RDD[(String, V)], aggregator: Aggregator[V, A]): A = aggregator match {
+ case _: NullAggregator[_] =>
+ None
+ case _ =>
+ verts.map {
+ case (id, vert) => aggregator.createAggregator(vert)
+ }.reduce(aggregator.mergeAggregators(_, _))
+ }
- // Run compute on each vertex
+ /**
+ * Processes the given vertex-message RDD using the compute
+ * function. Returns the processed RDD, the number of messages
+ * created, and the number of active vertices.
+ */
+ def comp[V <: Vertex, M <: Message, C](sc: SparkContext, grouped: RDD[(String, (Seq[V], Seq[C]))], compute: (V, Option[C]) => (V, Iterable[M])): (RDD[(String, (V, Iterable[M]))], Int, Int) = {
var numMsgs = sc.accumulator(0)
var numActiveVerts = sc.accumulator(0)
val processed = grouped.flatMapValues {
@@ -46,7 +71,7 @@ object Pregel extends Logging {
compute(v, c match {
case Seq(comb) => Some(comb)
case Seq() => None
- }, superstep)
+ })
numMsgs += newMsgs.size
if (newVert.active)
@@ -58,30 +83,36 @@ object Pregel extends Logging {
// Force evaluation of processed RDD for accurate performance measurements
processed.foreach(x => {})
- val timeTaken = System.currentTimeMillis - startTime
- logInfo("Superstep %d took %d s".format(superstep, timeTaken / 1000))
+ (processed, numMsgs.value, numActiveVerts.value)
+ }
- // Check stopping condition and iterate
- val noActivity = numMsgs.value == 0 && numActiveVerts.value == 0
- if (noActivity) {
- processed.map { case (id, (vert, msgs)) => vert }
- } else {
- val newVerts = processed.mapValues { case (vert, msgs) => vert }
- val newMsgs = processed.flatMap {
- case (id, (vert, msgs)) => msgs.map(m => (m.targetId, m))
- }
- run(sc, newVerts, newMsgs, combiner, numSplits, superstep + 1)(compute)
- }
+ /**
+ * Converts a compute function that doesn't take an aggregator to
+ * one that does, so it can be passed to Pregel.run.
+ */
+ implicit def addAggregatorArg[
+ V <: Vertex : Manifest, M <: Message : Manifest, C
+ ](
+ compute: (V, Option[C], Int) => (V, Iterable[M])
+ ): (V, Option[C], Option[Nothing], Int) => (V, Iterable[M]) = {
+ (vert: V, messages: Option[C], aggregator: Option[Nothing], superstep: Int) => compute(vert, messages, superstep)
}
}
+// TODO: Simplify Combiner interface and make it more OO.
trait Combiner[M, C] {
def createCombiner(msg: M): C
def mergeMsg(combiner: C, msg: M): C
def mergeCombiners(a: C, b: C): C
}
-@serializable class DefaultCombiner[M] extends Combiner[M, ArrayBuffer[M]] {
+trait Aggregator[V, A] {
+ def createAggregator(vert: V): A
+ def mergeAggregators(a: A, b: A): A
+}
+
+@serializable
+class DefaultCombiner[M] extends Combiner[M, ArrayBuffer[M]] {
def createCombiner(msg: M): ArrayBuffer[M] =
ArrayBuffer(msg)
def mergeMsg(combiner: ArrayBuffer[M], msg: M): ArrayBuffer[M] =
@@ -90,6 +121,12 @@ trait Combiner[M, C] {
a ++= b
}
+@serializable
+class NullAggregator[V] extends Aggregator[V, Option[Nothing]] {
+ def createAggregator(vert: V): Option[Nothing] = None
+ def mergeAggregators(a: Option[Nothing], b: Option[Nothing]): Option[Nothing] = None
+}
+
/**
* Represents a Pregel vertex.
*
diff --git a/bagel/src/main/scala/bagel/ShortestPath.scala b/bagel/src/main/scala/bagel/ShortestPath.scala
index 3fd2f39334..8f4a881850 100644
--- a/bagel/src/main/scala/bagel/ShortestPath.scala
+++ b/bagel/src/main/scala/bagel/ShortestPath.scala
@@ -5,6 +5,8 @@ import spark.SparkContext._
import scala.math.min
+import bagel.Pregel._
+
object ShortestPath {
def main(args: Array[String]) {
if (args.length < 4) {
@@ -49,7 +51,7 @@ object ShortestPath {
messages.count()+" messages.")
// Do the computation
- val result = Pregel.run(sc, vertices, messages, MinCombiner, numSplits) {
+ val compute = addAggregatorArg {
(self: SPVertex, messageMinValue: Option[Int], superstep: Int) =>
val newValue = messageMinValue match {
case Some(minVal) => min(self.value, minVal)
@@ -65,6 +67,7 @@ object ShortestPath {
(new SPVertex(self.id, newValue, self.outEdges, false), outbox)
}
+ val result = Pregel.run(sc, vertices, messages)(combiner = MinCombiner, numSplits = numSplits)(compute)
// Print the result
System.err.println("Shortest path from "+startVertex+" to all vertices:")
diff --git a/bagel/src/main/scala/bagel/WikipediaPageRank.scala b/bagel/src/main/scala/bagel/WikipediaPageRank.scala
index 994cea8ec3..2fe77b4962 100644
--- a/bagel/src/main/scala/bagel/WikipediaPageRank.scala
+++ b/bagel/src/main/scala/bagel/WikipediaPageRank.scala
@@ -3,6 +3,8 @@ package bagel
import spark._
import spark.SparkContext._
+import bagel.Pregel._
+
import scala.collection.mutable.ArrayBuffer
import scala.xml.{XML,NodeSeq}
@@ -60,9 +62,9 @@ object WikipediaPageRank {
val messages = sc.parallelize(List[(String, PRMessage)]())
val result =
if (noCombiner) {
- Pregel.run(sc, vertices, messages, PRNoCombiner, numSplits)(PRNoCombiner.compute(numVertices, epsilon))
+ Pregel.run(sc, vertices, messages)(numSplits = numSplits)(PRNoCombiner.compute(numVertices, epsilon))
} else {
- Pregel.run(sc, vertices, messages, PRCombiner, numSplits)(PRCombiner.compute(numVertices, epsilon))
+ Pregel.run(sc, vertices, messages)(combiner = PRCombiner, numSplits = numSplits)(PRCombiner.compute(numVertices, epsilon))
}
// Print the result
diff --git a/bagel/src/test/scala/bagel/BagelSuite.scala b/bagel/src/test/scala/bagel/BagelSuite.scala
index 29f5f0c358..53a93a6b80 100644
--- a/bagel/src/test/scala/bagel/BagelSuite.scala
+++ b/bagel/src/test/scala/bagel/BagelSuite.scala
@@ -10,6 +10,8 @@ import scala.collection.mutable.ArrayBuffer
import spark._
+import bagel.Pregel._
+
@serializable class TestVertex(val id: String, val active: Boolean, val age: Int) extends Vertex
@serializable class TestMessage(val targetId: String) extends Message
@@ -20,10 +22,10 @@ class BagelSuite extends FunSuite with Assertions {
val msgs = sc.parallelize(Array[(String, TestMessage)]())
val numSupersteps = 5
val result =
- Pregel.run(sc, verts, msgs, new DefaultCombiner[TestMessage], 1) {
+ Pregel.run(sc, verts, msgs)()(addAggregatorArg {
(self: TestVertex, msgs: Option[ArrayBuffer[TestMessage]], superstep: Int) =>
(new TestVertex(self.id, superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]())
- }
+ })
for (vert <- result.collect)
assert(vert.age === numSupersteps)
}
@@ -34,7 +36,7 @@ class BagelSuite extends FunSuite with Assertions {
val msgs = sc.parallelize(Array("a" -> new TestMessage("a")))
val numSupersteps = 5
val result =
- Pregel.run(sc, verts, msgs, new DefaultCombiner[TestMessage], 1) {
+ Pregel.run(sc, verts, msgs)()(addAggregatorArg {
(self: TestVertex, msgs: Option[ArrayBuffer[TestMessage]], superstep: Int) =>
val msgsOut =
msgs match {
@@ -44,7 +46,7 @@ class BagelSuite extends FunSuite with Assertions {
new ArrayBuffer[TestMessage]()
}
(new TestVertex(self.id, self.active, self.age + 1), msgsOut)
- }
+ })
for (vert <- result.collect)
assert(vert.age === numSupersteps)
}