aboutsummaryrefslogtreecommitdiff
path: root/graphx
diff options
context:
space:
mode:
authorAnkur Dave <ankurdave@gmail.com>2014-01-13 16:48:11 -0800
committerAnkur Dave <ankurdave@gmail.com>2014-01-13 17:03:03 -0800
commitae4b75d94a4a0f2545e6d90d6f9b8f162bf70ded (patch)
tree2ed8615d2239f7f349fab128cb3eaeec191f3abb /graphx
parent1bd5cefcae2769d48ad5ef4b8197193371c754da (diff)
downloadspark-ae4b75d94a4a0f2545e6d90d6f9b8f162bf70ded.tar.gz
spark-ae4b75d94a4a0f2545e6d90d6f9b8f162bf70ded.tar.bz2
spark-ae4b75d94a4a0f2545e6d90d6f9b8f162bf70ded.zip
Add EdgeDirection.Either and use it to fix CC bug
The bug was due to a misunderstanding of the activeSetOpt parameter to Graph.mapReduceTriplets. Passing EdgeDirection.Both causes mapReduceTriplets to run only on edges with *both* vertices in the active set. This commit adds EdgeDirection.Either, which causes mapReduceTriplets to run on edges with *either* vertex in the active set. This is what connected components needed.
Diffstat (limited to 'graphx')
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala8
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/Graph.scala9
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala33
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala7
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala4
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala41
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala5
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala3
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala2
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala2
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala2
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala2
12 files changed, 64 insertions, 54 deletions
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala
index 9d37f6513f..5b58a61bbd 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala
@@ -6,11 +6,12 @@ package org.apache.spark.graphx
class EdgeDirection private (private val name: String) extends Serializable {
/**
* Reverse the direction of an edge. An in becomes out,
- * out becomes in and both remains both.
+ * out becomes in and both and either remain the same.
*/
def reverse: EdgeDirection = this match {
case EdgeDirection.In => EdgeDirection.Out
case EdgeDirection.Out => EdgeDirection.In
+ case EdgeDirection.Either => EdgeDirection.Either
case EdgeDirection.Both => EdgeDirection.Both
}
@@ -32,6 +33,9 @@ object EdgeDirection {
/** Edges originating from a vertex. */
final val Out = new EdgeDirection("Out")
- /** All edges adjacent to a vertex. */
+ /** Edges originating from *or* arriving at a vertex of interest. */
+ final val Either = new EdgeDirection("Either")
+
+ /** Edges originating from *and* arriving at a vertex of interest. */
final val Both = new EdgeDirection("Both")
}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala
index 7d4f0de3d6..49705fdf5d 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala
@@ -274,9 +274,12 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] {
* of the map phase
*
* @param activeSetOpt optionally, a set of "active" vertices and a direction of edges to consider
- * when running `mapFunc`. For example, if the direction is Out, `mapFunc` will only be run on
- * edges originating from vertices in the active set. The active set must have the same index as
- * the graph's vertices.
+ * when running `mapFunc`. If the direction is `In`, `mapFunc` will only be run on edges with
+ * destination in the active set. If the direction is `Out`, `mapFunc` will only be run on edges
+ * originating from vertices in the active set. If the direction is `Either`, `mapFunc` will be
+ * run on edges with *either* vertex in the active set. If the direction is `Both`, `mapFunc` will
+ * be run on edges with *both* vertices in the active set. The active set must have the same index
+ * as the graph's vertices.
*
* @example We can use this function to compute the in-degree of each
* vertex
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
index 578eb331c1..66d5180020 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
@@ -38,7 +38,7 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) {
* The degree of each vertex in the graph.
* @note Vertices with no edges are not returned in the resulting RDD.
*/
- lazy val degrees: VertexRDD[Int] = degreesRDD(EdgeDirection.Both)
+ lazy val degrees: VertexRDD[Int] = degreesRDD(EdgeDirection.Either)
/**
* Computes the neighboring vertex degrees.
@@ -50,7 +50,7 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) {
graph.mapReduceTriplets(et => Iterator((et.dstId,1)), _ + _)
} else if (edgeDirection == EdgeDirection.Out) {
graph.mapReduceTriplets(et => Iterator((et.srcId,1)), _ + _)
- } else { // EdgeDirection.both
+ } else { // EdgeDirection.Either
graph.mapReduceTriplets(et => Iterator((et.srcId,1), (et.dstId,1)), _ + _)
}
}
@@ -65,7 +65,7 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) {
*/
def collectNeighborIds(edgeDirection: EdgeDirection): VertexRDD[Array[VertexID]] = {
val nbrs =
- if (edgeDirection == EdgeDirection.Both) {
+ if (edgeDirection == EdgeDirection.Either) {
graph.mapReduceTriplets[Array[VertexID]](
mapFunc = et => Iterator((et.srcId, Array(et.dstId)), (et.dstId, Array(et.srcId))),
reduceFunc = _ ++ _
@@ -79,7 +79,8 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) {
mapFunc = et => Iterator((et.dstId, Array(et.srcId))),
reduceFunc = _ ++ _)
} else {
- throw new SparkException("It doesn't make sense to collect neighbor ids without a direction.")
+ throw new SparkException("It doesn't make sense to collect neighbor ids without a " +
+ "direction. (EdgeDirection.Both is not supported; use EdgeDirection.Either instead.)")
}
graph.vertices.leftZipJoin(nbrs) { (vid, vdata, nbrsOpt) =>
nbrsOpt.getOrElse(Array.empty[VertexID])
@@ -100,11 +101,19 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) {
*/
def collectNeighbors(edgeDirection: EdgeDirection): VertexRDD[Array[(VertexID, VD)]] = {
val nbrs = graph.mapReduceTriplets[Array[(VertexID,VD)]](
- edge => Iterator(
- (edge.srcId, Array((edge.dstId, edge.dstAttr))),
- (edge.dstId, Array((edge.srcId, edge.srcAttr)))),
- (a, b) => a ++ b,
- edgeDirection)
+ edge => {
+ val msgToSrc = (edge.srcId, Array((edge.dstId, edge.dstAttr)))
+ val msgToDst = (edge.dstId, Array((edge.srcId, edge.srcAttr)))
+ edgeDirection match {
+ case EdgeDirection.Either => Iterator(msgToSrc, msgToDst)
+ case EdgeDirection.In => Iterator(msgToDst)
+ case EdgeDirection.Out => Iterator(msgToSrc)
+ case EdgeDirection.Both =>
+ throw new SparkException("collectNeighbors does not support EdgeDirection.Both. Use" +
+ "EdgeDirection.Either instead.")
+ }
+ },
+ (a, b) => a ++ b)
graph.vertices.leftZipJoin(nbrs) { (vid, vdata, nbrsOpt) =>
nbrsOpt.getOrElse(Array.empty[(VertexID, VD)])
@@ -237,7 +246,7 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) {
def pregel[A: ClassTag](
initialMsg: A,
maxIterations: Int = Int.MaxValue,
- activeDirection: EdgeDirection = EdgeDirection.Out)(
+ activeDirection: EdgeDirection = EdgeDirection.Either)(
vprog: (VertexID, VD, A) => VD,
sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexID,A)],
mergeMsg: (A, A) => A)
@@ -271,8 +280,8 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) {
*
* @see [[org.apache.spark.graphx.lib.ConnectedComponents]]
*/
- def connectedComponents(undirected: Boolean = true): Graph[VertexID, ED] = {
- ConnectedComponents.run(graph, undirected)
+ def connectedComponents(): Graph[VertexID, ED] = {
+ ConnectedComponents.run(graph)
}
/**
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
index 83e28d0ab2..75b44ddac9 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
@@ -67,7 +67,10 @@ object Pregel {
*
* @param activeDirection the direction of edges incident to a vertex that received a message in
* the previous round on which to run `sendMsg`. For example, if this is `EdgeDirection.Out`, only
- * out-edges of vertices that received a message in the previous round will run.
+ * out-edges of vertices that received a message in the previous round will run. The default is
+ * `EdgeDirection.Either`, which will run `sendMsg` on edges where either side received a message
+ * in the previous round. If this is `EdgeDirection.Both`, `sendMsg` will only run on edges where
+ * *both* vertices received a message.
*
* @param vprog the user-defined vertex program which runs on each
* vertex and receives the inbound message and computes a new vertex
@@ -90,7 +93,7 @@ object Pregel {
*/
def apply[VD: ClassTag, ED: ClassTag, A: ClassTag]
(graph: Graph[VD, ED], initialMsg: A, maxIterations: Int = Int.MaxValue,
- activeDirection: EdgeDirection = EdgeDirection.Out)(
+ activeDirection: EdgeDirection = EdgeDirection.Either)(
vprog: (VertexID, VD, A) => VD,
sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexID,A)],
mergeMsg: (A, A) => A)
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala
index 6a2abc71cc..c21f8935d9 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala
@@ -275,6 +275,10 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected (
} else {
edgePartition.iterator.filter(e => vPart.isActive(e.srcId) && vPart.isActive(e.dstId))
}
+ case Some(EdgeDirection.Either) =>
+ // TODO: Because we only have a clustered index on the source vertex ID, we can't filter
+ // the index here. Instead we have to scan all edges and then do the filter.
+ edgePartition.iterator.filter(e => vPart.isActive(e.srcId) || vPart.isActive(e.dstId))
case Some(EdgeDirection.Out) =>
if (activeFraction < 0.8) {
edgePartition.indexIterator(srcVertexID => vPart.isActive(srcVertexID))
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala
index d078d2acdb..d057c933d7 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala
@@ -19,37 +19,22 @@ object ConnectedComponents {
* @return a graph with vertex attributes containing the smallest vertex in each
* connected component
*/
- def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED], undirected: Boolean = true):
+ def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]):
Graph[VertexID, ED] = {
val ccGraph = graph.mapVertices { case (vid, _) => vid }
- if (undirected) {
- def sendMessage(edge: EdgeTriplet[VertexID, ED]) = {
- if (edge.srcAttr < edge.dstAttr) {
- Iterator((edge.dstId, edge.srcAttr))
- } else if (edge.srcAttr > edge.dstAttr) {
- Iterator((edge.srcId, edge.dstAttr))
- } else {
- Iterator.empty
- }
+ def sendMessage(edge: EdgeTriplet[VertexID, ED]) = {
+ if (edge.srcAttr < edge.dstAttr) {
+ Iterator((edge.dstId, edge.srcAttr))
+ } else if (edge.srcAttr > edge.dstAttr) {
+ Iterator((edge.srcId, edge.dstAttr))
+ } else {
+ Iterator.empty
}
- val initialMessage = Long.MaxValue
- Pregel(ccGraph, initialMessage, activeDirection = EdgeDirection.Both)(
- vprog = (id, attr, msg) => math.min(attr, msg),
- sendMsg = sendMessage,
- mergeMsg = (a, b) => math.min(a, b))
- } else {
- def sendMessage(edge: EdgeTriplet[VertexID, ED]) = {
- if (edge.srcAttr < edge.dstAttr) {
- Iterator((edge.dstId, edge.srcAttr))
- } else {
- Iterator.empty
- }
- }
- val initialMessage = Long.MaxValue
- Pregel(ccGraph, initialMessage, activeDirection = EdgeDirection.Out)(
- vprog = (id, attr, msg) => math.min(attr, msg),
- sendMsg = sendMessage,
- mergeMsg = (a, b) => math.min(a, b))
}
+ val initialMessage = Long.MaxValue
+ Pregel(ccGraph, initialMessage, activeDirection = EdgeDirection.Either)(
+ vprog = (id, attr, msg) => math.min(attr, msg),
+ sendMsg = sendMessage,
+ mergeMsg = (a, b) => math.min(a, b))
} // end of connectedComponents
}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala
index cf95267e77..6ced2462eb 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala
@@ -77,7 +77,7 @@ object PageRank extends Logging {
val initialMessage = 0.0
// Execute pregel for a fixed number of iterations.
- Pregel(pagerankGraph, initialMessage, numIter)(
+ Pregel(pagerankGraph, initialMessage, numIter, activeDirection = EdgeDirection.Out)(
vertexProgram, sendMessage, messageCombiner)
}
@@ -153,7 +153,8 @@ object PageRank extends Logging {
val initialMessage = resetProb / (1.0 - resetProb)
// Execute a dynamic version of Pregel.
- Pregel(pagerankGraph, initialMessage)(vertexProgram, sendMessage, messageCombiner)
+ Pregel(pagerankGraph, initialMessage, activeDirection = EdgeDirection.Out)(
+ vertexProgram, sendMessage, messageCombiner)
.mapVertices((vid, attr) => attr._1)
} // end of deltaPageRank
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala
index 43c4b9cf2d..edffbcc5ac 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala
@@ -53,7 +53,8 @@ object StronglyConnectedComponents {
// collect min of all my neighbor's scc values, update if it's smaller than mine
// then notify any neighbors with scc values larger than mine
- sccWorkGraph = Pregel[(VertexID, Boolean), ED, VertexID](sccWorkGraph, Long.MaxValue)(
+ sccWorkGraph = Pregel[(VertexID, Boolean), ED, VertexID](
+ sccWorkGraph, Long.MaxValue, activeDirection = EdgeDirection.Out)(
(vid, myScc, neighborScc) => (math.min(myScc._1, neighborScc), myScc._2),
e => {
if (e.srcId < e.dstId) {
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala
index 58da9e3aed..d3e22b176c 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala
@@ -28,7 +28,7 @@ object TriangleCount {
// Construct set representations of the neighborhoods
val nbrSets: VertexRDD[VertexSet] =
- g.collectNeighborIds(EdgeDirection.Both).mapValues { (vid, nbrs) =>
+ g.collectNeighborIds(EdgeDirection.Either).mapValues { (vid, nbrs) =>
val set = new VertexSet(4)
var i = 0
while (i < nbrs.size) {
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala
index 7a901409d5..280f50e39a 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala
@@ -28,7 +28,7 @@ class GraphOpsSuite extends FunSuite with LocalSparkContext {
val chain = (0 until 100).map(x => (x, (x+1)%100) )
val rawEdges = sc.parallelize(chain, 3).map { case (s,d) => (s.toLong, d.toLong) }
val graph = Graph.fromEdgeTuples(rawEdges, 1.0).cache()
- val nbrs = graph.collectNeighborIds(EdgeDirection.Both).cache()
+ val nbrs = graph.collectNeighborIds(EdgeDirection.Either).cache()
assert(nbrs.count === chain.size)
assert(graph.numVertices === nbrs.count)
nbrs.collect.foreach { case (vid, nbrs) => assert(nbrs.size === 2) }
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala
index 1ff3d75633..bceff11b8e 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala
@@ -32,7 +32,7 @@ class PregelSuite extends FunSuite with LocalSparkContext {
Set((1: VertexID, 1)) ++ (2 to n).map(x => (x: VertexID, 0)).toSet)
val result = Pregel(chainWithSeed, 0)(
(vid, attr, msg) => math.max(msg, attr),
- et => Iterator((et.dstId, et.srcAttr)),
+ et => if (et.dstAttr != et.srcAttr) Iterator((et.dstId, et.srcAttr)) else Iterator.empty,
(a: Int, b: Int) => math.max(a, b))
assert(result.vertices.collect.toSet ===
chain.vertices.mapValues { (vid, attr) => attr + 1 }.collect.toSet)
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala
index 86da8f1b46..27c8705bca 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala
@@ -102,7 +102,7 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext {
val defaultUser = ("John Doe", "Missing")
// Build the initial Graph
val graph = Graph(users, relationships, defaultUser)
- val ccGraph = graph.connectedComponents(undirected = true)
+ val ccGraph = graph.connectedComponents()
val vertices = ccGraph.vertices.collect
for ( (id, cc) <- vertices ) {
assert(cc == 0)