aboutsummaryrefslogtreecommitdiff
path: root/graphx
diff options
context:
space:
mode:
authorSemih Salihoglu <semihsalihoglu@gmail.com>2014-02-24 22:42:30 -0800
committerReynold Xin <rxin@apache.org>2014-02-24 22:42:30 -0800
commit1f4c7f7ecc9d2393663fc4d059e71fe4c70bad84 (patch)
treeadb39bfef3d08da56b8e33c882070d911ebaf1bd /graphx
parenta4f4fbc8fa5886a8c6ee58ee614de0cc6e67dcd7 (diff)
downloadspark-1f4c7f7ecc9d2393663fc4d059e71fe4c70bad84.tar.gz
spark-1f4c7f7ecc9d2393663fc4d059e71fe4c70bad84.tar.bz2
spark-1f4c7f7ecc9d2393663fc4d059e71fe4c70bad84.zip
Graph primitives2
Hi guys, I'm following Joey and Ankur's suggestions to add collectEdges and pickRandomVertex. I'm also adding the tests for collectEdges and refactoring one method getCycleGraph in GraphOpsSuite.scala. Thank you, semih Author: Semih Salihoglu <semihsalihoglu@gmail.com> Closes #580 from semihsalihoglu/GraphPrimitives2 and squashes the following commits: 937d3ec [Semih Salihoglu] - Fixed the scalastyle errors. a69a152 [Semih Salihoglu] - Adding collectEdges and pickRandomVertices. - Adding tests for collectEdges. - Refactoring a getCycle utility function for GraphOpsSuite.scala. 41265a6 [Semih Salihoglu] - Adding collectEdges and pickRandomVertex. - Adding tests for collectEdges. - Recycling a getCycle utility test file.
Diffstat (limited to 'graphx')
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala59
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala134
2 files changed, 183 insertions, 10 deletions
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
index 0fc1e4df68..377d9d6bd5 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
@@ -18,11 +18,11 @@
package org.apache.spark.graphx
import scala.reflect.ClassTag
-
import org.apache.spark.SparkContext._
import org.apache.spark.SparkException
import org.apache.spark.graphx.lib._
import org.apache.spark.rdd.RDD
+import scala.util.Random
/**
* Contains additional functionality for [[Graph]]. All operations are expressed in terms of the
@@ -138,6 +138,42 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
} // end of collectNeighbor
/**
+ * Returns an RDD that contains for each vertex v its local edges,
+ * i.e., the edges that are incident on v, in the user-specified direction.
+ * Warning: note that singleton vertices, those with no edges in the given
+ * direction will not be part of the return value.
+ *
+ * @note This function could be highly inefficient on power-law
+ * graphs where high degree vertices may force a large amount of
+ * information to be collected to a single location.
+ *
+ * @param edgeDirection the direction along which to collect
+ * the local edges of vertices
+ *
+ * @return the local edges for each vertex
+ */
+ def collectEdges(edgeDirection: EdgeDirection): VertexRDD[Array[Edge[ED]]] = {
+ edgeDirection match {
+ case EdgeDirection.Either =>
+ graph.mapReduceTriplets[Array[Edge[ED]]](
+ edge => Iterator((edge.srcId, Array(new Edge(edge.srcId, edge.dstId, edge.attr))),
+ (edge.dstId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))),
+ (a, b) => a ++ b)
+ case EdgeDirection.In =>
+ graph.mapReduceTriplets[Array[Edge[ED]]](
+ edge => Iterator((edge.dstId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))),
+ (a, b) => a ++ b)
+ case EdgeDirection.Out =>
+ graph.mapReduceTriplets[Array[Edge[ED]]](
+ edge => Iterator((edge.srcId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))),
+ (a, b) => a ++ b)
+ case EdgeDirection.Both =>
+ throw new SparkException("collectEdges does not support EdgeDirection.Both. Use" +
+ "EdgeDirection.Either instead.")
+ }
+ }
+
+ /**
* Join the vertices with an RDD and then apply a function from the
* the vertex and RDD entry to a new vertex value. The input table
* should contain at most one entry for each vertex. If no entry is
@@ -210,6 +246,27 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
}
/**
+ * Picks a random vertex from the graph and returns its ID.
+ */
+ def pickRandomVertex(): VertexId = {
+ val probability = 50 / graph.numVertices
+ var found = false
+ var retVal: VertexId = null.asInstanceOf[VertexId]
+ while (!found) {
+ val selectedVertices = graph.vertices.flatMap { vidVvals =>
+ if (Random.nextDouble() < probability) { Some(vidVvals._1) }
+ else { None }
+ }
+ if (selectedVertices.count > 1) {
+ found = true
+ val collectedVertices = selectedVertices.collect()
+ retVal = collectedVertices(Random.nextInt(collectedVertices.size))
+ }
+ }
+ retVal
+ }
+
+ /**
* Execute a Pregel-like iterative vertex-parallel abstraction. The
* user-defined vertex-program `vprog` is executed in parallel on
* each vertex receiving any inbound messages and computing a new
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala
index bc2ad5677f..6386306c04 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala
@@ -42,21 +42,20 @@ class GraphOpsSuite extends FunSuite with LocalSparkContext {
test("collectNeighborIds") {
withSpark { sc =>
- val chain = (0 until 100).map(x => (x, (x+1)%100) )
- val rawEdges = sc.parallelize(chain, 3).map { case (s,d) => (s.toLong, d.toLong) }
- val graph = Graph.fromEdgeTuples(rawEdges, 1.0).cache()
+ val graph = getCycleGraph(sc, 100)
val nbrs = graph.collectNeighborIds(EdgeDirection.Either).cache()
- assert(nbrs.count === chain.size)
+ assert(nbrs.count === 100)
assert(graph.numVertices === nbrs.count)
nbrs.collect.foreach { case (vid, nbrs) => assert(nbrs.size === 2) }
- nbrs.collect.foreach { case (vid, nbrs) =>
- val s = nbrs.toSet
- assert(s.contains((vid + 1) % 100))
- assert(s.contains(if (vid > 0) vid - 1 else 99 ))
+ nbrs.collect.foreach {
+ case (vid, nbrs) =>
+ val s = nbrs.toSet
+ assert(s.contains((vid + 1) % 100))
+ assert(s.contains(if (vid > 0) vid - 1 else 99))
}
}
}
-
+
test ("filter") {
withSpark { sc =>
val n = 5
@@ -80,4 +79,121 @@ class GraphOpsSuite extends FunSuite with LocalSparkContext {
}
}
+ test("collectEdgesCycleDirectionOut") {
+ withSpark { sc =>
+ val graph = getCycleGraph(sc, 100)
+ val edges = graph.collectEdges(EdgeDirection.Out).cache()
+ assert(edges.count == 100)
+ edges.collect.foreach { case (vid, edges) => assert(edges.size == 1) }
+ edges.collect.foreach {
+ case (vid, edges) =>
+ val s = edges.toSet
+ val edgeDstIds = s.map(e => e.dstId)
+ assert(edgeDstIds.contains((vid + 1) % 100))
+ }
+ }
+ }
+
+ test("collectEdgesCycleDirectionIn") {
+ withSpark { sc =>
+ val graph = getCycleGraph(sc, 100)
+ val edges = graph.collectEdges(EdgeDirection.In).cache()
+ assert(edges.count == 100)
+ edges.collect.foreach { case (vid, edges) => assert(edges.size == 1) }
+ edges.collect.foreach {
+ case (vid, edges) =>
+ val s = edges.toSet
+ val edgeSrcIds = s.map(e => e.srcId)
+ assert(edgeSrcIds.contains(if (vid > 0) vid - 1 else 99))
+ }
+ }
+ }
+
+ test("collectEdgesCycleDirectionEither") {
+ withSpark { sc =>
+ val graph = getCycleGraph(sc, 100)
+ val edges = graph.collectEdges(EdgeDirection.Either).cache()
+ assert(edges.count == 100)
+ edges.collect.foreach { case (vid, edges) => assert(edges.size == 2) }
+ edges.collect.foreach {
+ case (vid, edges) =>
+ val s = edges.toSet
+ val edgeIds = s.map(e => if (vid != e.srcId) e.srcId else e.dstId)
+ assert(edgeIds.contains((vid + 1) % 100))
+ assert(edgeIds.contains(if (vid > 0) vid - 1 else 99))
+ }
+ }
+ }
+
+ test("collectEdgesChainDirectionOut") {
+ withSpark { sc =>
+ val graph = getChainGraph(sc, 50)
+ val edges = graph.collectEdges(EdgeDirection.Out).cache()
+ assert(edges.count == 49)
+ edges.collect.foreach { case (vid, edges) => assert(edges.size == 1) }
+ edges.collect.foreach {
+ case (vid, edges) =>
+ val s = edges.toSet
+ val edgeDstIds = s.map(e => e.dstId)
+ assert(edgeDstIds.contains(vid + 1))
+ }
+ }
+ }
+
+ test("collectEdgesChainDirectionIn") {
+ withSpark { sc =>
+ val graph = getChainGraph(sc, 50)
+ val edges = graph.collectEdges(EdgeDirection.In).cache()
+ // We expect only 49 because collectEdges does not return vertices that do
+ // not have any edges in the specified direction.
+ assert(edges.count == 49)
+ edges.collect.foreach { case (vid, edges) => assert(edges.size == 1) }
+ edges.collect.foreach {
+ case (vid, edges) =>
+ val s = edges.toSet
+ val edgeDstIds = s.map(e => e.srcId)
+ assert(edgeDstIds.contains((vid - 1) % 100))
+ }
+ }
+ }
+
+ test("collectEdgesChainDirectionEither") {
+ withSpark { sc =>
+ val graph = getChainGraph(sc, 50)
+ val edges = graph.collectEdges(EdgeDirection.Either).cache()
+ // We expect only 49 because collectEdges does not return vertices that do
+ // not have any edges in the specified direction.
+ assert(edges.count === 50)
+ edges.collect.foreach {
+ case (vid, edges) => if (vid > 0 && vid < 49) assert(edges.size == 2)
+ else assert(edges.size == 1)
+ }
+ edges.collect.foreach {
+ case (vid, edges) =>
+ val s = edges.toSet
+ val edgeIds = s.map(e => if (vid != e.srcId) e.srcId else e.dstId)
+ if (vid == 0) { assert(edgeIds.contains(1)) }
+ else if (vid == 49) { assert(edgeIds.contains(48)) }
+ else {
+ assert(edgeIds.contains(vid + 1))
+ assert(edgeIds.contains(vid - 1))
+ }
+ }
+ }
+ }
+
+ private def getCycleGraph(sc: SparkContext, numVertices: Int): Graph[Double, Int] = {
+ val cycle = (0 until numVertices).map(x => (x, (x + 1) % numVertices))
+ getGraphFromSeq(sc, cycle)
+ }
+
+ private def getChainGraph(sc: SparkContext, numVertices: Int): Graph[Double, Int] = {
+ val chain = (0 until numVertices - 1).map(x => (x, (x + 1)))
+ getGraphFromSeq(sc, chain)
+ }
+
+ private def getGraphFromSeq(sc: SparkContext, seq: IndexedSeq[(Int, Int)]): Graph[Double, Int] = {
+ val rawEdges = sc.parallelize(seq, 3).map { case (s, d) => (s.toLong, d.toLong) }
+ Graph.fromEdgeTuples(rawEdges, 1.0).cache()
+ }
}