aboutsummaryrefslogtreecommitdiff
path: root/bagel/src/main/scala/bagel/ShortestPath.scala
blob: 6699f58a31c76c806c00dd231326d64ab93b9ea9 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
package bagel

import spark._
import spark.SparkContext._

import scala.math.min

object ShortestPath {
  def main(args: Array[String]) {
    if (args.length < 4) {
      System.err.println("Usage: ShortestPath <graphFile> <startVertex> " +
                         "<numSplits> <host>")
      System.exit(-1)
    }

    val graphFile = args(0)
    val startVertex = args(1)
    val numSplits = args(2).toInt
    val host = args(3)
    val sc = new SparkContext(host, "ShortestPath")

    // Parse the graph data from a file into two RDDs, vertices and messages
    val lines =
      (sc.textFile(graphFile)
       .filter(!_.matches("^\\s*#.*"))
       .map(line => line.split("\t")))

    val vertices: RDD[(String, SPVertex)] =
      (lines.groupBy(line => line(0))
       .map {
         case (vertexId, lines) => {
           val outEdges = lines.collect {
             case Array(_, targetId, edgeValue) =>
               new SPEdge(targetId, edgeValue.toInt)
           }
           
           (vertexId, new SPVertex(vertexId, Int.MaxValue, outEdges, true))
         }
       })

    val messages: RDD[(String, SPMessage)] =
      (lines.filter(_.length == 2)
       .map {
         case Array(vertexId, messageValue) =>
           (vertexId, new SPMessage(vertexId, messageValue.toInt))
       })
    
    System.err.println("Read "+vertices.count()+" vertices and "+
                       messages.count()+" messages.")

    // Do the computation
    def createCombiner(message: SPMessage): Int = message.value
    def mergeMsg(combiner: Int, message: SPMessage): Int =
      min(combiner, message.value)
    def mergeCombiners(a: Int, b: Int): Int = min(a, b)

    val result = Pregel.run(sc, vertices, messages, createCombiner, mergeMsg, mergeCombiners, numSplits) {
      (self: SPVertex, messageMinValue: Option[Int], superstep: Int) =>
        val newValue = messageMinValue match {
          case Some(minVal) => min(self.value, minVal)
          case None => self.value
        }

        val outbox =
          if (newValue != self.value)
            self.outEdges.map(edge =>
              new SPMessage(edge.targetId, newValue + edge.value))
          else
            List()

        (new SPVertex(self.id, newValue, self.outEdges, false), outbox)
    }

    // Print the result
    System.err.println("Shortest path from "+startVertex+" to all vertices:")
    val shortest = result.map(vertex =>
      "%s\t%s\n".format(vertex.id, vertex.value match {
        case x if x == Int.MaxValue => "inf"
        case x => x
      })).collect.mkString
    println(shortest)
  }
}

@serializable class SPVertex(val id: String, val value: Int, val outEdges: Seq[SPEdge], val active: Boolean) extends Vertex
@serializable class SPEdge(val targetId: String, val value: Int) extends Edge
@serializable class SPMessage(val targetId: String, val value: Int) extends Message