diff options
-rw-r--r-- | graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala | 26 | ||||
-rw-r--r-- | graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala | 15 |
2 files changed, 41 insertions, 0 deletions
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala index 116d1ea700..dc8b4789c4 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala @@ -279,6 +279,32 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali } /** + * Convert bi-directional edges into uni-directional ones. + * Some graph algorithms (e.g., TriangleCount) assume that an input graph + * has its edges in canonical direction. + * This function rewrites the vertex ids of edges so that srcIds are bigger + * than dstIds, and merges the duplicated edges. + * + * @param mergeFunc the user defined reduce function which should + * be commutative and associative and is used to combine the output + * of the map phase + * + * @return the resulting graph with canonical edges + */ + def convertToCanonicalEdges( + mergeFunc: (ED, ED) => ED = (e1, e2) => e1): Graph[VD, ED] = { + val newEdges = + graph.edges + .map { + case e if e.srcId < e.dstId => ((e.srcId, e.dstId), e.attr) + case e => ((e.dstId, e.srcId), e.attr) + } + .reduceByKey(mergeFunc) + .map(e => new Edge(e._1._1, e._1._2, e._2)) + Graph(graph.vertices, newEdges) + } + + /** * Execute a Pregel-like iterative vertex-parallel abstraction. The * user-defined vertex-program `vprog` is executed in parallel on * each vertex receiving any inbound messages and computing a new diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala index ea94d4accb..9bc8007ce4 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala @@ -79,6 +79,21 @@ class GraphOpsSuite extends FunSuite with LocalSparkContext { } } + test ("convertToCanonicalEdges") { + withSpark { sc => + val vertices = + sc.parallelize(Seq[(VertexId, String)]((1, "one"), (2, "two"), (3, "three")), 2) + val edges = + sc.parallelize(Seq(Edge(1, 2, 1), Edge(2, 1, 1), Edge(3, 2, 2))) + val g: Graph[String, Int] = Graph(vertices, edges) + + val g1 = g.convertToCanonicalEdges() + + val e = g1.edges.collect().toSet + assert(e === Set(Edge(1, 2, 1), Edge(2, 3, 2))) + } + } + test("collectEdgesCycleDirectionOut") { withSpark { sc => val graph = getCycleGraph(sc, 100) |