aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala12
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala13
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala10
3 files changed, 31 insertions, 4 deletions
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala
index 8c62897037..8b910fbc5a 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala
@@ -300,6 +300,18 @@ class VertexRDD[@specialized VD: ClassTag](
def reverseRoutingTables(): VertexRDD[VD] =
this.mapVertexPartitions(vPart => vPart.withRoutingTable(vPart.routingTable.reverse))
+ /** Prepares this VertexRDD for efficient joins with the given EdgeRDD. */
+ def withEdges(edges: EdgeRDD[_, _]): VertexRDD[VD] = {
+ val routingTables = VertexRDD.createRoutingTables(edges, this.partitioner.get)
+ val vertexPartitions = partitionsRDD.zipPartitions(routingTables, true) {
+ (partIter, routingTableIter) =>
+ val routingTable =
+ if (routingTableIter.hasNext) routingTableIter.next() else RoutingTablePartition.empty
+ partIter.map(_.withRoutingTable(routingTable))
+ }
+ new VertexRDD(vertexPartitions)
+ }
+
/** Generates an RDD of vertex attributes suitable for shipping to the edge partitions. */
private[graphx] def shipVertexAttributes(
shipSrc: Boolean, shipDst: Boolean): RDD[(PartitionID, VertexAttributeBlock[VD])] = {
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala
index 2f2d0e03fd..1649b244d2 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala
@@ -88,8 +88,8 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected (
}
val edgePartition = builder.toEdgePartition
Iterator((pid, edgePartition))
- }, preservesPartitioning = true))
- GraphImpl.fromExistingRDDs(vertices, newEdges)
+ }, preservesPartitioning = true)).cache()
+ GraphImpl.fromExistingRDDs(vertices.withEdges(newEdges), newEdges)
}
override def reverse: Graph[VD, ED] = {
@@ -277,7 +277,11 @@ object GraphImpl {
GraphImpl(vertexRDD, edgeRDD)
}
- /** Create a graph from a VertexRDD and an EdgeRDD with arbitrary replicated vertices. */
+ /**
+ * Create a graph from a VertexRDD and an EdgeRDD with arbitrary replicated vertices. The
+ * VertexRDD must already be set up for efficient joins with the EdgeRDD by calling
+ * `VertexRDD.withEdges` or an appropriate VertexRDD constructor.
+ */
def apply[VD: ClassTag, ED: ClassTag](
vertices: VertexRDD[VD],
edges: EdgeRDD[ED, _]): GraphImpl[VD, ED] = {
@@ -290,7 +294,8 @@ object GraphImpl {
/**
* Create a graph from a VertexRDD and an EdgeRDD with the same replicated vertex type as the
- * vertices.
+ * vertices. The VertexRDD must already be set up for efficient joins with the EdgeRDD by calling
+ * `VertexRDD.withEdges` or an appropriate VertexRDD constructor.
*/
def fromExistingRDDs[VD: ClassTag, ED: ClassTag](
vertices: VertexRDD[VD],
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
index 7b9bac5d9c..abc25d0671 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
@@ -133,6 +133,16 @@ class GraphSuite extends FunSuite with LocalSparkContext {
Iterator((part.srcIds ++ part.dstIds).toSet)
}.collect
assert(verts.exists(id => partitionSetsUnpartitioned.count(_.contains(id)) > bound))
+
+ // Forming triplets view
+ val g = Graph(
+ sc.parallelize(List((0L, "a"), (1L, "b"), (2L, "c"))),
+ sc.parallelize(List(Edge(0L, 1L, 1), Edge(0L, 2L, 1)), 2))
+ assert(g.triplets.collect.map(_.toTuple).toSet ===
+ Set(((0L, "a"), (1L, "b"), 1), ((0L, "a"), (2L, "c"), 1)))
+ val gPart = g.partitionBy(EdgePartition2D)
+ assert(gPart.triplets.collect.map(_.toTuple).toSet ===
+ Set(((0L, "a"), (1L, "b"), 1), ((0L, "a"), (2L, "c"), 1)))
}
}