From b7c92dded33e61976dea10beef88ab52e2009b42 Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Thu, 9 Jan 2014 20:44:28 -0800 Subject: Add implicit algorithm methods for Graph; remove standalone PageRank --- .../scala/org/apache/spark/graphx/Analytics.scala | 2 +- .../spark/graphx/algorithms/Algorithms.scala | 56 ++++++++++++++++++++++ .../graphx/algorithms/ConnectedComponents.scala | 6 ++- .../apache/spark/graphx/algorithms/PageRank.scala | 55 ++------------------- .../algorithms/StronglyConnectedComponents.scala | 8 ++-- .../apache/spark/graphx/algorithms/package.scala | 8 ++++ .../algorithms/ConnectedComponentsSuite.scala | 8 ++-- .../spark/graphx/algorithms/PageRankSuite.scala | 27 ++++------- .../StronglyConnectedComponentsSuite.scala | 6 +-- .../graphx/algorithms/TriangleCountSuite.scala | 8 ++-- 10 files changed, 99 insertions(+), 85 deletions(-) create mode 100644 graphx/src/main/scala/org/apache/spark/graphx/algorithms/Algorithms.scala create mode 100644 graphx/src/main/scala/org/apache/spark/graphx/algorithms/package.scala diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Analytics.scala b/graphx/src/main/scala/org/apache/spark/graphx/Analytics.scala index 0cafc3fdf9..def6d69190 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Analytics.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Analytics.scala @@ -83,7 +83,7 @@ object Analytics extends Logging { println("GRAPHX: Number of edges " + graph.edges.count) //val pr = Analytics.pagerank(graph, numIter) - val pr = PageRank.runStandalone(graph, tol) + val pr = graph.pageRank(tol).vertices println("GRAPHX: Total rank: " + pr.map(_._2).reduce(_+_)) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/algorithms/Algorithms.scala b/graphx/src/main/scala/org/apache/spark/graphx/algorithms/Algorithms.scala new file mode 100644 index 0000000000..4af7af545c --- /dev/null +++ b/graphx/src/main/scala/org/apache/spark/graphx/algorithms/Algorithms.scala @@ -0,0 +1,56 @@ +package org.apache.spark.graphx.algorithms + +import scala.reflect.ClassTag + +import org.apache.spark.graphx._ + +class Algorithms[VD: ClassTag, ED: ClassTag](self: Graph[VD, ED]) { + /** + * Run a dynamic version of PageRank returning a graph with vertex attributes containing the + * PageRank and edge attributes containing the normalized edge weight. + * + * @see [[org.apache.spark.graphx.algorithms.PageRank]], method `runUntilConvergence`. + */ + def pageRank(tol: Double, resetProb: Double = 0.15): Graph[Double, Double] = { + PageRank.runUntilConvergence(self, tol, resetProb) + } + + /** + * Run PageRank for a fixed number of iterations returning a graph with vertex attributes + * containing the PageRank and edge attributes the normalized edge weight. + * + * @see [[org.apache.spark.graphx.algorithms.PageRank]], method `run`. + */ + def staticPageRank(numIter: Int, resetProb: Double = 0.15): Graph[Double, Double] = { + PageRank.run(self, numIter, resetProb) + } + + /** + * Compute the connected component membership of each vertex and return a graph with the vertex + * value containing the lowest vertex id in the connected component containing that vertex. + * + * @see [[org.apache.spark.graphx.algorithms.ConnectedComponents]] + */ + def connectedComponents(): Graph[VertexID, ED] = { + ConnectedComponents.run(self) + } + + /** + * Compute the number of triangles passing through each vertex. + * + * @see [[org.apache.spark.graphx.algorithms.TriangleCount]] + */ + def triangleCount(): Graph[Int, ED] = { + TriangleCount.run(self) + } + + /** + * Compute the strongly connected component (SCC) of each vertex and return a graph with the + * vertex value containing the lowest vertex id in the SCC containing that vertex. + * + * @see [[org.apache.spark.graphx.algorithms.StronglyConnectedComponents]] + */ + def stronglyConnectedComponents(numIter: Int): Graph[VertexID, ED] = { + StronglyConnectedComponents.run(self, numIter) + } +} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/algorithms/ConnectedComponents.scala b/graphx/src/main/scala/org/apache/spark/graphx/algorithms/ConnectedComponents.scala index a0dd36da60..137a81f4d5 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/algorithms/ConnectedComponents.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/algorithms/ConnectedComponents.scala @@ -1,11 +1,13 @@ package org.apache.spark.graphx.algorithms +import scala.reflect.ClassTag + import org.apache.spark.graphx._ object ConnectedComponents { /** - * Compute the connected component membership of each vertex and return an RDD with the vertex + * Compute the connected component membership of each vertex and return a graph with the vertex * value containing the lowest vertex id in the connected component containing that vertex. * * @tparam VD the vertex attribute type (discarded in the computation) @@ -16,7 +18,7 @@ object ConnectedComponents { * @return a graph with vertex attributes containing the smallest vertex in each * connected component */ - def run[VD: Manifest, ED: Manifest](graph: Graph[VD, ED]): Graph[VertexID, ED] = { + def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]): Graph[VertexID, ED] = { val ccGraph = graph.mapVertices { case (vid, _) => vid } def sendMessage(edge: EdgeTriplet[VertexID, ED]) = { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/algorithms/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/algorithms/PageRank.scala index 0292b7316d..b423104eda 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/algorithms/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/algorithms/PageRank.scala @@ -1,5 +1,7 @@ package org.apache.spark.graphx.algorithms +import scala.reflect.ClassTag + import org.apache.spark.Logging import org.apache.spark.graphx._ @@ -42,7 +44,7 @@ object PageRank extends Logging { * containing the normalized weight. * */ - def run[VD: Manifest, ED: Manifest]( + def run[VD: ClassTag, ED: ClassTag]( graph: Graph[VD, ED], numIter: Int, resetProb: Double = 0.15): Graph[Double, Double] = { @@ -109,7 +111,7 @@ object PageRank extends Logging { * @return the graph containing with each vertex containing the PageRank and each edge * containing the normalized weight. */ - def runUntillConvergence[VD: Manifest, ED: Manifest]( + def runUntilConvergence[VD: ClassTag, ED: ClassTag]( graph: Graph[VD, ED], tol: Double, resetProb: Double = 0.15): Graph[Double, Double] = { // Initialize the pagerankGraph with each edge attribute @@ -153,53 +155,4 @@ object PageRank extends Logging { .mapVertices((vid, attr) => attr._1) } // end of deltaPageRank - def runStandalone[VD: Manifest, ED: Manifest]( - graph: Graph[VD, ED], tol: Double, resetProb: Double = 0.15): VertexRDD[Double] = { - - // Initialize the ranks - var ranks: VertexRDD[Double] = graph.vertices.mapValues((vid, attr) => resetProb).cache() - - // Initialize the delta graph where each vertex stores its delta and each edge knows its weight - var deltaGraph: Graph[Double, Double] = - graph.outerJoinVertices(graph.outDegrees)((vid, vdata, deg) => deg.getOrElse(0)) - .mapTriplets(e => 1.0 / e.srcAttr) - .mapVertices((vid, degree) => resetProb).cache() - var numDeltas: Long = ranks.count() - - var prevDeltas: Option[VertexRDD[Double]] = None - - var i = 0 - val weight = (1.0 - resetProb) - while (numDeltas > 0) { - // Compute new deltas. Only deltas that existed in the last round (i.e., were greater than - // `tol`) get to send messages; those that were less than `tol` would send messages less than - // `tol` as well. - val deltas = deltaGraph - .mapReduceTriplets[Double]( - et => Iterator((et.dstId, et.srcAttr * et.attr * weight)), - _ + _, - prevDeltas.map((_, EdgeDirection.Out))) - .filter { case (vid, delta) => delta > tol } - .cache() - prevDeltas = Some(deltas) - numDeltas = deltas.count() - logInfo("Standalone PageRank: iter %d has %d deltas".format(i, numDeltas)) - - // Update deltaGraph with the deltas - deltaGraph = deltaGraph.outerJoinVertices(deltas) { (vid, old, newOpt) => - newOpt.getOrElse(old) - }.cache() - - // Update ranks - ranks = ranks.leftZipJoin(deltas) { (vid, oldRank, deltaOpt) => - oldRank + deltaOpt.getOrElse(0.0) - } - ranks.foreach(x => {}) // force the iteration for ease of debugging - - i += 1 - } - - ranks - } - } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/algorithms/StronglyConnectedComponents.scala b/graphx/src/main/scala/org/apache/spark/graphx/algorithms/StronglyConnectedComponents.scala index f64fc3ef0f..49ec91aedd 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/algorithms/StronglyConnectedComponents.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/algorithms/StronglyConnectedComponents.scala @@ -1,12 +1,14 @@ package org.apache.spark.graphx.algorithms +import scala.reflect.ClassTag + import org.apache.spark.graphx._ object StronglyConnectedComponents { /** - * Compute the strongly connected component (SCC) of each vertex and return an RDD with the vertex - * value containing the lowest vertex id in the SCC containing that vertex. + * Compute the strongly connected component (SCC) of each vertex and return a graph with the + * vertex value containing the lowest vertex id in the SCC containing that vertex. * * @tparam VD the vertex attribute type (discarded in the computation) * @tparam ED the edge attribute type (preserved in the computation) @@ -15,7 +17,7 @@ object StronglyConnectedComponents { * * @return a graph with vertex attributes containing the smallest vertex id in each SCC */ - def run[VD: Manifest, ED: Manifest](graph: Graph[VD, ED], numIter: Int): Graph[VertexID, ED] = { + def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED], numIter: Int): Graph[VertexID, ED] = { // the graph we update with final SCC ids, and the graph we return at the end var sccGraph = graph.mapVertices { case (vid, _) => vid } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/algorithms/package.scala b/graphx/src/main/scala/org/apache/spark/graphx/algorithms/package.scala new file mode 100644 index 0000000000..fbabf1257c --- /dev/null +++ b/graphx/src/main/scala/org/apache/spark/graphx/algorithms/package.scala @@ -0,0 +1,8 @@ +package org.apache.spark.graphx + +import scala.reflect.ClassTag + +package object algorithms { + implicit def graphToAlgorithms[VD: ClassTag, ED: ClassTag]( + graph: Graph[VD, ED]): Algorithms[VD, ED] = new Algorithms(graph) +} diff --git a/graphx/src/test/scala/org/apache/spark/graphx/algorithms/ConnectedComponentsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/algorithms/ConnectedComponentsSuite.scala index 5e2ecfcde9..209191ef07 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/algorithms/ConnectedComponentsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/algorithms/ConnectedComponentsSuite.scala @@ -14,7 +14,7 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { test("Grid Connected Components") { withSpark { sc => val gridGraph = GraphGenerators.gridGraph(sc, 10, 10).cache() - val ccGraph = ConnectedComponents.run(gridGraph).cache() + val ccGraph = gridGraph.connectedComponents().cache() val maxCCid = ccGraph.vertices.map { case (vid, ccId) => ccId }.sum assert(maxCCid === 0) } @@ -24,7 +24,7 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { test("Reverse Grid Connected Components") { withSpark { sc => val gridGraph = GraphGenerators.gridGraph(sc, 10, 10).reverse.cache() - val ccGraph = ConnectedComponents.run(gridGraph).cache() + val ccGraph = gridGraph.connectedComponents().cache() val maxCCid = ccGraph.vertices.map { case (vid, ccId) => ccId }.sum assert(maxCCid === 0) } @@ -37,7 +37,7 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { val chain2 = (10 until 20).map(x => (x, x+1) ) val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) } val twoChains = Graph.fromEdgeTuples(rawEdges, 1.0).cache() - val ccGraph = ConnectedComponents.run(twoChains).cache() + val ccGraph = twoChains.connectedComponents().cache() val vertices = ccGraph.vertices.collect() for ( (id, cc) <- vertices ) { if(id < 10) { assert(cc === 0) } @@ -60,7 +60,7 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { val chain2 = (10 until 20).map(x => (x, x+1) ) val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) } val twoChains = Graph.fromEdgeTuples(rawEdges, true).reverse.cache() - val ccGraph = ConnectedComponents.run(twoChains).cache() + val ccGraph = twoChains.connectedComponents().cache() val vertices = ccGraph.vertices.collect for ( (id, cc) <- vertices ) { if (id < 10) { diff --git a/graphx/src/test/scala/org/apache/spark/graphx/algorithms/PageRankSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/algorithms/PageRankSuite.scala index e365b1e230..cd857bd3a1 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/algorithms/PageRankSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/algorithms/PageRankSuite.scala @@ -2,13 +2,12 @@ package org.apache.spark.graphx.algorithms import org.scalatest.FunSuite -import org.apache.spark.graphx._ import org.apache.spark.SparkContext import org.apache.spark.SparkContext._ -import org.apache.spark.rdd._ - +import org.apache.spark.graphx._ +import org.apache.spark.graphx.algorithms._ import org.apache.spark.graphx.util.GraphGenerators - +import org.apache.spark.rdd._ object GridPageRank { def apply(nRows: Int, nCols: Int, nIter: Int, resetProb: Double) = { @@ -58,8 +57,8 @@ class PageRankSuite extends FunSuite with LocalSparkContext { val resetProb = 0.15 val errorTol = 1.0e-5 - val staticRanks1 = PageRank.run(starGraph, numIter = 1, resetProb).vertices.cache() - val staticRanks2 = PageRank.run(starGraph, numIter = 2, resetProb).vertices.cache() + val staticRanks1 = starGraph.staticPageRank(numIter = 1, resetProb).vertices.cache() + val staticRanks2 = starGraph.staticPageRank(numIter = 2, resetProb).vertices.cache() // Static PageRank should only take 2 iterations to converge val notMatching = staticRanks1.innerZipJoin(staticRanks2) { (vid, pr1, pr2) => @@ -74,10 +73,8 @@ class PageRankSuite extends FunSuite with LocalSparkContext { } assert(staticErrors.sum === 0) - val dynamicRanks = PageRank.runUntillConvergence(starGraph, 0, resetProb).vertices.cache() - val standaloneRanks = PageRank.runStandalone(starGraph, 0, resetProb).cache() + val dynamicRanks = starGraph.pageRank(0, resetProb).vertices.cache() assert(compareRanks(staticRanks2, dynamicRanks) < errorTol) - assert(compareRanks(staticRanks2, standaloneRanks) < errorTol) } } // end of test Star PageRank @@ -93,14 +90,12 @@ class PageRankSuite extends FunSuite with LocalSparkContext { val errorTol = 1.0e-5 val gridGraph = GraphGenerators.gridGraph(sc, rows, cols).cache() - val staticRanks = PageRank.run(gridGraph, numIter, resetProb).vertices.cache() - val dynamicRanks = PageRank.runUntillConvergence(gridGraph, tol, resetProb).vertices.cache() - val standaloneRanks = PageRank.runStandalone(gridGraph, tol, resetProb).cache() + val staticRanks = gridGraph.staticPageRank(numIter, resetProb).vertices.cache() + val dynamicRanks = gridGraph.pageRank(tol, resetProb).vertices.cache() val referenceRanks = VertexRDD(sc.parallelize(GridPageRank(rows, cols, numIter, resetProb))) assert(compareRanks(staticRanks, referenceRanks) < errorTol) assert(compareRanks(dynamicRanks, referenceRanks) < errorTol) - assert(compareRanks(standaloneRanks, referenceRanks) < errorTol) } } // end of Grid PageRank @@ -115,12 +110,10 @@ class PageRankSuite extends FunSuite with LocalSparkContext { val numIter = 10 val errorTol = 1.0e-5 - val staticRanks = PageRank.run(chain, numIter, resetProb).vertices.cache() - val dynamicRanks = PageRank.runUntillConvergence(chain, tol, resetProb).vertices.cache() - val standaloneRanks = PageRank.runStandalone(chain, tol, resetProb).cache() + val staticRanks = chain.staticPageRank(numIter, resetProb).vertices.cache() + val dynamicRanks = chain.pageRank(tol, resetProb).vertices.cache() assert(compareRanks(staticRanks, dynamicRanks) < errorTol) - assert(compareRanks(dynamicRanks, standaloneRanks) < errorTol) } } } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/algorithms/StronglyConnectedComponentsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/algorithms/StronglyConnectedComponentsSuite.scala index 696b80944e..fee7d20161 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/algorithms/StronglyConnectedComponentsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/algorithms/StronglyConnectedComponentsSuite.scala @@ -16,7 +16,7 @@ class StronglyConnectedComponentsSuite extends FunSuite with LocalSparkContext { val vertices = sc.parallelize((1L to 5L).map(x => (x, -1))) val edges = sc.parallelize(Seq.empty[Edge[Int]]) val graph = Graph(vertices, edges) - val sccGraph = StronglyConnectedComponents.run(graph, 5) + val sccGraph = graph.stronglyConnectedComponents(5) for ((id, scc) <- sccGraph.vertices.collect) { assert(id == scc) } @@ -27,7 +27,7 @@ class StronglyConnectedComponentsSuite extends FunSuite with LocalSparkContext { withSpark { sc => val rawEdges = sc.parallelize((0L to 6L).map(x => (x, (x + 1) % 7))) val graph = Graph.fromEdgeTuples(rawEdges, -1) - val sccGraph = StronglyConnectedComponents.run(graph, 20) + val sccGraph = graph.stronglyConnectedComponents(20) for ((id, scc) <- sccGraph.vertices.collect) { assert(0L == scc) } @@ -42,7 +42,7 @@ class StronglyConnectedComponentsSuite extends FunSuite with LocalSparkContext { Array(6L -> 0L, 5L -> 7L) val rawEdges = sc.parallelize(edges) val graph = Graph.fromEdgeTuples(rawEdges, -1) - val sccGraph = StronglyConnectedComponents.run(graph, 20) + val sccGraph = graph.stronglyConnectedComponents(20) for ((id, scc) <- sccGraph.vertices.collect) { if (id < 3) assert(0L == scc) diff --git a/graphx/src/test/scala/org/apache/spark/graphx/algorithms/TriangleCountSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/algorithms/TriangleCountSuite.scala index 0e59912754..b85b289da6 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/algorithms/TriangleCountSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/algorithms/TriangleCountSuite.scala @@ -15,7 +15,7 @@ class TriangleCountSuite extends FunSuite with LocalSparkContext { withSpark { sc => val rawEdges = sc.parallelize(Array( 0L->1L, 1L->2L, 2L->0L ), 2) val graph = Graph.fromEdgeTuples(rawEdges, true).cache() - val triangleCount = TriangleCount.run(graph) + val triangleCount = graph.triangleCount() val verts = triangleCount.vertices verts.collect.foreach { case (vid, count) => assert(count === 1) } } @@ -27,7 +27,7 @@ class TriangleCountSuite extends FunSuite with LocalSparkContext { Array(0L -> -1L, -1L -> -2L, -2L -> 0L) val rawEdges = sc.parallelize(triangles, 2) val graph = Graph.fromEdgeTuples(rawEdges, true).cache() - val triangleCount = TriangleCount.run(graph) + val triangleCount = graph.triangleCount() val verts = triangleCount.vertices verts.collect().foreach { case (vid, count) => if (vid == 0) { @@ -47,7 +47,7 @@ class TriangleCountSuite extends FunSuite with LocalSparkContext { val revTriangles = triangles.map { case (a,b) => (b,a) } val rawEdges = sc.parallelize(triangles ++ revTriangles, 2) val graph = Graph.fromEdgeTuples(rawEdges, true).cache() - val triangleCount = TriangleCount.run(graph) + val triangleCount = graph.triangleCount() val verts = triangleCount.vertices verts.collect().foreach { case (vid, count) => if (vid == 0) { @@ -64,7 +64,7 @@ class TriangleCountSuite extends FunSuite with LocalSparkContext { val rawEdges = sc.parallelize(Array(0L -> 1L, 1L -> 2L, 2L -> 0L) ++ Array(0L -> 1L, 1L -> 2L, 2L -> 0L), 2) val graph = Graph.fromEdgeTuples(rawEdges, true, uniqueEdges = Some(RandomVertexCut)).cache() - val triangleCount = TriangleCount.run(graph) + val triangleCount = graph.triangleCount() val verts = triangleCount.vertices verts.collect.foreach { case (vid, count) => assert(count === 1) } } -- cgit v1.2.3