aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala26
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala15
2 files changed, 41 insertions, 0 deletions
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
index 116d1ea700..dc8b4789c4 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
@@ -279,6 +279,32 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
}
/**
+ * Convert bi-directional edges into uni-directional ones.
+ * Some graph algorithms (e.g., TriangleCount) assume that an input graph
+ * has its edges in canonical direction.
+ * This function rewrites the vertex ids of edges so that srcIds are bigger
+ * than dstIds, and merges the duplicated edges.
+ *
+ * @param mergeFunc the user defined reduce function which should
+ * be commutative and associative and is used to combine the output
+ * of the map phase
+ *
+ * @return the resulting graph with canonical edges
+ */
+ def convertToCanonicalEdges(
+ mergeFunc: (ED, ED) => ED = (e1, e2) => e1): Graph[VD, ED] = {
+ val newEdges =
+ graph.edges
+ .map {
+ case e if e.srcId < e.dstId => ((e.srcId, e.dstId), e.attr)
+ case e => ((e.dstId, e.srcId), e.attr)
+ }
+ .reduceByKey(mergeFunc)
+ .map(e => new Edge(e._1._1, e._1._2, e._2))
+ Graph(graph.vertices, newEdges)
+ }
+
+ /**
* Execute a Pregel-like iterative vertex-parallel abstraction. The
* user-defined vertex-program `vprog` is executed in parallel on
* each vertex receiving any inbound messages and computing a new
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala
index ea94d4accb..9bc8007ce4 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala
@@ -79,6 +79,21 @@ class GraphOpsSuite extends FunSuite with LocalSparkContext {
}
}
+ test ("convertToCanonicalEdges") {
+ withSpark { sc =>
+ val vertices =
+ sc.parallelize(Seq[(VertexId, String)]((1, "one"), (2, "two"), (3, "three")), 2)
+ val edges =
+ sc.parallelize(Seq(Edge(1, 2, 1), Edge(2, 1, 1), Edge(3, 2, 2)))
+ val g: Graph[String, Int] = Graph(vertices, edges)
+
+ val g1 = g.convertToCanonicalEdges()
+
+ val e = g1.edges.collect().toSet
+ assert(e === Set(Edge(1, 2, 1), Edge(2, 3, 2)))
+ }
+ }
+
test("collectEdgesCycleDirectionOut") {
withSpark { sc =>
val graph = getCycleGraph(sc, 100)