aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAnkur Dave <ankurdave@gmail.com>2014-01-09 20:44:28 -0800
committerAnkur Dave <ankurdave@gmail.com>2014-01-09 20:44:28 -0800
commitb7c92dded33e61976dea10beef88ab52e2009b42 (patch)
tree04e612d4903d8b04a08ea65bbc98c17b5195daa4
parent731f56f309914e3fc7c22c8ef1c8cb9dd40d42c1 (diff)
downloadspark-b7c92dded33e61976dea10beef88ab52e2009b42.tar.gz
spark-b7c92dded33e61976dea10beef88ab52e2009b42.tar.bz2
spark-b7c92dded33e61976dea10beef88ab52e2009b42.zip
Add implicit algorithm methods for Graph; remove standalone PageRank
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/Analytics.scala2
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/algorithms/Algorithms.scala56
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/algorithms/ConnectedComponents.scala6
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/algorithms/PageRank.scala55
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/algorithms/StronglyConnectedComponents.scala8
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/algorithms/package.scala8
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/algorithms/ConnectedComponentsSuite.scala8
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/algorithms/PageRankSuite.scala27
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/algorithms/StronglyConnectedComponentsSuite.scala6
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/algorithms/TriangleCountSuite.scala8
10 files changed, 99 insertions, 85 deletions
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Analytics.scala b/graphx/src/main/scala/org/apache/spark/graphx/Analytics.scala
index 0cafc3fdf9..def6d69190 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/Analytics.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/Analytics.scala
@@ -83,7 +83,7 @@ object Analytics extends Logging {
println("GRAPHX: Number of edges " + graph.edges.count)
//val pr = Analytics.pagerank(graph, numIter)
- val pr = PageRank.runStandalone(graph, tol)
+ val pr = graph.pageRank(tol).vertices
println("GRAPHX: Total rank: " + pr.map(_._2).reduce(_+_))
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/algorithms/Algorithms.scala b/graphx/src/main/scala/org/apache/spark/graphx/algorithms/Algorithms.scala
new file mode 100644
index 0000000000..4af7af545c
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/algorithms/Algorithms.scala
@@ -0,0 +1,56 @@
+package org.apache.spark.graphx.algorithms
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.graphx._
+
+class Algorithms[VD: ClassTag, ED: ClassTag](self: Graph[VD, ED]) {
+ /**
+ * Run a dynamic version of PageRank returning a graph with vertex attributes containing the
+ * PageRank and edge attributes containing the normalized edge weight.
+ *
+ * @see [[org.apache.spark.graphx.algorithms.PageRank]], method `runUntilConvergence`.
+ */
+ def pageRank(tol: Double, resetProb: Double = 0.15): Graph[Double, Double] = {
+ PageRank.runUntilConvergence(self, tol, resetProb)
+ }
+
+ /**
+ * Run PageRank for a fixed number of iterations returning a graph with vertex attributes
+ * containing the PageRank and edge attributes the normalized edge weight.
+ *
+ * @see [[org.apache.spark.graphx.algorithms.PageRank]], method `run`.
+ */
+ def staticPageRank(numIter: Int, resetProb: Double = 0.15): Graph[Double, Double] = {
+ PageRank.run(self, numIter, resetProb)
+ }
+
+ /**
+ * Compute the connected component membership of each vertex and return a graph with the vertex
+ * value containing the lowest vertex id in the connected component containing that vertex.
+ *
+ * @see [[org.apache.spark.graphx.algorithms.ConnectedComponents]]
+ */
+ def connectedComponents(): Graph[VertexID, ED] = {
+ ConnectedComponents.run(self)
+ }
+
+ /**
+ * Compute the number of triangles passing through each vertex.
+ *
+ * @see [[org.apache.spark.graphx.algorithms.TriangleCount]]
+ */
+ def triangleCount(): Graph[Int, ED] = {
+ TriangleCount.run(self)
+ }
+
+ /**
+ * Compute the strongly connected component (SCC) of each vertex and return a graph with the
+ * vertex value containing the lowest vertex id in the SCC containing that vertex.
+ *
+ * @see [[org.apache.spark.graphx.algorithms.StronglyConnectedComponents]]
+ */
+ def stronglyConnectedComponents(numIter: Int): Graph[VertexID, ED] = {
+ StronglyConnectedComponents.run(self, numIter)
+ }
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/algorithms/ConnectedComponents.scala b/graphx/src/main/scala/org/apache/spark/graphx/algorithms/ConnectedComponents.scala
index a0dd36da60..137a81f4d5 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/algorithms/ConnectedComponents.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/algorithms/ConnectedComponents.scala
@@ -1,11 +1,13 @@
package org.apache.spark.graphx.algorithms
+import scala.reflect.ClassTag
+
import org.apache.spark.graphx._
object ConnectedComponents {
/**
- * Compute the connected component membership of each vertex and return an RDD with the vertex
+ * Compute the connected component membership of each vertex and return a graph with the vertex
* value containing the lowest vertex id in the connected component containing that vertex.
*
* @tparam VD the vertex attribute type (discarded in the computation)
@@ -16,7 +18,7 @@ object ConnectedComponents {
* @return a graph with vertex attributes containing the smallest vertex in each
* connected component
*/
- def run[VD: Manifest, ED: Manifest](graph: Graph[VD, ED]): Graph[VertexID, ED] = {
+ def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]): Graph[VertexID, ED] = {
val ccGraph = graph.mapVertices { case (vid, _) => vid }
def sendMessage(edge: EdgeTriplet[VertexID, ED]) = {
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/algorithms/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/algorithms/PageRank.scala
index 0292b7316d..b423104eda 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/algorithms/PageRank.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/algorithms/PageRank.scala
@@ -1,5 +1,7 @@
package org.apache.spark.graphx.algorithms
+import scala.reflect.ClassTag
+
import org.apache.spark.Logging
import org.apache.spark.graphx._
@@ -42,7 +44,7 @@ object PageRank extends Logging {
* containing the normalized weight.
*
*/
- def run[VD: Manifest, ED: Manifest](
+ def run[VD: ClassTag, ED: ClassTag](
graph: Graph[VD, ED], numIter: Int, resetProb: Double = 0.15): Graph[Double, Double] =
{
@@ -109,7 +111,7 @@ object PageRank extends Logging {
* @return the graph containing with each vertex containing the PageRank and each edge
* containing the normalized weight.
*/
- def runUntillConvergence[VD: Manifest, ED: Manifest](
+ def runUntilConvergence[VD: ClassTag, ED: ClassTag](
graph: Graph[VD, ED], tol: Double, resetProb: Double = 0.15): Graph[Double, Double] =
{
// Initialize the pagerankGraph with each edge attribute
@@ -153,53 +155,4 @@ object PageRank extends Logging {
.mapVertices((vid, attr) => attr._1)
} // end of deltaPageRank
- def runStandalone[VD: Manifest, ED: Manifest](
- graph: Graph[VD, ED], tol: Double, resetProb: Double = 0.15): VertexRDD[Double] = {
-
- // Initialize the ranks
- var ranks: VertexRDD[Double] = graph.vertices.mapValues((vid, attr) => resetProb).cache()
-
- // Initialize the delta graph where each vertex stores its delta and each edge knows its weight
- var deltaGraph: Graph[Double, Double] =
- graph.outerJoinVertices(graph.outDegrees)((vid, vdata, deg) => deg.getOrElse(0))
- .mapTriplets(e => 1.0 / e.srcAttr)
- .mapVertices((vid, degree) => resetProb).cache()
- var numDeltas: Long = ranks.count()
-
- var prevDeltas: Option[VertexRDD[Double]] = None
-
- var i = 0
- val weight = (1.0 - resetProb)
- while (numDeltas > 0) {
- // Compute new deltas. Only deltas that existed in the last round (i.e., were greater than
- // `tol`) get to send messages; those that were less than `tol` would send messages less than
- // `tol` as well.
- val deltas = deltaGraph
- .mapReduceTriplets[Double](
- et => Iterator((et.dstId, et.srcAttr * et.attr * weight)),
- _ + _,
- prevDeltas.map((_, EdgeDirection.Out)))
- .filter { case (vid, delta) => delta > tol }
- .cache()
- prevDeltas = Some(deltas)
- numDeltas = deltas.count()
- logInfo("Standalone PageRank: iter %d has %d deltas".format(i, numDeltas))
-
- // Update deltaGraph with the deltas
- deltaGraph = deltaGraph.outerJoinVertices(deltas) { (vid, old, newOpt) =>
- newOpt.getOrElse(old)
- }.cache()
-
- // Update ranks
- ranks = ranks.leftZipJoin(deltas) { (vid, oldRank, deltaOpt) =>
- oldRank + deltaOpt.getOrElse(0.0)
- }
- ranks.foreach(x => {}) // force the iteration for ease of debugging
-
- i += 1
- }
-
- ranks
- }
-
}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/algorithms/StronglyConnectedComponents.scala b/graphx/src/main/scala/org/apache/spark/graphx/algorithms/StronglyConnectedComponents.scala
index f64fc3ef0f..49ec91aedd 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/algorithms/StronglyConnectedComponents.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/algorithms/StronglyConnectedComponents.scala
@@ -1,12 +1,14 @@
package org.apache.spark.graphx.algorithms
+import scala.reflect.ClassTag
+
import org.apache.spark.graphx._
object StronglyConnectedComponents {
/**
- * Compute the strongly connected component (SCC) of each vertex and return an RDD with the vertex
- * value containing the lowest vertex id in the SCC containing that vertex.
+ * Compute the strongly connected component (SCC) of each vertex and return a graph with the
+ * vertex value containing the lowest vertex id in the SCC containing that vertex.
*
* @tparam VD the vertex attribute type (discarded in the computation)
* @tparam ED the edge attribute type (preserved in the computation)
@@ -15,7 +17,7 @@ object StronglyConnectedComponents {
*
* @return a graph with vertex attributes containing the smallest vertex id in each SCC
*/
- def run[VD: Manifest, ED: Manifest](graph: Graph[VD, ED], numIter: Int): Graph[VertexID, ED] = {
+ def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED], numIter: Int): Graph[VertexID, ED] = {
// the graph we update with final SCC ids, and the graph we return at the end
var sccGraph = graph.mapVertices { case (vid, _) => vid }
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/algorithms/package.scala b/graphx/src/main/scala/org/apache/spark/graphx/algorithms/package.scala
new file mode 100644
index 0000000000..fbabf1257c
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/algorithms/package.scala
@@ -0,0 +1,8 @@
+package org.apache.spark.graphx
+
+import scala.reflect.ClassTag
+
+package object algorithms {
+ implicit def graphToAlgorithms[VD: ClassTag, ED: ClassTag](
+ graph: Graph[VD, ED]): Algorithms[VD, ED] = new Algorithms(graph)
+}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/algorithms/ConnectedComponentsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/algorithms/ConnectedComponentsSuite.scala
index 5e2ecfcde9..209191ef07 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/algorithms/ConnectedComponentsSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/algorithms/ConnectedComponentsSuite.scala
@@ -14,7 +14,7 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext {
test("Grid Connected Components") {
withSpark { sc =>
val gridGraph = GraphGenerators.gridGraph(sc, 10, 10).cache()
- val ccGraph = ConnectedComponents.run(gridGraph).cache()
+ val ccGraph = gridGraph.connectedComponents().cache()
val maxCCid = ccGraph.vertices.map { case (vid, ccId) => ccId }.sum
assert(maxCCid === 0)
}
@@ -24,7 +24,7 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext {
test("Reverse Grid Connected Components") {
withSpark { sc =>
val gridGraph = GraphGenerators.gridGraph(sc, 10, 10).reverse.cache()
- val ccGraph = ConnectedComponents.run(gridGraph).cache()
+ val ccGraph = gridGraph.connectedComponents().cache()
val maxCCid = ccGraph.vertices.map { case (vid, ccId) => ccId }.sum
assert(maxCCid === 0)
}
@@ -37,7 +37,7 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext {
val chain2 = (10 until 20).map(x => (x, x+1) )
val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) }
val twoChains = Graph.fromEdgeTuples(rawEdges, 1.0).cache()
- val ccGraph = ConnectedComponents.run(twoChains).cache()
+ val ccGraph = twoChains.connectedComponents().cache()
val vertices = ccGraph.vertices.collect()
for ( (id, cc) <- vertices ) {
if(id < 10) { assert(cc === 0) }
@@ -60,7 +60,7 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext {
val chain2 = (10 until 20).map(x => (x, x+1) )
val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) }
val twoChains = Graph.fromEdgeTuples(rawEdges, true).reverse.cache()
- val ccGraph = ConnectedComponents.run(twoChains).cache()
+ val ccGraph = twoChains.connectedComponents().cache()
val vertices = ccGraph.vertices.collect
for ( (id, cc) <- vertices ) {
if (id < 10) {
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/algorithms/PageRankSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/algorithms/PageRankSuite.scala
index e365b1e230..cd857bd3a1 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/algorithms/PageRankSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/algorithms/PageRankSuite.scala
@@ -2,13 +2,12 @@ package org.apache.spark.graphx.algorithms
import org.scalatest.FunSuite
-import org.apache.spark.graphx._
import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
-import org.apache.spark.rdd._
-
+import org.apache.spark.graphx._
+import org.apache.spark.graphx.algorithms._
import org.apache.spark.graphx.util.GraphGenerators
-
+import org.apache.spark.rdd._
object GridPageRank {
def apply(nRows: Int, nCols: Int, nIter: Int, resetProb: Double) = {
@@ -58,8 +57,8 @@ class PageRankSuite extends FunSuite with LocalSparkContext {
val resetProb = 0.15
val errorTol = 1.0e-5
- val staticRanks1 = PageRank.run(starGraph, numIter = 1, resetProb).vertices.cache()
- val staticRanks2 = PageRank.run(starGraph, numIter = 2, resetProb).vertices.cache()
+ val staticRanks1 = starGraph.staticPageRank(numIter = 1, resetProb).vertices.cache()
+ val staticRanks2 = starGraph.staticPageRank(numIter = 2, resetProb).vertices.cache()
// Static PageRank should only take 2 iterations to converge
val notMatching = staticRanks1.innerZipJoin(staticRanks2) { (vid, pr1, pr2) =>
@@ -74,10 +73,8 @@ class PageRankSuite extends FunSuite with LocalSparkContext {
}
assert(staticErrors.sum === 0)
- val dynamicRanks = PageRank.runUntillConvergence(starGraph, 0, resetProb).vertices.cache()
- val standaloneRanks = PageRank.runStandalone(starGraph, 0, resetProb).cache()
+ val dynamicRanks = starGraph.pageRank(0, resetProb).vertices.cache()
assert(compareRanks(staticRanks2, dynamicRanks) < errorTol)
- assert(compareRanks(staticRanks2, standaloneRanks) < errorTol)
}
} // end of test Star PageRank
@@ -93,14 +90,12 @@ class PageRankSuite extends FunSuite with LocalSparkContext {
val errorTol = 1.0e-5
val gridGraph = GraphGenerators.gridGraph(sc, rows, cols).cache()
- val staticRanks = PageRank.run(gridGraph, numIter, resetProb).vertices.cache()
- val dynamicRanks = PageRank.runUntillConvergence(gridGraph, tol, resetProb).vertices.cache()
- val standaloneRanks = PageRank.runStandalone(gridGraph, tol, resetProb).cache()
+ val staticRanks = gridGraph.staticPageRank(numIter, resetProb).vertices.cache()
+ val dynamicRanks = gridGraph.pageRank(tol, resetProb).vertices.cache()
val referenceRanks = VertexRDD(sc.parallelize(GridPageRank(rows, cols, numIter, resetProb)))
assert(compareRanks(staticRanks, referenceRanks) < errorTol)
assert(compareRanks(dynamicRanks, referenceRanks) < errorTol)
- assert(compareRanks(standaloneRanks, referenceRanks) < errorTol)
}
} // end of Grid PageRank
@@ -115,12 +110,10 @@ class PageRankSuite extends FunSuite with LocalSparkContext {
val numIter = 10
val errorTol = 1.0e-5
- val staticRanks = PageRank.run(chain, numIter, resetProb).vertices.cache()
- val dynamicRanks = PageRank.runUntillConvergence(chain, tol, resetProb).vertices.cache()
- val standaloneRanks = PageRank.runStandalone(chain, tol, resetProb).cache()
+ val staticRanks = chain.staticPageRank(numIter, resetProb).vertices.cache()
+ val dynamicRanks = chain.pageRank(tol, resetProb).vertices.cache()
assert(compareRanks(staticRanks, dynamicRanks) < errorTol)
- assert(compareRanks(dynamicRanks, standaloneRanks) < errorTol)
}
}
}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/algorithms/StronglyConnectedComponentsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/algorithms/StronglyConnectedComponentsSuite.scala
index 696b80944e..fee7d20161 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/algorithms/StronglyConnectedComponentsSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/algorithms/StronglyConnectedComponentsSuite.scala
@@ -16,7 +16,7 @@ class StronglyConnectedComponentsSuite extends FunSuite with LocalSparkContext {
val vertices = sc.parallelize((1L to 5L).map(x => (x, -1)))
val edges = sc.parallelize(Seq.empty[Edge[Int]])
val graph = Graph(vertices, edges)
- val sccGraph = StronglyConnectedComponents.run(graph, 5)
+ val sccGraph = graph.stronglyConnectedComponents(5)
for ((id, scc) <- sccGraph.vertices.collect) {
assert(id == scc)
}
@@ -27,7 +27,7 @@ class StronglyConnectedComponentsSuite extends FunSuite with LocalSparkContext {
withSpark { sc =>
val rawEdges = sc.parallelize((0L to 6L).map(x => (x, (x + 1) % 7)))
val graph = Graph.fromEdgeTuples(rawEdges, -1)
- val sccGraph = StronglyConnectedComponents.run(graph, 20)
+ val sccGraph = graph.stronglyConnectedComponents(20)
for ((id, scc) <- sccGraph.vertices.collect) {
assert(0L == scc)
}
@@ -42,7 +42,7 @@ class StronglyConnectedComponentsSuite extends FunSuite with LocalSparkContext {
Array(6L -> 0L, 5L -> 7L)
val rawEdges = sc.parallelize(edges)
val graph = Graph.fromEdgeTuples(rawEdges, -1)
- val sccGraph = StronglyConnectedComponents.run(graph, 20)
+ val sccGraph = graph.stronglyConnectedComponents(20)
for ((id, scc) <- sccGraph.vertices.collect) {
if (id < 3)
assert(0L == scc)
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/algorithms/TriangleCountSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/algorithms/TriangleCountSuite.scala
index 0e59912754..b85b289da6 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/algorithms/TriangleCountSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/algorithms/TriangleCountSuite.scala
@@ -15,7 +15,7 @@ class TriangleCountSuite extends FunSuite with LocalSparkContext {
withSpark { sc =>
val rawEdges = sc.parallelize(Array( 0L->1L, 1L->2L, 2L->0L ), 2)
val graph = Graph.fromEdgeTuples(rawEdges, true).cache()
- val triangleCount = TriangleCount.run(graph)
+ val triangleCount = graph.triangleCount()
val verts = triangleCount.vertices
verts.collect.foreach { case (vid, count) => assert(count === 1) }
}
@@ -27,7 +27,7 @@ class TriangleCountSuite extends FunSuite with LocalSparkContext {
Array(0L -> -1L, -1L -> -2L, -2L -> 0L)
val rawEdges = sc.parallelize(triangles, 2)
val graph = Graph.fromEdgeTuples(rawEdges, true).cache()
- val triangleCount = TriangleCount.run(graph)
+ val triangleCount = graph.triangleCount()
val verts = triangleCount.vertices
verts.collect().foreach { case (vid, count) =>
if (vid == 0) {
@@ -47,7 +47,7 @@ class TriangleCountSuite extends FunSuite with LocalSparkContext {
val revTriangles = triangles.map { case (a,b) => (b,a) }
val rawEdges = sc.parallelize(triangles ++ revTriangles, 2)
val graph = Graph.fromEdgeTuples(rawEdges, true).cache()
- val triangleCount = TriangleCount.run(graph)
+ val triangleCount = graph.triangleCount()
val verts = triangleCount.vertices
verts.collect().foreach { case (vid, count) =>
if (vid == 0) {
@@ -64,7 +64,7 @@ class TriangleCountSuite extends FunSuite with LocalSparkContext {
val rawEdges = sc.parallelize(Array(0L -> 1L, 1L -> 2L, 2L -> 0L) ++
Array(0L -> 1L, 1L -> 2L, 2L -> 0L), 2)
val graph = Graph.fromEdgeTuples(rawEdges, true, uniqueEdges = Some(RandomVertexCut)).cache()
- val triangleCount = TriangleCount.run(graph)
+ val triangleCount = graph.triangleCount()
val verts = triangleCount.vertices
verts.collect.foreach { case (vid, count) => assert(count === 1) }
}