aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala2
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala4
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala16
3 files changed, 20 insertions, 2 deletions
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
index b908860310..796082721d 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
@@ -151,7 +151,7 @@ object Pregel extends Logging {
// count the iteration
i += 1
}
-
+ messages.unpersist(blocking = false)
g
} // end of apply
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala
index 859f896039..f72cbb1524 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala
@@ -47,9 +47,11 @@ object ConnectedComponents {
}
}
val initialMessage = Long.MaxValue
- Pregel(ccGraph, initialMessage, activeDirection = EdgeDirection.Either)(
+ val pregelGraph = Pregel(ccGraph, initialMessage, activeDirection = EdgeDirection.Either)(
vprog = (id, attr, msg) => math.min(attr, msg),
sendMsg = sendMessage,
mergeMsg = (a, b) => math.min(a, b))
+ ccGraph.unpersist()
+ pregelGraph
} // end of connectedComponents
}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
index 1f5e27d550..2fbc6f069d 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
@@ -428,4 +428,20 @@ class GraphSuite extends SparkFunSuite with LocalSparkContext {
}
}
+ test("unpersist graph RDD") {
+ withSpark { sc =>
+ val vert = sc.parallelize(List((1L, "a"), (2L, "b"), (3L, "c")), 1)
+ val edges = sc.parallelize(List(Edge[Long](1L, 2L), Edge[Long](1L, 3L)), 1)
+ val g0 = Graph(vert, edges)
+ val g = g0.partitionBy(PartitionStrategy.EdgePartition2D, 2)
+ val cc = g.connectedComponents()
+ assert(sc.getPersistentRDDs.nonEmpty)
+ cc.unpersist()
+ g.unpersist()
+ g0.unpersist()
+ vert.unpersist()
+ edges.unpersist()
+ assert(sc.getPersistentRDDs.isEmpty)
+ }
+ }
}