aboutsummaryrefslogtreecommitdiff
path: root/graphx/src/test
diff options
context:
space:
mode:
authorSemih Salihoglu <semihsalihoglu@gmail.com>2014-02-24 22:42:30 -0800
committerReynold Xin <rxin@apache.org>2014-02-24 22:42:30 -0800
commit1f4c7f7ecc9d2393663fc4d059e71fe4c70bad84 (patch)
treeadb39bfef3d08da56b8e33c882070d911ebaf1bd /graphx/src/test
parenta4f4fbc8fa5886a8c6ee58ee614de0cc6e67dcd7 (diff)
downloadspark-1f4c7f7ecc9d2393663fc4d059e71fe4c70bad84.tar.gz
spark-1f4c7f7ecc9d2393663fc4d059e71fe4c70bad84.tar.bz2
spark-1f4c7f7ecc9d2393663fc4d059e71fe4c70bad84.zip
Graph primitives2
Hi guys, I'm following Joey and Ankur's suggestions to add collectEdges and pickRandomVertex. I'm also adding the tests for collectEdges and refactoring one method getCycleGraph in GraphOpsSuite.scala. Thank you, semih Author: Semih Salihoglu <semihsalihoglu@gmail.com> Closes #580 from semihsalihoglu/GraphPrimitives2 and squashes the following commits: 937d3ec [Semih Salihoglu] - Fixed the scalastyle errors. a69a152 [Semih Salihoglu] - Adding collectEdges and pickRandomVertices. - Adding tests for collectEdges. - Refactoring a getCycle utility function for GraphOpsSuite.scala. 41265a6 [Semih Salihoglu] - Adding collectEdges and pickRandomVertex. - Adding tests for collectEdges. - Recycling a getCycle utility test file.
Diffstat (limited to 'graphx/src/test')
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala134
1 files changed, 125 insertions, 9 deletions
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala
index bc2ad5677f..6386306c04 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala
@@ -42,21 +42,20 @@ class GraphOpsSuite extends FunSuite with LocalSparkContext {
test("collectNeighborIds") {
withSpark { sc =>
- val chain = (0 until 100).map(x => (x, (x+1)%100) )
- val rawEdges = sc.parallelize(chain, 3).map { case (s,d) => (s.toLong, d.toLong) }
- val graph = Graph.fromEdgeTuples(rawEdges, 1.0).cache()
+ val graph = getCycleGraph(sc, 100)
val nbrs = graph.collectNeighborIds(EdgeDirection.Either).cache()
- assert(nbrs.count === chain.size)
+ assert(nbrs.count === 100)
assert(graph.numVertices === nbrs.count)
nbrs.collect.foreach { case (vid, nbrs) => assert(nbrs.size === 2) }
- nbrs.collect.foreach { case (vid, nbrs) =>
- val s = nbrs.toSet
- assert(s.contains((vid + 1) % 100))
- assert(s.contains(if (vid > 0) vid - 1 else 99 ))
+ nbrs.collect.foreach {
+ case (vid, nbrs) =>
+ val s = nbrs.toSet
+ assert(s.contains((vid + 1) % 100))
+ assert(s.contains(if (vid > 0) vid - 1 else 99))
}
}
}
-
+
test ("filter") {
withSpark { sc =>
val n = 5
@@ -80,4 +79,121 @@ class GraphOpsSuite extends FunSuite with LocalSparkContext {
}
}
+ test("collectEdgesCycleDirectionOut") {
+ withSpark { sc =>
+ val graph = getCycleGraph(sc, 100)
+ val edges = graph.collectEdges(EdgeDirection.Out).cache()
+ assert(edges.count == 100)
+ edges.collect.foreach { case (vid, edges) => assert(edges.size == 1) }
+ edges.collect.foreach {
+ case (vid, edges) =>
+ val s = edges.toSet
+ val edgeDstIds = s.map(e => e.dstId)
+ assert(edgeDstIds.contains((vid + 1) % 100))
+ }
+ }
+ }
+
+ test("collectEdgesCycleDirectionIn") {
+ withSpark { sc =>
+ val graph = getCycleGraph(sc, 100)
+ val edges = graph.collectEdges(EdgeDirection.In).cache()
+ assert(edges.count == 100)
+ edges.collect.foreach { case (vid, edges) => assert(edges.size == 1) }
+ edges.collect.foreach {
+ case (vid, edges) =>
+ val s = edges.toSet
+ val edgeSrcIds = s.map(e => e.srcId)
+ assert(edgeSrcIds.contains(if (vid > 0) vid - 1 else 99))
+ }
+ }
+ }
+
+ test("collectEdgesCycleDirectionEither") {
+ withSpark { sc =>
+ val graph = getCycleGraph(sc, 100)
+ val edges = graph.collectEdges(EdgeDirection.Either).cache()
+ assert(edges.count == 100)
+ edges.collect.foreach { case (vid, edges) => assert(edges.size == 2) }
+ edges.collect.foreach {
+ case (vid, edges) =>
+ val s = edges.toSet
+ val edgeIds = s.map(e => if (vid != e.srcId) e.srcId else e.dstId)
+ assert(edgeIds.contains((vid + 1) % 100))
+ assert(edgeIds.contains(if (vid > 0) vid - 1 else 99))
+ }
+ }
+ }
+
+ test("collectEdgesChainDirectionOut") {
+ withSpark { sc =>
+ val graph = getChainGraph(sc, 50)
+ val edges = graph.collectEdges(EdgeDirection.Out).cache()
+ assert(edges.count == 49)
+ edges.collect.foreach { case (vid, edges) => assert(edges.size == 1) }
+ edges.collect.foreach {
+ case (vid, edges) =>
+ val s = edges.toSet
+ val edgeDstIds = s.map(e => e.dstId)
+ assert(edgeDstIds.contains(vid + 1))
+ }
+ }
+ }
+
+ test("collectEdgesChainDirectionIn") {
+ withSpark { sc =>
+ val graph = getChainGraph(sc, 50)
+ val edges = graph.collectEdges(EdgeDirection.In).cache()
+ // We expect only 49 because collectEdges does not return vertices that do
+ // not have any edges in the specified direction.
+ assert(edges.count == 49)
+ edges.collect.foreach { case (vid, edges) => assert(edges.size == 1) }
+ edges.collect.foreach {
+ case (vid, edges) =>
+ val s = edges.toSet
+ val edgeDstIds = s.map(e => e.srcId)
+ assert(edgeDstIds.contains((vid - 1) % 100))
+ }
+ }
+ }
+
+ test("collectEdgesChainDirectionEither") {
+ withSpark { sc =>
+ val graph = getChainGraph(sc, 50)
+ val edges = graph.collectEdges(EdgeDirection.Either).cache()
+ // We expect only 49 because collectEdges does not return vertices that do
+ // not have any edges in the specified direction.
+ assert(edges.count === 50)
+ edges.collect.foreach {
+ case (vid, edges) => if (vid > 0 && vid < 49) assert(edges.size == 2)
+ else assert(edges.size == 1)
+ }
+ edges.collect.foreach {
+ case (vid, edges) =>
+ val s = edges.toSet
+ val edgeIds = s.map(e => if (vid != e.srcId) e.srcId else e.dstId)
+ if (vid == 0) { assert(edgeIds.contains(1)) }
+ else if (vid == 49) { assert(edgeIds.contains(48)) }
+ else {
+ assert(edgeIds.contains(vid + 1))
+ assert(edgeIds.contains(vid - 1))
+ }
+ }
+ }
+ }
+
+ private def getCycleGraph(sc: SparkContext, numVertices: Int): Graph[Double, Int] = {
+ val cycle = (0 until numVertices).map(x => (x, (x + 1) % numVertices))
+ getGraphFromSeq(sc, cycle)
+ }
+
+ private def getChainGraph(sc: SparkContext, numVertices: Int): Graph[Double, Int] = {
+ val chain = (0 until numVertices - 1).map(x => (x, (x + 1)))
+ getGraphFromSeq(sc, chain)
+ }
+
+ private def getGraphFromSeq(sc: SparkContext, seq: IndexedSeq[(Int, Int)]): Graph[Double, Int] = {
+ val rawEdges = sc.parallelize(seq, 3).map { case (s, d) => (s.toLong, d.toLong) }
+ Graph.fromEdgeTuples(rawEdges, 1.0).cache()
+ }
}