aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--graphx/pom.xml5
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala12
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala81
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala24
4 files changed, 121 insertions, 1 deletions
diff --git a/graphx/pom.xml b/graphx/pom.xml
index bd4e53371b..10d5ba93eb 100644
--- a/graphx/pom.xml
+++ b/graphx/pom.xml
@@ -47,6 +47,11 @@
<scope>test</scope>
</dependency>
<dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-mllib-local_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+ <dependency>
<groupId>org.apache.xbean</groupId>
<artifactId>xbean-asm5-shaded</artifactId>
</dependency>
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
index 868658dfe5..90907300be 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
@@ -20,9 +20,10 @@ package org.apache.spark.graphx
import scala.reflect.ClassTag
import scala.util.Random
-import org.apache.spark.SparkException
import org.apache.spark.graphx.lib._
+import org.apache.spark.ml.linalg.Vector
import org.apache.spark.rdd.RDD
+import org.apache.spark.SparkException
/**
* Contains additional functionality for [[Graph]]. All operations are expressed in terms of the
@@ -392,6 +393,15 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
}
/**
+ * Run parallel personalized PageRank for a given array of source vertices, such
+ * that all random walks are started relative to the source vertices
+ */
+ def staticParallelPersonalizedPageRank(sources: Array[VertexId], numIter: Int,
+ resetProb: Double = 0.15) : Graph[Vector, Double] = {
+ PageRank.runParallelPersonalizedPageRank(graph, numIter, resetProb, sources)
+ }
+
+ /**
* Run Personalized PageRank for a fixed number of iterations with
* with all iterations originating at the source node
* returning a graph with vertex attributes
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala
index 2f5bd4ed4f..f4b00757a8 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala
@@ -19,8 +19,11 @@ package org.apache.spark.graphx.lib
import scala.reflect.ClassTag
+import breeze.linalg.{Vector => BV}
+
import org.apache.spark.graphx._
import org.apache.spark.internal.Logging
+import org.apache.spark.ml.linalg.{Vector, Vectors}
/**
* PageRank algorithm implementation. There are two implementations of PageRank implemented.
@@ -163,6 +166,84 @@ object PageRank extends Logging {
}
/**
+ * Run Personalized PageRank for a fixed number of iterations, for a
+ * set of starting nodes in parallel. Returns a graph with vertex attributes
+ * containing the pagerank relative to all starting nodes (as a sparse vector) and
+ * edge attributes the normalized edge weight
+ *
+ * @tparam VD The original vertex attribute (not used)
+ * @tparam ED The original edge attribute (not used)
+ *
+ * @param graph The graph on which to compute personalized pagerank
+ * @param numIter The number of iterations to run
+ * @param resetProb The random reset probability
+ * @param sources The list of sources to compute personalized pagerank from
+ * @return the graph with vertex attributes
+ * containing the pagerank relative to all starting nodes (as a sparse vector) and
+ * edge attributes the normalized edge weight
+ */
+ def runParallelPersonalizedPageRank[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED],
+ numIter: Int, resetProb: Double = 0.15,
+ sources: Array[VertexId]): Graph[Vector, Double] = {
+ // TODO if one sources vertex id is outside of the int range
+ // we won't be able to store its activations in a sparse vector
+ val zero = Vectors.sparse(sources.size, List()).asBreeze
+ val sourcesInitMap = sources.zipWithIndex.map { case (vid, i) =>
+ val v = Vectors.sparse(sources.size, Array(i), Array(resetProb)).asBreeze
+ (vid, v)
+ }.toMap
+ val sc = graph.vertices.sparkContext
+ val sourcesInitMapBC = sc.broadcast(sourcesInitMap)
+ // Initialize the PageRank graph with each edge attribute having
+ // weight 1/outDegree and each source vertex with attribute 1.0.
+ var rankGraph = graph
+ // Associate the degree with each vertex
+ .outerJoinVertices(graph.outDegrees) { (vid, vdata, deg) => deg.getOrElse(0) }
+ // Set the weight on the edges based on the degree
+ .mapTriplets(e => 1.0 / e.srcAttr, TripletFields.Src)
+ .mapVertices { (vid, attr) =>
+ if (sourcesInitMapBC.value contains vid) {
+ sourcesInitMapBC.value(vid)
+ } else {
+ zero
+ }
+ }
+
+ var i = 0
+ while (i < numIter) {
+ val prevRankGraph = rankGraph
+ // Propagates the message along outbound edges
+ // and adding start nodes back in with activation resetProb
+ val rankUpdates = rankGraph.aggregateMessages[BV[Double]](
+ ctx => ctx.sendToDst(ctx.srcAttr :* ctx.attr),
+ (a : BV[Double], b : BV[Double]) => a :+ b, TripletFields.Src)
+
+ rankGraph = rankGraph.joinVertices(rankUpdates) {
+ (vid, oldRank, msgSum) =>
+ val popActivations: BV[Double] = msgSum :* (1.0 - resetProb)
+ val resetActivations = if (sourcesInitMapBC.value contains vid) {
+ sourcesInitMapBC.value(vid)
+ } else {
+ zero
+ }
+ popActivations :+ resetActivations
+ }.cache()
+
+ rankGraph.edges.foreachPartition(x => {}) // also materializes rankGraph.vertices
+ prevRankGraph.vertices.unpersist(false)
+ prevRankGraph.edges.unpersist(false)
+
+ logInfo(s"Parallel Personalized PageRank finished iteration $i.")
+
+ i += 1
+ }
+
+ rankGraph.mapVertices { (vid, attr) =>
+ Vectors.fromBreeze(attr)
+ }
+ }
+
+ /**
* Run a dynamic version of PageRank returning a graph with vertex attributes containing the
* PageRank and edge attributes containing the normalized edge weight.
*
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala
index bdff31446f..b6305c8d00 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala
@@ -118,11 +118,29 @@ class PageRankSuite extends SparkFunSuite with LocalSparkContext {
val dynamicRanks = starGraph.personalizedPageRank(0, 0, resetProb).vertices.cache()
assert(compareRanks(staticRanks2, dynamicRanks) < errorTol)
+ val parallelStaticRanks1 = starGraph
+ .staticParallelPersonalizedPageRank(Array(0), 1, resetProb).mapVertices {
+ case (vertexId, vector) => vector(0)
+ }.vertices.cache()
+ assert(compareRanks(staticRanks1, parallelStaticRanks1) < errorTol)
+
+ val parallelStaticRanks2 = starGraph
+ .staticParallelPersonalizedPageRank(Array(0, 1), 2, resetProb).mapVertices {
+ case (vertexId, vector) => vector(0)
+ }.vertices.cache()
+ assert(compareRanks(staticRanks2, parallelStaticRanks2) < errorTol)
+
// We have one outbound edge from 1 to 0
val otherStaticRanks2 = starGraph.staticPersonalizedPageRank(1, numIter = 2, resetProb)
.vertices.cache()
val otherDynamicRanks = starGraph.personalizedPageRank(1, 0, resetProb).vertices.cache()
+ val otherParallelStaticRanks2 = starGraph
+ .staticParallelPersonalizedPageRank(Array(0, 1), 2, resetProb).mapVertices {
+ case (vertexId, vector) => vector(1)
+ }.vertices.cache()
assert(compareRanks(otherDynamicRanks, otherStaticRanks2) < errorTol)
+ assert(compareRanks(otherStaticRanks2, otherParallelStaticRanks2) < errorTol)
+ assert(compareRanks(otherDynamicRanks, otherParallelStaticRanks2) < errorTol)
}
} // end of test Star PersonalPageRank
@@ -177,6 +195,12 @@ class PageRankSuite extends SparkFunSuite with LocalSparkContext {
val dynamicRanks = chain.personalizedPageRank(4, tol, resetProb).vertices
assert(compareRanks(staticRanks, dynamicRanks) < errorTol)
+
+ val parallelStaticRanks = chain
+ .staticParallelPersonalizedPageRank(Array(4), numIter, resetProb).mapVertices {
+ case (vertexId, vector) => vector(0)
+ }.vertices.cache()
+ assert(compareRanks(staticRanks, parallelStaticRanks) < errorTol)
}
}
}