aboutsummaryrefslogtreecommitdiff
path: root/graphx
diff options
context:
space:
mode:
Diffstat (limited to 'graphx')
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala59
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala134
2 files changed, 183 insertions, 10 deletions
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
index 0fc1e4df68..377d9d6bd5 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
@@ -18,11 +18,11 @@
package org.apache.spark.graphx
import scala.reflect.ClassTag
-
import org.apache.spark.SparkContext._
import org.apache.spark.SparkException
import org.apache.spark.graphx.lib._
import org.apache.spark.rdd.RDD
+import scala.util.Random
/**
* Contains additional functionality for [[Graph]]. All operations are expressed in terms of the
@@ -138,6 +138,42 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
} // end of collectNeighbor
/**
+ * Returns an RDD that contains for each vertex v its local edges,
+ * i.e., the edges that are incident on v, in the user-specified direction.
+ * Warning: note that singleton vertices, those with no edges in the given
+ * direction will not be part of the return value.
+ *
+ * @note This function could be highly inefficient on power-law
+ * graphs where high degree vertices may force a large amount of
+ * information to be collected to a single location.
+ *
+ * @param edgeDirection the direction along which to collect
+ * the local edges of vertices
+ *
+ * @return the local edges for each vertex
+ */
+ def collectEdges(edgeDirection: EdgeDirection): VertexRDD[Array[Edge[ED]]] = {
+ edgeDirection match {
+ case EdgeDirection.Either =>
+ graph.mapReduceTriplets[Array[Edge[ED]]](
+ edge => Iterator((edge.srcId, Array(new Edge(edge.srcId, edge.dstId, edge.attr))),
+ (edge.dstId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))),
+ (a, b) => a ++ b)
+ case EdgeDirection.In =>
+ graph.mapReduceTriplets[Array[Edge[ED]]](
+ edge => Iterator((edge.dstId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))),
+ (a, b) => a ++ b)
+ case EdgeDirection.Out =>
+ graph.mapReduceTriplets[Array[Edge[ED]]](
+ edge => Iterator((edge.srcId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))),
+ (a, b) => a ++ b)
+ case EdgeDirection.Both =>
+ throw new SparkException("collectEdges does not support EdgeDirection.Both. Use" +
+ "EdgeDirection.Either instead.")
+ }
+ }
+
+ /**
* Join the vertices with an RDD and then apply a function from the
* the vertex and RDD entry to a new vertex value. The input table
* should contain at most one entry for each vertex. If no entry is
@@ -210,6 +246,27 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
}
/**
+ * Picks a random vertex from the graph and returns its ID.
+ */
+ def pickRandomVertex(): VertexId = {
+ val probability = 50 / graph.numVertices
+ var found = false
+ var retVal: VertexId = null.asInstanceOf[VertexId]
+ while (!found) {
+ val selectedVertices = graph.vertices.flatMap { vidVvals =>
+ if (Random.nextDouble() < probability) { Some(vidVvals._1) }
+ else { None }
+ }
+ if (selectedVertices.count > 1) {
+ found = true
+ val collectedVertices = selectedVertices.collect()
+ retVal = collectedVertices(Random.nextInt(collectedVertices.size))
+ }
+ }
+ retVal
+ }
+
+ /**
* Execute a Pregel-like iterative vertex-parallel abstraction. The
* user-defined vertex-program `vprog` is executed in parallel on
* each vertex receiving any inbound messages and computing a new
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala
index bc2ad5677f..6386306c04 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala
@@ -42,21 +42,20 @@ class GraphOpsSuite extends FunSuite with LocalSparkContext {
test("collectNeighborIds") {
withSpark { sc =>
- val chain = (0 until 100).map(x => (x, (x+1)%100) )
- val rawEdges = sc.parallelize(chain, 3).map { case (s,d) => (s.toLong, d.toLong) }
- val graph = Graph.fromEdgeTuples(rawEdges, 1.0).cache()
+ val graph = getCycleGraph(sc, 100)
val nbrs = graph.collectNeighborIds(EdgeDirection.Either).cache()
- assert(nbrs.count === chain.size)
+ assert(nbrs.count === 100)
assert(graph.numVertices === nbrs.count)
nbrs.collect.foreach { case (vid, nbrs) => assert(nbrs.size === 2) }
- nbrs.collect.foreach { case (vid, nbrs) =>
- val s = nbrs.toSet
- assert(s.contains((vid + 1) % 100))
- assert(s.contains(if (vid > 0) vid - 1 else 99 ))
+ nbrs.collect.foreach {
+ case (vid, nbrs) =>
+ val s = nbrs.toSet
+ assert(s.contains((vid + 1) % 100))
+ assert(s.contains(if (vid > 0) vid - 1 else 99))
}
}
}
-
+
test ("filter") {
withSpark { sc =>
val n = 5
@@ -80,4 +79,121 @@ class GraphOpsSuite extends FunSuite with LocalSparkContext {
}
}
+ test("collectEdgesCycleDirectionOut") {
+ withSpark { sc =>
+ val graph = getCycleGraph(sc, 100)
+ val edges = graph.collectEdges(EdgeDirection.Out).cache()
+ assert(edges.count == 100)
+ edges.collect.foreach { case (vid, edges) => assert(edges.size == 1) }
+ edges.collect.foreach {
+ case (vid, edges) =>
+ val s = edges.toSet
+ val edgeDstIds = s.map(e => e.dstId)
+ assert(edgeDstIds.contains((vid + 1) % 100))
+ }
+ }
+ }
+
+ test("collectEdgesCycleDirectionIn") {
+ withSpark { sc =>
+ val graph = getCycleGraph(sc, 100)
+ val edges = graph.collectEdges(EdgeDirection.In).cache()
+ assert(edges.count == 100)
+ edges.collect.foreach { case (vid, edges) => assert(edges.size == 1) }
+ edges.collect.foreach {
+ case (vid, edges) =>
+ val s = edges.toSet
+ val edgeSrcIds = s.map(e => e.srcId)
+ assert(edgeSrcIds.contains(if (vid > 0) vid - 1 else 99))
+ }
+ }
+ }
+
+ test("collectEdgesCycleDirectionEither") {
+ withSpark { sc =>
+ val graph = getCycleGraph(sc, 100)
+ val edges = graph.collectEdges(EdgeDirection.Either).cache()
+ assert(edges.count == 100)
+ edges.collect.foreach { case (vid, edges) => assert(edges.size == 2) }
+ edges.collect.foreach {
+ case (vid, edges) =>
+ val s = edges.toSet
+ val edgeIds = s.map(e => if (vid != e.srcId) e.srcId else e.dstId)
+ assert(edgeIds.contains((vid + 1) % 100))
+ assert(edgeIds.contains(if (vid > 0) vid - 1 else 99))
+ }
+ }
+ }
+
+ test("collectEdgesChainDirectionOut") {
+ withSpark { sc =>
+ val graph = getChainGraph(sc, 50)
+ val edges = graph.collectEdges(EdgeDirection.Out).cache()
+ assert(edges.count == 49)
+ edges.collect.foreach { case (vid, edges) => assert(edges.size == 1) }
+ edges.collect.foreach {
+ case (vid, edges) =>
+ val s = edges.toSet
+ val edgeDstIds = s.map(e => e.dstId)
+ assert(edgeDstIds.contains(vid + 1))
+ }
+ }
+ }
+
+ test("collectEdgesChainDirectionIn") {
+ withSpark { sc =>
+ val graph = getChainGraph(sc, 50)
+ val edges = graph.collectEdges(EdgeDirection.In).cache()
+ // We expect only 49 because collectEdges does not return vertices that do
+ // not have any edges in the specified direction.
+ assert(edges.count == 49)
+ edges.collect.foreach { case (vid, edges) => assert(edges.size == 1) }
+ edges.collect.foreach {
+ case (vid, edges) =>
+ val s = edges.toSet
+ val edgeDstIds = s.map(e => e.srcId)
+ assert(edgeDstIds.contains((vid - 1) % 100))
+ }
+ }
+ }
+
+ test("collectEdgesChainDirectionEither") {
+ withSpark { sc =>
+ val graph = getChainGraph(sc, 50)
+ val edges = graph.collectEdges(EdgeDirection.Either).cache()
+ // We expect only 49 because collectEdges does not return vertices that do
+ // not have any edges in the specified direction.
+ assert(edges.count === 50)
+ edges.collect.foreach {
+ case (vid, edges) => if (vid > 0 && vid < 49) assert(edges.size == 2)
+ else assert(edges.size == 1)
+ }
+ edges.collect.foreach {
+ case (vid, edges) =>
+ val s = edges.toSet
+ val edgeIds = s.map(e => if (vid != e.srcId) e.srcId else e.dstId)
+ if (vid == 0) { assert(edgeIds.contains(1)) }
+ else if (vid == 49) { assert(edgeIds.contains(48)) }
+ else {
+ assert(edgeIds.contains(vid + 1))
+ assert(edgeIds.contains(vid - 1))
+ }
+ }
+ }
+ }
+
+ private def getCycleGraph(sc: SparkContext, numVertices: Int): Graph[Double, Int] = {
+ val cycle = (0 until numVertices).map(x => (x, (x + 1) % numVertices))
+ getGraphFromSeq(sc, cycle)
+ }
+
+ private def getChainGraph(sc: SparkContext, numVertices: Int): Graph[Double, Int] = {
+ val chain = (0 until numVertices - 1).map(x => (x, (x + 1)))
+ getGraphFromSeq(sc, chain)
+ }
+
+ private def getGraphFromSeq(sc: SparkContext, seq: IndexedSeq[(Int, Int)]): Graph[Double, Int] = {
+ val rawEdges = sc.parallelize(seq, 3).map { case (s, d) => (s.toLong, d.toLong) }
+ Graph.fromEdgeTuples(rawEdges, 1.0).cache()
+ }
}