From 0028caf3a4727623f70e23cd2f611f9797d0a3d3 Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Sun, 9 Oct 2011 15:58:39 -0700 Subject: Simplify and genericize type parameters in Bagel --- bagel/src/main/scala/spark/bagel/Bagel.scala | 214 ++++++++++++++++----------- 1 file changed, 129 insertions(+), 85 deletions(-) (limited to 'bagel/src') diff --git a/bagel/src/main/scala/spark/bagel/Bagel.scala b/bagel/src/main/scala/spark/bagel/Bagel.scala index c24c65be2a..2f57c9c0fd 100644 --- a/bagel/src/main/scala/spark/bagel/Bagel.scala +++ b/bagel/src/main/scala/spark/bagel/Bagel.scala @@ -6,54 +6,110 @@ import spark.SparkContext._ import scala.collection.mutable.ArrayBuffer object Bagel extends Logging { - def run[V <: Vertex : Manifest, M <: Message : Manifest, C : Manifest, A : Manifest]( + def run[K : Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, + C : Manifest, A : Manifest]( sc: SparkContext, - verts: RDD[(String, V)], - msgs: RDD[(String, M)] + vertices: RDD[(K, V)], + messages: RDD[(K, M)], + combiner: Combiner[M, C], + aggregator: Option[Aggregator[V, A]], + partitioner: Partitioner, + numSplits: Int )( - combiner: Combiner[M, C] = new DefaultCombiner[M], - aggregator: Aggregator[V, A] = new NullAggregator[V], - superstep: Int = 0, - numSplits: Int = sc.defaultParallelism - )( - compute: (V, Option[C], A, Int) => (V, Iterable[M]) - ): RDD[V] = { - - logInfo("Starting superstep "+superstep+".") - val startTime = System.currentTimeMillis - - val aggregated = agg(verts, aggregator) - val combinedMsgs = msgs.combineByKey(combiner.createCombiner, combiner.mergeMsg, combiner.mergeCombiners, numSplits) - val grouped = verts.groupWith(combinedMsgs) - val (processed, numMsgs, numActiveVerts) = comp[V, M, C](sc, grouped, compute(_, _, aggregated, superstep)) - - val timeTaken = System.currentTimeMillis - startTime - logInfo("Superstep %d took %d s".format(superstep, timeTaken / 1000)) - - // Check stopping condition and iterate - val noActivity = numMsgs == 0 && numActiveVerts == 0 - if (noActivity) { - processed.map { case (id, (vert, msgs)) => vert } - } else { - val newVerts = processed.mapValues { case (vert, msgs) => vert } - val newMsgs = processed.flatMap { + compute: (V, Option[C], Option[A], Int) => (V, Array[M]) + ): RDD[(K, V)] = { + val splits = if (numSplits != 0) numSplits else sc.defaultParallelism + + var superstep = 0 + var verts = vertices + var msgs = messages + var noActivity = false + do { + logInfo("Starting superstep "+superstep+".") + val startTime = System.currentTimeMillis + + val aggregated = agg(verts, aggregator) + val combinedMsgs = msgs.combineByKey( + combiner.createCombiner, combiner.mergeMsg, combiner.mergeCombiners, + splits, partitioner) + val grouped = combinedMsgs.groupWith(verts) + val (processed, numMsgs, numActiveVerts) = + comp[K, V, M, C](sc, grouped, compute(_, _, aggregated, superstep)) + + val timeTaken = System.currentTimeMillis - startTime + logInfo("Superstep %d took %d s".format(superstep, timeTaken / 1000)) + + verts = processed.mapValues { case (vert, msgs) => vert } + msgs = processed.flatMap { case (id, (vert, msgs)) => msgs.map(m => (m.targetId, m)) } - run(sc, newVerts, newMsgs)(combiner, aggregator, superstep + 1, numSplits)(compute) - } + superstep += 1 + + noActivity = numMsgs == 0 && numActiveVerts == 0 + } while (!noActivity) + + verts + } + + def run[K : Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, + C : Manifest]( + sc: SparkContext, + vertices: RDD[(K, V)], + messages: RDD[(K, M)], + combiner: Combiner[M, C], + partitioner: Partitioner, + numSplits: Int + )( + compute: (V, Option[C], Int) => (V, Array[M]) + ): RDD[(K, V)] = { + run[K, V, M, C, Nothing]( + sc, vertices, messages, combiner, None, partitioner, numSplits)( + addAggregatorArg[K, V, M, C](compute)) + } + + def run[K : Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, + C : Manifest]( + sc: SparkContext, + vertices: RDD[(K, V)], + messages: RDD[(K, M)], + combiner: Combiner[M, C], + numSplits: Int + )( + compute: (V, Option[C], Int) => (V, Array[M]) + ): RDD[(K, V)] = { + val part = new HashPartitioner(numSplits) + run[K, V, M, C, Nothing]( + sc, vertices, messages, combiner, None, part, numSplits)( + addAggregatorArg[K, V, M, C](compute)) + } + + def run[K : Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest]( + sc: SparkContext, + vertices: RDD[(K, V)], + messages: RDD[(K, M)], + numSplits: Int + )( + compute: (V, Option[Array[M]], Int) => (V, Array[M]) + ): RDD[(K, V)] = { + val part = new HashPartitioner(numSplits) + run[K, V, M, Array[M], Nothing]( + sc, vertices, messages, new DefaultCombiner(), None, part, numSplits)( + addAggregatorArg[K, V, M, Array[M]](compute)) } /** - * Aggregates the given vertices using the given aggregator, or does - * nothing if it is a NullAggregator. + * Aggregates the given vertices using the given aggregator, if it + * is specified. */ - def agg[V <: Vertex, A : Manifest](verts: RDD[(String, V)], aggregator: Aggregator[V, A]): A = aggregator match { - case _: NullAggregator[_] => - None - case _ => - verts.map { - case (id, vert) => aggregator.createAggregator(vert) - }.reduce(aggregator.mergeAggregators(_, _)) + private def agg[K, V <: Vertex, A : Manifest]( + verts: RDD[(K, V)], + aggregator: Option[Aggregator[V, A]] + ): Option[A] = aggregator match { + case Some(a) => + Some(verts.map { + case (id, vert) => a.createAggregator(vert) + }.reduce(a.mergeAggregators(_, _))) + case None => None } /** @@ -61,23 +117,27 @@ object Bagel extends Logging { * function. Returns the processed RDD, the number of messages * created, and the number of active vertices. */ - def comp[V <: Vertex, M <: Message, C](sc: SparkContext, grouped: RDD[(String, (Seq[V], Seq[C]))], compute: (V, Option[C]) => (V, Iterable[M])): (RDD[(String, (V, Iterable[M]))], Int, Int) = { + private def comp[K : Manifest, V <: Vertex, M <: Message[K], C]( + sc: SparkContext, + grouped: RDD[(K, (Seq[C], Seq[V]))], + compute: (V, Option[C]) => (V, Array[M]) + ): (RDD[(K, (V, Array[M]))], Int, Int) = { var numMsgs = sc.accumulator(0) var numActiveVerts = sc.accumulator(0) val processed = grouped.flatMapValues { - case (Seq(), _) => None - case (Seq(v), c) => - val (newVert, newMsgs) = - compute(v, c match { - case Seq(comb) => Some(comb) - case Seq() => None - }) - - numMsgs += newMsgs.size - if (newVert.active) - numActiveVerts += 1 - - Some((newVert, newMsgs)) + case (_, vs) if vs.size == 0 => None + case (c, vs) => + val (newVert, newMsgs) = + compute(vs(0), c match { + case Seq(comb) => Some(comb) + case Seq() => None + }) + + numMsgs += newMsgs.size + if (newVert.active) + numActiveVerts += 1 + + Some((newVert, newMsgs)) }.cache // Force evaluation of processed RDD for accurate performance measurements @@ -90,16 +150,16 @@ object Bagel extends Logging { * Converts a compute function that doesn't take an aggregator to * one that does, so it can be passed to Bagel.run. */ - implicit def addAggregatorArg[ - V <: Vertex : Manifest, M <: Message : Manifest, C + private def addAggregatorArg[ + K : Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C ]( - compute: (V, Option[C], Int) => (V, Iterable[M]) - ): (V, Option[C], Option[Nothing], Int) => (V, Iterable[M]) = { - (vert: V, messages: Option[C], aggregator: Option[Nothing], superstep: Int) => compute(vert, messages, superstep) + compute: (V, Option[C], Int) => (V, Array[M]) + ): (V, Option[C], Option[Nothing], Int) => (V, Array[M]) = { + (vert: V, msgs: Option[C], aggregated: Option[Nothing], superstep: Int) => + compute(vert, msgs, superstep) } } -// TODO: Simplify Combiner interface and make it more OO. trait Combiner[M, C] { def createCombiner(msg: M): C def mergeMsg(combiner: C, msg: M): C @@ -111,18 +171,13 @@ trait Aggregator[V, A] { def mergeAggregators(a: A, b: A): A } -class DefaultCombiner[M] extends Combiner[M, ArrayBuffer[M]] with Serializable { - def createCombiner(msg: M): ArrayBuffer[M] = - ArrayBuffer(msg) - def mergeMsg(combiner: ArrayBuffer[M], msg: M): ArrayBuffer[M] = - combiner += msg - def mergeCombiners(a: ArrayBuffer[M], b: ArrayBuffer[M]): ArrayBuffer[M] = - a ++= b -} - -class NullAggregator[V] extends Aggregator[V, Option[Nothing]] with Serializable { - def createAggregator(vert: V): Option[Nothing] = None - def mergeAggregators(a: Option[Nothing], b: Option[Nothing]): Option[Nothing] = None +class DefaultCombiner[M : Manifest] extends Combiner[M, Array[M]] with Serializable { + def createCombiner(msg: M): Array[M] = + Array(msg) + def mergeMsg(combiner: Array[M], msg: M): Array[M] = + combiner :+ msg + def mergeCombiners(a: Array[M], b: Array[M]): Array[M] = + a ++ b } /** @@ -132,7 +187,6 @@ class NullAggregator[V] extends Aggregator[V, Option[Nothing]] with Serializable * inherit from java.io.Serializable or scala.Serializable. */ trait Vertex { - def id: String def active: Boolean } @@ -142,16 +196,6 @@ trait Vertex { * Subclasses may contain a payload to deliver to the target vertex * and must inherit from java.io.Serializable or scala.Serializable. */ -trait Message { - def targetId: String -} - -/** - * Represents a directed edge between two vertices. - * - * Subclasses may store state along each edge and must inherit from - * java.io.Serializable or scala.Serializable. - */ -trait Edge { - def targetId: String +trait Message[K] { + def targetId: K } -- cgit v1.2.3