From b7c92dded33e61976dea10beef88ab52e2009b42 Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Thu, 9 Jan 2014 20:44:28 -0800 Subject: Add implicit algorithm methods for Graph; remove standalone PageRank --- .../algorithms/ConnectedComponentsSuite.scala | 8 +++---- .../spark/graphx/algorithms/PageRankSuite.scala | 27 ++++++++-------------- .../StronglyConnectedComponentsSuite.scala | 6 ++--- .../graphx/algorithms/TriangleCountSuite.scala | 8 +++---- 4 files changed, 21 insertions(+), 28 deletions(-) (limited to 'graphx/src/test/scala/org/apache') diff --git a/graphx/src/test/scala/org/apache/spark/graphx/algorithms/ConnectedComponentsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/algorithms/ConnectedComponentsSuite.scala index 5e2ecfcde9..209191ef07 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/algorithms/ConnectedComponentsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/algorithms/ConnectedComponentsSuite.scala @@ -14,7 +14,7 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { test("Grid Connected Components") { withSpark { sc => val gridGraph = GraphGenerators.gridGraph(sc, 10, 10).cache() - val ccGraph = ConnectedComponents.run(gridGraph).cache() + val ccGraph = gridGraph.connectedComponents().cache() val maxCCid = ccGraph.vertices.map { case (vid, ccId) => ccId }.sum assert(maxCCid === 0) } @@ -24,7 +24,7 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { test("Reverse Grid Connected Components") { withSpark { sc => val gridGraph = GraphGenerators.gridGraph(sc, 10, 10).reverse.cache() - val ccGraph = ConnectedComponents.run(gridGraph).cache() + val ccGraph = gridGraph.connectedComponents().cache() val maxCCid = ccGraph.vertices.map { case (vid, ccId) => ccId }.sum assert(maxCCid === 0) } @@ -37,7 +37,7 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { val chain2 = (10 until 20).map(x => (x, x+1) ) val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) } val twoChains = Graph.fromEdgeTuples(rawEdges, 1.0).cache() - val ccGraph = ConnectedComponents.run(twoChains).cache() + val ccGraph = twoChains.connectedComponents().cache() val vertices = ccGraph.vertices.collect() for ( (id, cc) <- vertices ) { if(id < 10) { assert(cc === 0) } @@ -60,7 +60,7 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { val chain2 = (10 until 20).map(x => (x, x+1) ) val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) } val twoChains = Graph.fromEdgeTuples(rawEdges, true).reverse.cache() - val ccGraph = ConnectedComponents.run(twoChains).cache() + val ccGraph = twoChains.connectedComponents().cache() val vertices = ccGraph.vertices.collect for ( (id, cc) <- vertices ) { if (id < 10) { diff --git a/graphx/src/test/scala/org/apache/spark/graphx/algorithms/PageRankSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/algorithms/PageRankSuite.scala index e365b1e230..cd857bd3a1 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/algorithms/PageRankSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/algorithms/PageRankSuite.scala @@ -2,13 +2,12 @@ package org.apache.spark.graphx.algorithms import org.scalatest.FunSuite -import org.apache.spark.graphx._ import org.apache.spark.SparkContext import org.apache.spark.SparkContext._ -import org.apache.spark.rdd._ - +import org.apache.spark.graphx._ +import org.apache.spark.graphx.algorithms._ import org.apache.spark.graphx.util.GraphGenerators - +import org.apache.spark.rdd._ object GridPageRank { def apply(nRows: Int, nCols: Int, nIter: Int, resetProb: Double) = { @@ -58,8 +57,8 @@ class PageRankSuite extends FunSuite with LocalSparkContext { val resetProb = 0.15 val errorTol = 1.0e-5 - val staticRanks1 = PageRank.run(starGraph, numIter = 1, resetProb).vertices.cache() - val staticRanks2 = PageRank.run(starGraph, numIter = 2, resetProb).vertices.cache() + val staticRanks1 = starGraph.staticPageRank(numIter = 1, resetProb).vertices.cache() + val staticRanks2 = starGraph.staticPageRank(numIter = 2, resetProb).vertices.cache() // Static PageRank should only take 2 iterations to converge val notMatching = staticRanks1.innerZipJoin(staticRanks2) { (vid, pr1, pr2) => @@ -74,10 +73,8 @@ class PageRankSuite extends FunSuite with LocalSparkContext { } assert(staticErrors.sum === 0) - val dynamicRanks = PageRank.runUntillConvergence(starGraph, 0, resetProb).vertices.cache() - val standaloneRanks = PageRank.runStandalone(starGraph, 0, resetProb).cache() + val dynamicRanks = starGraph.pageRank(0, resetProb).vertices.cache() assert(compareRanks(staticRanks2, dynamicRanks) < errorTol) - assert(compareRanks(staticRanks2, standaloneRanks) < errorTol) } } // end of test Star PageRank @@ -93,14 +90,12 @@ class PageRankSuite extends FunSuite with LocalSparkContext { val errorTol = 1.0e-5 val gridGraph = GraphGenerators.gridGraph(sc, rows, cols).cache() - val staticRanks = PageRank.run(gridGraph, numIter, resetProb).vertices.cache() - val dynamicRanks = PageRank.runUntillConvergence(gridGraph, tol, resetProb).vertices.cache() - val standaloneRanks = PageRank.runStandalone(gridGraph, tol, resetProb).cache() + val staticRanks = gridGraph.staticPageRank(numIter, resetProb).vertices.cache() + val dynamicRanks = gridGraph.pageRank(tol, resetProb).vertices.cache() val referenceRanks = VertexRDD(sc.parallelize(GridPageRank(rows, cols, numIter, resetProb))) assert(compareRanks(staticRanks, referenceRanks) < errorTol) assert(compareRanks(dynamicRanks, referenceRanks) < errorTol) - assert(compareRanks(standaloneRanks, referenceRanks) < errorTol) } } // end of Grid PageRank @@ -115,12 +110,10 @@ class PageRankSuite extends FunSuite with LocalSparkContext { val numIter = 10 val errorTol = 1.0e-5 - val staticRanks = PageRank.run(chain, numIter, resetProb).vertices.cache() - val dynamicRanks = PageRank.runUntillConvergence(chain, tol, resetProb).vertices.cache() - val standaloneRanks = PageRank.runStandalone(chain, tol, resetProb).cache() + val staticRanks = chain.staticPageRank(numIter, resetProb).vertices.cache() + val dynamicRanks = chain.pageRank(tol, resetProb).vertices.cache() assert(compareRanks(staticRanks, dynamicRanks) < errorTol) - assert(compareRanks(dynamicRanks, standaloneRanks) < errorTol) } } } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/algorithms/StronglyConnectedComponentsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/algorithms/StronglyConnectedComponentsSuite.scala index 696b80944e..fee7d20161 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/algorithms/StronglyConnectedComponentsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/algorithms/StronglyConnectedComponentsSuite.scala @@ -16,7 +16,7 @@ class StronglyConnectedComponentsSuite extends FunSuite with LocalSparkContext { val vertices = sc.parallelize((1L to 5L).map(x => (x, -1))) val edges = sc.parallelize(Seq.empty[Edge[Int]]) val graph = Graph(vertices, edges) - val sccGraph = StronglyConnectedComponents.run(graph, 5) + val sccGraph = graph.stronglyConnectedComponents(5) for ((id, scc) <- sccGraph.vertices.collect) { assert(id == scc) } @@ -27,7 +27,7 @@ class StronglyConnectedComponentsSuite extends FunSuite with LocalSparkContext { withSpark { sc => val rawEdges = sc.parallelize((0L to 6L).map(x => (x, (x + 1) % 7))) val graph = Graph.fromEdgeTuples(rawEdges, -1) - val sccGraph = StronglyConnectedComponents.run(graph, 20) + val sccGraph = graph.stronglyConnectedComponents(20) for ((id, scc) <- sccGraph.vertices.collect) { assert(0L == scc) } @@ -42,7 +42,7 @@ class StronglyConnectedComponentsSuite extends FunSuite with LocalSparkContext { Array(6L -> 0L, 5L -> 7L) val rawEdges = sc.parallelize(edges) val graph = Graph.fromEdgeTuples(rawEdges, -1) - val sccGraph = StronglyConnectedComponents.run(graph, 20) + val sccGraph = graph.stronglyConnectedComponents(20) for ((id, scc) <- sccGraph.vertices.collect) { if (id < 3) assert(0L == scc) diff --git a/graphx/src/test/scala/org/apache/spark/graphx/algorithms/TriangleCountSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/algorithms/TriangleCountSuite.scala index 0e59912754..b85b289da6 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/algorithms/TriangleCountSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/algorithms/TriangleCountSuite.scala @@ -15,7 +15,7 @@ class TriangleCountSuite extends FunSuite with LocalSparkContext { withSpark { sc => val rawEdges = sc.parallelize(Array( 0L->1L, 1L->2L, 2L->0L ), 2) val graph = Graph.fromEdgeTuples(rawEdges, true).cache() - val triangleCount = TriangleCount.run(graph) + val triangleCount = graph.triangleCount() val verts = triangleCount.vertices verts.collect.foreach { case (vid, count) => assert(count === 1) } } @@ -27,7 +27,7 @@ class TriangleCountSuite extends FunSuite with LocalSparkContext { Array(0L -> -1L, -1L -> -2L, -2L -> 0L) val rawEdges = sc.parallelize(triangles, 2) val graph = Graph.fromEdgeTuples(rawEdges, true).cache() - val triangleCount = TriangleCount.run(graph) + val triangleCount = graph.triangleCount() val verts = triangleCount.vertices verts.collect().foreach { case (vid, count) => if (vid == 0) { @@ -47,7 +47,7 @@ class TriangleCountSuite extends FunSuite with LocalSparkContext { val revTriangles = triangles.map { case (a,b) => (b,a) } val rawEdges = sc.parallelize(triangles ++ revTriangles, 2) val graph = Graph.fromEdgeTuples(rawEdges, true).cache() - val triangleCount = TriangleCount.run(graph) + val triangleCount = graph.triangleCount() val verts = triangleCount.vertices verts.collect().foreach { case (vid, count) => if (vid == 0) { @@ -64,7 +64,7 @@ class TriangleCountSuite extends FunSuite with LocalSparkContext { val rawEdges = sc.parallelize(Array(0L -> 1L, 1L -> 2L, 2L -> 0L) ++ Array(0L -> 1L, 1L -> 2L, 2L -> 0L), 2) val graph = Graph.fromEdgeTuples(rawEdges, true, uniqueEdges = Some(RandomVertexCut)).cache() - val triangleCount = TriangleCount.run(graph) + val triangleCount = graph.triangleCount() val verts = triangleCount.vertices verts.collect.foreach { case (vid, count) => assert(count === 1) } } -- cgit v1.2.3