diff options
3 files changed, 20 insertions, 2 deletions
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala index b908860310..796082721d 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala @@ -151,7 +151,7 @@ object Pregel extends Logging { // count the iteration i += 1 } - + messages.unpersist(blocking = false) g } // end of apply diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala index 859f896039..f72cbb1524 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala @@ -47,9 +47,11 @@ object ConnectedComponents { } } val initialMessage = Long.MaxValue - Pregel(ccGraph, initialMessage, activeDirection = EdgeDirection.Either)( + val pregelGraph = Pregel(ccGraph, initialMessage, activeDirection = EdgeDirection.Either)( vprog = (id, attr, msg) => math.min(attr, msg), sendMsg = sendMessage, mergeMsg = (a, b) => math.min(a, b)) + ccGraph.unpersist() + pregelGraph } // end of connectedComponents } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala index 1f5e27d550..2fbc6f069d 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala @@ -428,4 +428,20 @@ class GraphSuite extends SparkFunSuite with LocalSparkContext { } } + test("unpersist graph RDD") { + withSpark { sc => + val vert = sc.parallelize(List((1L, "a"), (2L, "b"), (3L, "c")), 1) + val edges = sc.parallelize(List(Edge[Long](1L, 2L), Edge[Long](1L, 3L)), 1) + val g0 = Graph(vert, edges) + val g = g0.partitionBy(PartitionStrategy.EdgePartition2D, 2) + val cc = g.connectedComponents() + assert(sc.getPersistentRDDs.nonEmpty) + cc.unpersist() + g.unpersist() + g0.unpersist() + vert.unpersist() + edges.unpersist() + assert(sc.getPersistentRDDs.isEmpty) + } + } } |