From 1f4c7f7ecc9d2393663fc4d059e71fe4c70bad84 Mon Sep 17 00:00:00 2001 From: Semih Salihoglu Date: Mon, 24 Feb 2014 22:42:30 -0800 Subject: Graph primitives2 Hi guys, I'm following Joey and Ankur's suggestions to add collectEdges and pickRandomVertex. I'm also adding the tests for collectEdges and refactoring one method getCycleGraph in GraphOpsSuite.scala. Thank you, semih Author: Semih Salihoglu Closes #580 from semihsalihoglu/GraphPrimitives2 and squashes the following commits: 937d3ec [Semih Salihoglu] - Fixed the scalastyle errors. a69a152 [Semih Salihoglu] - Adding collectEdges and pickRandomVertices. - Adding tests for collectEdges. - Refactoring a getCycle utility function for GraphOpsSuite.scala. 41265a6 [Semih Salihoglu] - Adding collectEdges and pickRandomVertex. - Adding tests for collectEdges. - Recycling a getCycle utility test file. --- .../org/apache/spark/graphx/GraphOpsSuite.scala | 134 +++++++++++++++++++-- 1 file changed, 125 insertions(+), 9 deletions(-) (limited to 'graphx/src/test/scala') diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala index bc2ad5677f..6386306c04 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala @@ -42,21 +42,20 @@ class GraphOpsSuite extends FunSuite with LocalSparkContext { test("collectNeighborIds") { withSpark { sc => - val chain = (0 until 100).map(x => (x, (x+1)%100) ) - val rawEdges = sc.parallelize(chain, 3).map { case (s,d) => (s.toLong, d.toLong) } - val graph = Graph.fromEdgeTuples(rawEdges, 1.0).cache() + val graph = getCycleGraph(sc, 100) val nbrs = graph.collectNeighborIds(EdgeDirection.Either).cache() - assert(nbrs.count === chain.size) + assert(nbrs.count === 100) assert(graph.numVertices === nbrs.count) nbrs.collect.foreach { case (vid, nbrs) => assert(nbrs.size === 2) } - nbrs.collect.foreach { case (vid, nbrs) => - val s = nbrs.toSet - assert(s.contains((vid + 1) % 100)) - assert(s.contains(if (vid > 0) vid - 1 else 99 )) + nbrs.collect.foreach { + case (vid, nbrs) => + val s = nbrs.toSet + assert(s.contains((vid + 1) % 100)) + assert(s.contains(if (vid > 0) vid - 1 else 99)) } } } - + test ("filter") { withSpark { sc => val n = 5 @@ -80,4 +79,121 @@ class GraphOpsSuite extends FunSuite with LocalSparkContext { } } + test("collectEdgesCycleDirectionOut") { + withSpark { sc => + val graph = getCycleGraph(sc, 100) + val edges = graph.collectEdges(EdgeDirection.Out).cache() + assert(edges.count == 100) + edges.collect.foreach { case (vid, edges) => assert(edges.size == 1) } + edges.collect.foreach { + case (vid, edges) => + val s = edges.toSet + val edgeDstIds = s.map(e => e.dstId) + assert(edgeDstIds.contains((vid + 1) % 100)) + } + } + } + + test("collectEdgesCycleDirectionIn") { + withSpark { sc => + val graph = getCycleGraph(sc, 100) + val edges = graph.collectEdges(EdgeDirection.In).cache() + assert(edges.count == 100) + edges.collect.foreach { case (vid, edges) => assert(edges.size == 1) } + edges.collect.foreach { + case (vid, edges) => + val s = edges.toSet + val edgeSrcIds = s.map(e => e.srcId) + assert(edgeSrcIds.contains(if (vid > 0) vid - 1 else 99)) + } + } + } + + test("collectEdgesCycleDirectionEither") { + withSpark { sc => + val graph = getCycleGraph(sc, 100) + val edges = graph.collectEdges(EdgeDirection.Either).cache() + assert(edges.count == 100) + edges.collect.foreach { case (vid, edges) => assert(edges.size == 2) } + edges.collect.foreach { + case (vid, edges) => + val s = edges.toSet + val edgeIds = s.map(e => if (vid != e.srcId) e.srcId else e.dstId) + assert(edgeIds.contains((vid + 1) % 100)) + assert(edgeIds.contains(if (vid > 0) vid - 1 else 99)) + } + } + } + + test("collectEdgesChainDirectionOut") { + withSpark { sc => + val graph = getChainGraph(sc, 50) + val edges = graph.collectEdges(EdgeDirection.Out).cache() + assert(edges.count == 49) + edges.collect.foreach { case (vid, edges) => assert(edges.size == 1) } + edges.collect.foreach { + case (vid, edges) => + val s = edges.toSet + val edgeDstIds = s.map(e => e.dstId) + assert(edgeDstIds.contains(vid + 1)) + } + } + } + + test("collectEdgesChainDirectionIn") { + withSpark { sc => + val graph = getChainGraph(sc, 50) + val edges = graph.collectEdges(EdgeDirection.In).cache() + // We expect only 49 because collectEdges does not return vertices that do + // not have any edges in the specified direction. + assert(edges.count == 49) + edges.collect.foreach { case (vid, edges) => assert(edges.size == 1) } + edges.collect.foreach { + case (vid, edges) => + val s = edges.toSet + val edgeDstIds = s.map(e => e.srcId) + assert(edgeDstIds.contains((vid - 1) % 100)) + } + } + } + + test("collectEdgesChainDirectionEither") { + withSpark { sc => + val graph = getChainGraph(sc, 50) + val edges = graph.collectEdges(EdgeDirection.Either).cache() + // We expect only 49 because collectEdges does not return vertices that do + // not have any edges in the specified direction. + assert(edges.count === 50) + edges.collect.foreach { + case (vid, edges) => if (vid > 0 && vid < 49) assert(edges.size == 2) + else assert(edges.size == 1) + } + edges.collect.foreach { + case (vid, edges) => + val s = edges.toSet + val edgeIds = s.map(e => if (vid != e.srcId) e.srcId else e.dstId) + if (vid == 0) { assert(edgeIds.contains(1)) } + else if (vid == 49) { assert(edgeIds.contains(48)) } + else { + assert(edgeIds.contains(vid + 1)) + assert(edgeIds.contains(vid - 1)) + } + } + } + } + + private def getCycleGraph(sc: SparkContext, numVertices: Int): Graph[Double, Int] = { + val cycle = (0 until numVertices).map(x => (x, (x + 1) % numVertices)) + getGraphFromSeq(sc, cycle) + } + + private def getChainGraph(sc: SparkContext, numVertices: Int): Graph[Double, Int] = { + val chain = (0 until numVertices - 1).map(x => (x, (x + 1))) + getGraphFromSeq(sc, chain) + } + + private def getGraphFromSeq(sc: SparkContext, seq: IndexedSeq[(Int, Int)]): Graph[Double, Int] = { + val rawEdges = sc.parallelize(seq, 3).map { case (s, d) => (s.toLong, d.toLong) } + Graph.fromEdgeTuples(rawEdges, 1.0).cache() + } } -- cgit v1.2.3