From c5be7d2b2268e44e3eafb460d4bf0fb0badf9b22 Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Tue, 8 Nov 2011 19:56:44 +0000 Subject: Update Bagel unit tests to reflect API change --- bagel/src/test/scala/bagel/BagelSuite.scala | 44 ++++++++++++++--------------- 1 file changed, 21 insertions(+), 23 deletions(-) (limited to 'bagel') diff --git a/bagel/src/test/scala/bagel/BagelSuite.scala b/bagel/src/test/scala/bagel/BagelSuite.scala index 59356e09f0..0eda80af64 100644 --- a/bagel/src/test/scala/bagel/BagelSuite.scala +++ b/bagel/src/test/scala/bagel/BagelSuite.scala @@ -10,45 +10,43 @@ import scala.collection.mutable.ArrayBuffer import spark._ -import spark.bagel.Bagel._ - -class TestVertex(val id: String, val active: Boolean, val age: Int) extends Vertex with Serializable -class TestMessage(val targetId: String) extends Message with Serializable +class TestVertex(val active: Boolean, val age: Int) extends Vertex with Serializable +class TestMessage(val targetId: String) extends Message[String] with Serializable class BagelSuite extends FunSuite with Assertions { test("halting by voting") { val sc = new SparkContext("local", "test") - val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(id, true, 0)))) + val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(true, 0)))) val msgs = sc.parallelize(Array[(String, TestMessage)]()) val numSupersteps = 5 val result = - Bagel.run(sc, verts, msgs)()(addAggregatorArg { - (self: TestVertex, msgs: Option[ArrayBuffer[TestMessage]], superstep: Int) => - (new TestVertex(self.id, superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]()) - }) - for (vert <- result.collect) + Bagel.run(sc, verts, msgs, sc.defaultParallelism) { + (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => + (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]()) + } + for ((id, vert) <- result.collect) assert(vert.age === numSupersteps) sc.stop() } test("halting by message silence") { val sc = new SparkContext("local", "test") - val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(id, false, 0)))) + val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(false, 0)))) val msgs = sc.parallelize(Array("a" -> new TestMessage("a"))) val numSupersteps = 5 val result = - Bagel.run(sc, verts, msgs)()(addAggregatorArg { - (self: TestVertex, msgs: Option[ArrayBuffer[TestMessage]], superstep: Int) => - val msgsOut = - msgs match { - case Some(ms) if (superstep < numSupersteps - 1) => - ms - case _ => - new ArrayBuffer[TestMessage]() - } - (new TestVertex(self.id, self.active, self.age + 1), msgsOut) - }) - for (vert <- result.collect) + Bagel.run(sc, verts, msgs, sc.defaultParallelism) { + (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => + val msgsOut = + msgs match { + case Some(ms) if (superstep < numSupersteps - 1) => + ms + case _ => + Array[TestMessage]() + } + (new TestVertex(self.active, self.age + 1), msgsOut) + } + for ((id, vert) <- result.collect) assert(vert.age === numSupersteps) sc.stop() } -- cgit v1.2.3