aboutsummaryrefslogtreecommitdiff
path: root/graphx
diff options
context:
space:
mode:
Diffstat (limited to 'graphx')
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala4
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala28
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala11
3 files changed, 36 insertions, 7 deletions
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala
index 04fbc9dbab..2c8b245955 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala
@@ -392,7 +392,7 @@ object VertexRDD {
*/
def apply[VD: ClassTag](
vertices: RDD[(VertexId, VD)], edges: EdgeRDD[_, _], defaultVal: VD): VertexRDD[VD] = {
- VertexRDD(vertices, edges, defaultVal, (a, b) => b)
+ VertexRDD(vertices, edges, defaultVal, (a, b) => a)
}
/**
@@ -419,7 +419,7 @@ object VertexRDD {
(vertexIter, routingTableIter) =>
val routingTable =
if (routingTableIter.hasNext) routingTableIter.next() else RoutingTablePartition.empty
- Iterator(ShippableVertexPartition(vertexIter, routingTable, defaultVal))
+ Iterator(ShippableVertexPartition(vertexIter, routingTable, defaultVal, mergeFunc))
}
new VertexRDD(vertexPartitions)
}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala
index dca54b8a7d..5412d72047 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala
@@ -36,7 +36,7 @@ private[graphx]
object ShippableVertexPartition {
/** Construct a `ShippableVertexPartition` from the given vertices without any routing table. */
def apply[VD: ClassTag](iter: Iterator[(VertexId, VD)]): ShippableVertexPartition[VD] =
- apply(iter, RoutingTablePartition.empty, null.asInstanceOf[VD])
+ apply(iter, RoutingTablePartition.empty, null.asInstanceOf[VD], (a, b) => a)
/**
* Construct a `ShippableVertexPartition` from the given vertices with the specified routing
@@ -44,10 +44,28 @@ object ShippableVertexPartition {
*/
def apply[VD: ClassTag](
iter: Iterator[(VertexId, VD)], routingTable: RoutingTablePartition, defaultVal: VD)
- : ShippableVertexPartition[VD] = {
- val fullIter = iter ++ routingTable.iterator.map(vid => (vid, defaultVal))
- val (index, values, mask) = VertexPartitionBase.initFrom(fullIter, (a: VD, b: VD) => a)
- new ShippableVertexPartition(index, values, mask, routingTable)
+ : ShippableVertexPartition[VD] =
+ apply(iter, routingTable, defaultVal, (a, b) => a)
+
+ /**
+ * Construct a `ShippableVertexPartition` from the given vertices with the specified routing
+ * table, filling in missing vertices mentioned in the routing table using `defaultVal`,
+ * and merging duplicate vertex atrribute with mergeFunc.
+ */
+ def apply[VD: ClassTag](
+ iter: Iterator[(VertexId, VD)], routingTable: RoutingTablePartition, defaultVal: VD,
+ mergeFunc: (VD, VD) => VD): ShippableVertexPartition[VD] = {
+ val map = new GraphXPrimitiveKeyOpenHashMap[VertexId, VD]
+ // Merge the given vertices using mergeFunc
+ iter.foreach { pair =>
+ map.setMerge(pair._1, pair._2, mergeFunc)
+ }
+ // Fill in missing vertices mentioned in the routing table
+ routingTable.iterator.foreach { vid =>
+ map.changeValue(vid, defaultVal, identity)
+ }
+
+ new ShippableVertexPartition(map.keySet, map._values, map.keySet.getBitSet, routingTable)
}
import scala.language.implicitConversions
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala
index cc86bafd2d..42d3f21dba 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala
@@ -99,4 +99,15 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext {
}
}
+ test("mergeFunc") {
+ // test to see if the mergeFunc is working correctly
+ withSpark { sc =>
+ val verts = sc.parallelize(List((0L, 0), (1L, 1), (1L, 2), (2L, 3), (2L, 3), (2L, 3)))
+ val edges = EdgeRDD.fromEdges(sc.parallelize(List.empty[Edge[Int]]))
+ val rdd = VertexRDD(verts, edges, 0, (a: Int, b: Int) => a + b)
+ // test merge function
+ assert(rdd.collect.toSet == Set((0L, 0), (1L, 3), (2L, 9)))
+ }
+ }
+
}