aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAnkur Dave <ankurdave@gmail.com>2014-01-09 23:25:35 -0800
committerAnkur Dave <ankurdave@gmail.com>2014-01-09 23:25:35 -0800
commit8ae108f6c48528f3bb7498d586eb51a70c043764 (patch)
tree1d6cb878a9f56859c77b2715b3bbe8ea583c8c72
parent210f2dd84fb2de623745a162377b989712f7ef0f (diff)
downloadspark-8ae108f6c48528f3bb7498d586eb51a70c043764.tar.gz
spark-8ae108f6c48528f3bb7498d586eb51a70c043764.tar.bz2
spark-8ae108f6c48528f3bb7498d586eb51a70c043764.zip
Unpersist previous iterations in Pregel
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/Analytics.scala2
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala5
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala19
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala5
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/algorithms/PageRank.scala1
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala16
6 files changed, 41 insertions, 7 deletions
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Analytics.scala b/graphx/src/main/scala/org/apache/spark/graphx/Analytics.scala
index def6d69190..2c4c885a04 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/Analytics.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/Analytics.scala
@@ -83,7 +83,7 @@ object Analytics extends Logging {
println("GRAPHX: Number of edges " + graph.edges.count)
//val pr = Analytics.pagerank(graph, numIter)
- val pr = graph.pageRank(tol).vertices
+ val pr = graph.pageRank(tol).vertices.cache()
println("GRAPHX: Total rank: " + pr.map(_._2).reduce(_+_))
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala
index e4ef460e6f..7fd6580626 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala
@@ -44,6 +44,11 @@ class EdgeRDD[@specialized ED: ClassTag](
/** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
override def cache(): EdgeRDD[ED] = persist()
+ override def unpersist(blocking: Boolean = true): EdgeRDD[ED] = {
+ partitionsRDD.unpersist(blocking)
+ this
+ }
+
def mapEdgePartitions[ED2: ClassTag](f: (PartitionID, EdgePartition[ED]) => EdgePartition[ED2])
: EdgeRDD[ED2] = {
// iter => iter.map { case (pid, ep) => (pid, f(ep)) }
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
index 8ddb788135..ed8733a806 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
@@ -93,25 +93,36 @@ object Pregel {
mergeMsg: (A, A) => A)
: Graph[VD, ED] = {
- var g = graph.mapVertices( (vid, vdata) => vprog(vid, vdata, initialMsg) )
+ var g = graph.mapVertices( (vid, vdata) => vprog(vid, vdata, initialMsg) ).cache()
// compute the messages
- var messages = g.mapReduceTriplets(sendMsg, mergeMsg).cache()
+ var messages = g.mapReduceTriplets(sendMsg, mergeMsg)
var activeMessages = messages.count()
// Loop
+ var prevG: Graph[VD, ED] = null
var i = 0
while (activeMessages > 0 && i < maxIterations) {
// Receive the messages. Vertices that didn't get any messages do not appear in newVerts.
val newVerts = g.vertices.innerJoin(messages)(vprog).cache()
// Update the graph with the new vertices.
+ prevG = g
g = g.outerJoinVertices(newVerts) { (vid, old, newOpt) => newOpt.getOrElse(old) }
+ g.vertices.cache()
val oldMessages = messages
// Send new messages. Vertices that didn't get any messages don't appear in newVerts, so don't
- // get to send messages.
+ // get to send messages. We must cache messages so it can be materialized on the next line,
+ // allowing us to uncache the previous iteration.
messages = g.mapReduceTriplets(sendMsg, mergeMsg, Some((newVerts, EdgeDirection.Out))).cache()
+ // Materializes messages, newVerts, and g.rvv (which materializes g.vertices). Hides
+ // oldMessages (depended on by newVerts), newVerts (depended on by messages), prevG.vertices
+ // (depended on by newVerts and g.vertices), and prevG.rvv (depended on by oldMessages and
+ // g.rvv).
activeMessages = messages.count()
- // after counting we can unpersist the old messages
+ // Unpersist hidden RDDs
oldMessages.unpersist(blocking=false)
+ newVerts.unpersist(blocking=false)
+ prevG.vertices.unpersist(blocking=false)
+ prevG.asInstanceOf[org.apache.spark.graphx.impl.GraphImpl[VD, ED]].replicatedVertexView.unpersist(blocking=false)
// count the iteration
i += 1
}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala
index cfee9b089f..971e2615d4 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala
@@ -98,6 +98,11 @@ class VertexRDD[@specialized VD: ClassTag](
/** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
override def cache(): VertexRDD[VD] = persist()
+ override def unpersist(blocking: Boolean = true): VertexRDD[VD] = {
+ partitionsRDD.unpersist(blocking)
+ this
+ }
+
/** Return the number of vertices in this set. */
override def count(): Long = {
partitionsRDD.map(_.size).reduce(_ + _)
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/algorithms/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/algorithms/PageRank.scala
index b423104eda..179d310554 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/algorithms/PageRank.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/algorithms/PageRank.scala
@@ -125,6 +125,7 @@ object PageRank extends Logging {
.mapTriplets( e => 1.0 / e.srcAttr )
// Set the vertex attributes to (initalPR, delta = 0)
.mapVertices( (id, attr) => (0.0, 0.0) )
+ .cache()
// Display statistics about pagerank
logInfo(pagerankGraph.statistics.toString)
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala
index 63180bc3af..0e2f5a9dd9 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala
@@ -14,9 +14,11 @@ import org.apache.spark.graphx._
* specified, `updatedVerts` are treated as incremental updates to the previous view. Otherwise, a
* fresh view is created.
*
- * The view is always cached (i.e., once it is created, it remains materialized). This avoids
+ * The view is always cached (i.e., once it is evaluated, it remains materialized). This avoids
* constructing it twice if the user calls graph.triplets followed by graph.mapReduceTriplets, for
- * example.
+ * example. However, it means iterative algorithms must manually call `Graph.unpersist` on previous
+ * iterations' graphs for best GC performance. See the implementation of
+ * [[org.apache.spark.graphx.Pregel]] for an example.
*/
private[impl]
class ReplicatedVertexView[VD: ClassTag](
@@ -51,6 +53,16 @@ class ReplicatedVertexView[VD: ClassTag](
private lazy val dstAttrOnly: RDD[(PartitionID, VertexPartition[VD])] = create(false, true)
private lazy val noAttrs: RDD[(PartitionID, VertexPartition[VD])] = create(false, false)
+ def unpersist(blocking: Boolean = true): ReplicatedVertexView[VD] = {
+ bothAttrs.unpersist(blocking)
+ srcAttrOnly.unpersist(blocking)
+ dstAttrOnly.unpersist(blocking)
+ noAttrs.unpersist(blocking)
+ // Don't unpersist localVertexIDMap because a future ReplicatedVertexView may be using it
+ // without modification
+ this
+ }
+
def get(includeSrc: Boolean, includeDst: Boolean): RDD[(PartitionID, VertexPartition[VD])] = {
(includeSrc, includeDst) match {
case (true, true) => bothAttrs