1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
|
package bagel
import spark._
import spark.SparkContext._
import scala.math.min
import bagel.Pregel._
object ShortestPath {
def main(args: Array[String]) {
if (args.length < 4) {
System.err.println("Usage: ShortestPath <graphFile> <startVertex> " +
"<numSplits> <host>")
System.exit(-1)
}
val graphFile = args(0)
val startVertex = args(1)
val numSplits = args(2).toInt
val host = args(3)
val sc = new SparkContext(host, "ShortestPath")
// Parse the graph data from a file into two RDDs, vertices and messages
val lines =
(sc.textFile(graphFile)
.filter(!_.matches("^\\s*#.*"))
.map(line => line.split("\t")))
val vertices: RDD[(String, SPVertex)] =
(lines.groupBy(line => line(0))
.map {
case (vertexId, lines) => {
val outEdges = lines.collect {
case Array(_, targetId, edgeValue) =>
new SPEdge(targetId, edgeValue.toInt)
}
(vertexId, new SPVertex(vertexId, Int.MaxValue, outEdges, true))
}
})
val messages: RDD[(String, SPMessage)] =
(lines.filter(_.length == 2)
.map {
case Array(vertexId, messageValue) =>
(vertexId, new SPMessage(vertexId, messageValue.toInt))
})
System.err.println("Read "+vertices.count()+" vertices and "+
messages.count()+" messages.")
// Do the computation
val compute = addAggregatorArg {
(self: SPVertex, messageMinValue: Option[Int], superstep: Int) =>
val newValue = messageMinValue match {
case Some(minVal) => min(self.value, minVal)
case None => self.value
}
val outbox =
if (newValue != self.value)
self.outEdges.map(edge =>
new SPMessage(edge.targetId, newValue + edge.value))
else
List()
(new SPVertex(self.id, newValue, self.outEdges, false), outbox)
}
val result = Pregel.run(sc, vertices, messages)(combiner = MinCombiner, numSplits = numSplits)(compute)
// Print the result
System.err.println("Shortest path from "+startVertex+" to all vertices:")
val shortest = result.map(vertex =>
"%s\t%s\n".format(vertex.id, vertex.value match {
case x if x == Int.MaxValue => "inf"
case x => x
})).collect.mkString
println(shortest)
}
}
object MinCombiner extends Combiner[SPMessage, Int] {
def createCombiner(msg: SPMessage): Int =
msg.value
def mergeMsg(combiner: Int, msg: SPMessage): Int =
min(combiner, msg.value)
def mergeCombiners(a: Int, b: Int): Int =
min(a, b)
}
@serializable class SPVertex(val id: String, val value: Int, val outEdges: Seq[SPEdge], val active: Boolean) extends Vertex
@serializable class SPEdge(val targetId: String, val value: Int) extends Edge
@serializable class SPMessage(val targetId: String, val value: Int) extends Message
|