diff options
author | Ankur Dave <ankurdave@gmail.com> | 2014-01-10 00:35:02 -0800 |
---|---|---|
committer | Ankur Dave <ankurdave@gmail.com> | 2014-01-10 00:35:02 -0800 |
commit | ba511f890ee0d7f85746126c4be734538ede21ea (patch) | |
tree | afebff06ddf4bc65f57df0e20844262230c5849d /graphx | |
parent | 8b6b8ac87f6ffb92b3395344bf2696d5c7fb3798 (diff) | |
download | spark-ba511f890ee0d7f85746126c4be734538ede21ea.tar.gz spark-ba511f890ee0d7f85746126c4be734538ede21ea.tar.bz2 spark-ba511f890ee0d7f85746126c4be734538ede21ea.zip |
Avoid recomputation by caching all multiply-used RDDs
Diffstat (limited to 'graphx')
11 files changed, 67 insertions, 53 deletions
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/algorithms/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/algorithms/PageRank.scala index 179d310554..ab447d5422 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/algorithms/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/algorithms/PageRank.scala @@ -61,6 +61,7 @@ object PageRank extends Logging { .mapTriplets( e => 1.0 / e.srcAttr ) // Set the vertex attributes to the initial pagerank values .mapVertices( (id, attr) => 1.0 ) + .cache() // Display statistics about pagerank logInfo(pagerankGraph.statistics.toString) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/algorithms/SVDPlusPlus.scala b/graphx/src/main/scala/org/apache/spark/graphx/algorithms/SVDPlusPlus.scala index 8fdfa3d907..2a13553d79 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/algorithms/SVDPlusPlus.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/algorithms/SVDPlusPlus.scala @@ -42,6 +42,7 @@ object SVDPlusPlus { } // calculate global rating mean + edges.cache() val (rs, rc) = edges.map(e => (e.attr, 1L)).reduce((a, b) => (a._1 + b._1, a._2 + b._2)) val u = rs / rc @@ -72,11 +73,13 @@ object SVDPlusPlus { for (i <- 0 until conf.maxIters) { // phase 1, calculate pu + |N(u)|^(-0.5)*sum(y) for user nodes + g.cache() var t1 = g.mapReduceTriplets(et => Iterator((et.srcId, et.dstAttr._2)), (g1: RealVector, g2: RealVector) => g1.add(g2)) g = g.outerJoinVertices(t1) { (vid: VertexID, vd: (RealVector, RealVector, Double, Double), msg: Option[RealVector]) => if (msg.isDefined) (vd._1, vd._1.add(msg.get.mapMultiply(vd._4)), vd._3, vd._4) else vd } // phase 2, update p for user nodes and q, y for item nodes + g.cache() val t2 = g.mapReduceTriplets(mapTrainF(conf, u), (g1: (RealVector, RealVector, Double), g2: (RealVector, RealVector, Double)) => (g1._1.add(g2._1), g1._2.add(g2._2), g1._3 + g2._3)) g = g.outerJoinVertices(t2) { (vid: VertexID, vd: (RealVector, RealVector, Double, Double), msg: Option[(RealVector, RealVector, Double)]) => @@ -94,6 +97,7 @@ object SVDPlusPlus { val err = (et.attr - pred) * (et.attr - pred) Iterator((et.dstId, err)) } + g.cache() val t3 = g.mapReduceTriplets(mapTestF(conf, u), (g1: Double, g2: Double) => g1 + g2) g = g.outerJoinVertices(t3) { (vid: VertexID, vd: (RealVector, RealVector, Double, Double), msg: Option[Double]) => if (msg.isDefined) (vd._1, vd._2, vd._3, msg.get) else vd diff --git a/graphx/src/main/scala/org/apache/spark/graphx/algorithms/StronglyConnectedComponents.scala b/graphx/src/main/scala/org/apache/spark/graphx/algorithms/StronglyConnectedComponents.scala index 49ec91aedd..864f0ec57c 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/algorithms/StronglyConnectedComponents.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/algorithms/StronglyConnectedComponents.scala @@ -22,7 +22,7 @@ object StronglyConnectedComponents { // the graph we update with final SCC ids, and the graph we return at the end var sccGraph = graph.mapVertices { case (vid, _) => vid } // graph we are going to work with in our iterations - var sccWorkGraph = graph.mapVertices { case (vid, _) => (vid, false) } + var sccWorkGraph = graph.mapVertices { case (vid, _) => (vid, false) }.cache() var numVertices = sccWorkGraph.numVertices var iter = 0 @@ -32,10 +32,9 @@ object StronglyConnectedComponents { numVertices = sccWorkGraph.numVertices sccWorkGraph = sccWorkGraph.outerJoinVertices(sccWorkGraph.outDegrees) { (vid, data, degreeOpt) => if (degreeOpt.isDefined) data else (vid, true) - } - sccWorkGraph = sccWorkGraph.outerJoinVertices(sccWorkGraph.inDegrees) { + }.outerJoinVertices(sccWorkGraph.inDegrees) { (vid, data, degreeOpt) => if (degreeOpt.isDefined) data else (vid, true) - } + }.cache() // get all vertices to be removed val finalVertices = sccWorkGraph.vertices @@ -47,7 +46,7 @@ object StronglyConnectedComponents { (vid, scc, opt) => opt.getOrElse(scc) } // only keep vertices that are not final - sccWorkGraph = sccWorkGraph.subgraph(vpred = (vid, data) => !data._2) + sccWorkGraph = sccWorkGraph.subgraph(vpred = (vid, data) => !data._2).cache() } while (sccWorkGraph.numVertices < numVertices) sccWorkGraph = sccWorkGraph.mapVertices{ case (vid, (color, isFinal)) => (vid, isFinal) } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala index 2dd1324d4f..987a646c0c 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala @@ -32,19 +32,6 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( @transient val replicatedVertexView: ReplicatedVertexView[VD]) extends Graph[VD, ED] with Serializable { - def this( - vertices: VertexRDD[VD], - edges: EdgeRDD[ED], - routingTable: RoutingTable) = { - this(vertices, edges, routingTable, new ReplicatedVertexView(vertices, edges, routingTable)) - } - - def this( - vertices: VertexRDD[VD], - edges: EdgeRDD[ED]) = { - this(vertices, edges, new RoutingTable(edges, vertices)) - } - /** Return a RDD that brings edges together with their source and destination vertices. */ @transient override val triplets: RDD[EdgeTriplet[VD, ED]] = { val vdTag = classTag[VD] @@ -90,7 +77,7 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( val edgePartition = builder.toEdgePartition Iterator((pid, edgePartition)) }, preservesPartitioning = true).cache()) - new GraphImpl(vertices, newEdges) + GraphImpl(vertices, newEdges) } override def statistics: Map[String, Any] = { @@ -166,7 +153,7 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( override def mapVertices[VD2: ClassTag](f: (VertexID, VD) => VD2): Graph[VD2, ED] = { if (classTag[VD] equals classTag[VD2]) { // The map preserves type, so we can use incremental replication - val newVerts = vertices.mapVertexPartitions(_.map(f)) + val newVerts = vertices.mapVertexPartitions(_.map(f)).cache() val changedVerts = vertices.asInstanceOf[VertexRDD[VD2]].diff(newVerts) val newReplicatedVertexView = new ReplicatedVertexView[VD2]( changedVerts, edges, routingTable, @@ -174,7 +161,7 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( new GraphImpl(newVerts, edges, routingTable, newReplicatedVertexView) } else { // The map does not preserve type, so we must re-replicate all vertices - new GraphImpl(vertices.mapVertexPartitions(_.map(f)), edges, routingTable) + GraphImpl(vertices.mapVertexPartitions(_.map(f)), edges, routingTable) } } @@ -336,7 +323,7 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( } else { // updateF does not preserve type, so we must re-replicate all vertices val newVerts = vertices.leftJoin(updates)(updateF) - new GraphImpl(newVerts, edges, routingTable) + GraphImpl(newVerts, edges, routingTable) } } @@ -382,7 +369,29 @@ object GraphImpl { val vertexRDD = VertexRDD(vids, vPartitioned, defaultVertexAttr) - new GraphImpl(vertexRDD, edgeRDD) + GraphImpl(vertexRDD, edgeRDD) + } + + def apply[VD: ClassTag, ED: ClassTag]( + vertices: VertexRDD[VD], + edges: EdgeRDD[ED]): GraphImpl[VD, ED] = { + // Cache RDDs that are referenced multiple times + edges.cache() + + GraphImpl(vertices, edges, new RoutingTable(edges, vertices)) + } + + def apply[VD: ClassTag, ED: ClassTag]( + vertices: VertexRDD[VD], + edges: EdgeRDD[ED], + routingTable: RoutingTable): GraphImpl[VD, ED] = { + // Cache RDDs that are referenced multiple times. `routingTable` is cached by default, so we + // don't cache it explicitly. + vertices.cache() + edges.cache() + + new GraphImpl( + vertices, edges, routingTable, new ReplicatedVertexView(vertices, edges, routingTable)) } /** @@ -413,7 +422,7 @@ object GraphImpl { val vids = collectVertexIDsFromEdges(edges, new HashPartitioner(edges.partitions.size)) // Create the VertexRDD. val vertices = VertexRDD(vids.mapValues(x => defaultVertexAttr)) - new GraphImpl(vertices, edges) + GraphImpl(vertices, edges) } /** Collects all vids mentioned in edges and partitions them by partitioner. */ diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala index cc281fce99..cd3c0bbd30 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala @@ -53,8 +53,8 @@ class GraphOpsSuite extends FunSuite with LocalSparkContext { withSpark { sc => val chain = (0 until 100).map(x => (x, (x+1)%100) ) val rawEdges = sc.parallelize(chain, 3).map { case (s,d) => (s.toLong, d.toLong) } - val graph = Graph.fromEdgeTuples(rawEdges, 1.0) - val nbrs = graph.collectNeighborIds(EdgeDirection.Both) + val graph = Graph.fromEdgeTuples(rawEdges, 1.0).cache() + val nbrs = graph.collectNeighborIds(EdgeDirection.Both).cache() assert(nbrs.count === chain.size) assert(graph.numVertices === nbrs.count) nbrs.collect.foreach { case (vid, nbrs) => assert(nbrs.size === 2) } @@ -71,14 +71,14 @@ class GraphOpsSuite extends FunSuite with LocalSparkContext { val n = 5 val vertices = sc.parallelize((0 to n).map(x => (x:VertexID, x))) val edges = sc.parallelize((1 to n).map(x => Edge(0, x, x))) - val graph: Graph[Int, Int] = Graph(vertices, edges) + val graph: Graph[Int, Int] = Graph(vertices, edges).cache() val filteredGraph = graph.filter( graph => { val degrees: VertexRDD[Int] = graph.outDegrees graph.outerJoinVertices(degrees) {(vid, data, deg) => deg.getOrElse(0)} }, vpred = (vid: VertexID, deg:Int) => deg > 0 - ) + ).cache() val v = filteredGraph.vertices.collect().toSet assert(v === Set((0,0))) diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala index 094fa722a0..c32a6cbb81 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala @@ -175,7 +175,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { val n = 5 val vertices = sc.parallelize((0 to n).map(x => (x:VertexID, x))) val edges = sc.parallelize((1 to n).map(x => Edge(0, x, x))) - val graph: Graph[Int, Int] = Graph(vertices, edges) + val graph: Graph[Int, Int] = Graph(vertices, edges).cache() val subgraph = graph.subgraph( e => e.dstId != 4L, @@ -211,7 +211,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { test("mapReduceTriplets") { withSpark { sc => val n = 5 - val star = starGraph(sc, n).mapVertices { (_, _) => 0 } + val star = starGraph(sc, n).mapVertices { (_, _) => 0 }.cache() val starDeg = star.joinVertices(star.degrees){ (vid, oldV, deg) => deg } val neighborDegreeSums = starDeg.mapReduceTriplets( edge => Iterator((edge.srcId, edge.dstAttr), (edge.dstId, edge.srcAttr)), @@ -235,7 +235,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { // outerJoinVertices followed by mapReduceTriplets(activeSetOpt) val ringEdges = sc.parallelize((0 until n).map(x => (x: VertexID, (x+1) % n: VertexID)), 3) val ring = Graph.fromEdgeTuples(ringEdges, 0) .mapVertices((vid, attr) => vid).cache() - val changed = ring.vertices.filter { case (vid, attr) => attr % 2 == 1 }.mapValues(-_) + val changed = ring.vertices.filter { case (vid, attr) => attr % 2 == 1 }.mapValues(-_).cache() val changedGraph = ring.outerJoinVertices(changed) { (vid, old, newOpt) => newOpt.getOrElse(old) } val numOddNeighbors = changedGraph.mapReduceTriplets(et => { // Map function should only run on edges with source in the active set @@ -252,7 +252,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { test("outerJoinVertices") { withSpark { sc => val n = 5 - val reverseStar = starGraph(sc, n).reverse + val reverseStar = starGraph(sc, n).reverse.cache() // outerJoinVertices changing type val reverseStarDegrees = reverseStar.outerJoinVertices(reverseStar.outDegrees) { (vid, a, bOpt) => bOpt.getOrElse(0) } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala index 429622357f..1ff3d75633 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala @@ -10,8 +10,8 @@ class PregelSuite extends FunSuite with LocalSparkContext { test("1 iteration") { withSpark { sc => val n = 5 - val star = - Graph.fromEdgeTuples(sc.parallelize((1 to n).map(x => (0: VertexID, x: VertexID)), 3), "v") + val starEdges = (1 to n).map(x => (0: VertexID, x: VertexID)) + val star = Graph.fromEdgeTuples(sc.parallelize(starEdges, 3), "v").cache() val result = Pregel(star, 0)( (vid, attr, msg) => attr, et => Iterator.empty, @@ -27,7 +27,7 @@ class PregelSuite extends FunSuite with LocalSparkContext { sc.parallelize((1 until n).map(x => (x: VertexID, x + 1: VertexID)), 3), 0).cache() assert(chain.vertices.collect.toSet === (1 to n).map(x => (x: VertexID, 0)).toSet) - val chainWithSeed = chain.mapVertices { (vid, attr) => if (vid == 1) 1 else 0 } + val chainWithSeed = chain.mapVertices { (vid, attr) => if (vid == 1) 1 else 0 }.cache() assert(chainWithSeed.vertices.collect.toSet === Set((1: VertexID, 1)) ++ (2 to n).map(x => (x: VertexID, 0)).toSet) val result = Pregel(chainWithSeed, 0)( diff --git a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala index 573b708e89..d94a3aa67c 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala @@ -33,8 +33,8 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { test("diff") { withSpark { sc => val n = 100 - val verts = vertices(sc, n) - val flipEvens = verts.mapValues(x => if (x % 2 == 0) -x else x) + val verts = vertices(sc, n).cache() + val flipEvens = verts.mapValues(x => if (x % 2 == 0) -x else x).cache() // diff should keep only the changed vertices assert(verts.diff(flipEvens).map(_._2).collect().toSet === (2 to n by 2).map(-_).toSet) // diff should keep the vertex values from `other` @@ -45,8 +45,8 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { test("leftJoin") { withSpark { sc => val n = 100 - val verts = vertices(sc, n) - val evens = verts.filter(q => ((q._2 % 2) == 0)) + val verts = vertices(sc, n).cache() + val evens = verts.filter(q => ((q._2 % 2) == 0)).cache() // leftJoin with another VertexRDD assert(verts.leftJoin(evens) { (id, a, bOpt) => a - bOpt.getOrElse(0) }.collect.toSet === (0 to n by 2).map(x => (x.toLong, 0)).toSet ++ (1 to n by 2).map(x => (x.toLong, x)).toSet) @@ -60,8 +60,8 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { test("innerJoin") { withSpark { sc => val n = 100 - val verts = vertices(sc, n) - val evens = verts.filter(q => ((q._2 % 2) == 0)) + val verts = vertices(sc, n).cache() + val evens = verts.filter(q => ((q._2 % 2) == 0)).cache() // innerJoin with another VertexRDD assert(verts.innerJoin(evens) { (id, a, b) => a - b }.collect.toSet === (0 to n by 2).map(x => (x.toLong, 0)).toSet) diff --git a/graphx/src/test/scala/org/apache/spark/graphx/algorithms/ConnectedComponentsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/algorithms/ConnectedComponentsSuite.scala index 209191ef07..16fc3fe5a2 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/algorithms/ConnectedComponentsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/algorithms/ConnectedComponentsSuite.scala @@ -13,8 +13,8 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { test("Grid Connected Components") { withSpark { sc => - val gridGraph = GraphGenerators.gridGraph(sc, 10, 10).cache() - val ccGraph = gridGraph.connectedComponents().cache() + val gridGraph = GraphGenerators.gridGraph(sc, 10, 10) + val ccGraph = gridGraph.connectedComponents() val maxCCid = ccGraph.vertices.map { case (vid, ccId) => ccId }.sum assert(maxCCid === 0) } @@ -23,8 +23,8 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { test("Reverse Grid Connected Components") { withSpark { sc => - val gridGraph = GraphGenerators.gridGraph(sc, 10, 10).reverse.cache() - val ccGraph = gridGraph.connectedComponents().cache() + val gridGraph = GraphGenerators.gridGraph(sc, 10, 10).reverse + val ccGraph = gridGraph.connectedComponents() val maxCCid = ccGraph.vertices.map { case (vid, ccId) => ccId }.sum assert(maxCCid === 0) } @@ -36,8 +36,8 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { val chain1 = (0 until 9).map(x => (x, x+1) ) val chain2 = (10 until 20).map(x => (x, x+1) ) val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) } - val twoChains = Graph.fromEdgeTuples(rawEdges, 1.0).cache() - val ccGraph = twoChains.connectedComponents().cache() + val twoChains = Graph.fromEdgeTuples(rawEdges, 1.0) + val ccGraph = twoChains.connectedComponents() val vertices = ccGraph.vertices.collect() for ( (id, cc) <- vertices ) { if(id < 10) { assert(cc === 0) } @@ -59,8 +59,8 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { val chain1 = (0 until 9).map(x => (x, x+1) ) val chain2 = (10 until 20).map(x => (x, x+1) ) val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) } - val twoChains = Graph.fromEdgeTuples(rawEdges, true).reverse.cache() - val ccGraph = twoChains.connectedComponents().cache() + val twoChains = Graph.fromEdgeTuples(rawEdges, true).reverse + val ccGraph = twoChains.connectedComponents() val vertices = ccGraph.vertices.collect for ( (id, cc) <- vertices ) { if (id < 10) { diff --git a/graphx/src/test/scala/org/apache/spark/graphx/algorithms/PageRankSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/algorithms/PageRankSuite.scala index cd857bd3a1..de2c2d1107 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/algorithms/PageRankSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/algorithms/PageRankSuite.scala @@ -57,7 +57,7 @@ class PageRankSuite extends FunSuite with LocalSparkContext { val resetProb = 0.15 val errorTol = 1.0e-5 - val staticRanks1 = starGraph.staticPageRank(numIter = 1, resetProb).vertices.cache() + val staticRanks1 = starGraph.staticPageRank(numIter = 1, resetProb).vertices val staticRanks2 = starGraph.staticPageRank(numIter = 2, resetProb).vertices.cache() // Static PageRank should only take 2 iterations to converge @@ -92,7 +92,7 @@ class PageRankSuite extends FunSuite with LocalSparkContext { val staticRanks = gridGraph.staticPageRank(numIter, resetProb).vertices.cache() val dynamicRanks = gridGraph.pageRank(tol, resetProb).vertices.cache() - val referenceRanks = VertexRDD(sc.parallelize(GridPageRank(rows, cols, numIter, resetProb))) + val referenceRanks = VertexRDD(sc.parallelize(GridPageRank(rows, cols, numIter, resetProb))).cache() assert(compareRanks(staticRanks, referenceRanks) < errorTol) assert(compareRanks(dynamicRanks, referenceRanks) < errorTol) @@ -110,8 +110,8 @@ class PageRankSuite extends FunSuite with LocalSparkContext { val numIter = 10 val errorTol = 1.0e-5 - val staticRanks = chain.staticPageRank(numIter, resetProb).vertices.cache() - val dynamicRanks = chain.pageRank(tol, resetProb).vertices.cache() + val staticRanks = chain.staticPageRank(numIter, resetProb).vertices + val dynamicRanks = chain.pageRank(tol, resetProb).vertices assert(compareRanks(staticRanks, dynamicRanks) < errorTol) } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/algorithms/SVDPlusPlusSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/algorithms/SVDPlusPlusSuite.scala index 06604198d7..7bd93e0e6c 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/algorithms/SVDPlusPlusSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/algorithms/SVDPlusPlusSuite.scala @@ -20,6 +20,7 @@ class SVDPlusPlusSuite extends FunSuite with LocalSparkContext { } val conf = new SVDPlusPlusConf(10, 2, 0.0, 5.0, 0.007, 0.007, 0.005, 0.015) // 2 iterations var (graph, u) = SVDPlusPlus.run(edges, conf) + graph.cache() val err = graph.vertices.collect.map{ case (vid, vd) => if (vid % 2 == 1) vd._4 else 0.0 }.reduce(_ + _) / graph.triplets.collect.size |