aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/Graph.scala5
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala14
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala2
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala2
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala25
5 files changed, 40 insertions, 8 deletions
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala
index 14ae50e665..4db45c9af8 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala
@@ -138,7 +138,8 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
* }}}
*
*/
- def mapVertices[VD2: ClassTag](map: (VertexId, VD) => VD2): Graph[VD2, ED]
+ def mapVertices[VD2: ClassTag](map: (VertexId, VD) => VD2)
+ (implicit eq: VD =:= VD2 = null): Graph[VD2, ED]
/**
* Transforms each edge attribute in the graph using the map function. The map function is not
@@ -348,7 +349,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
* }}}
*/
def outerJoinVertices[U: ClassTag, VD2: ClassTag](other: RDD[(VertexId, U)])
- (mapFunc: (VertexId, VD, Option[U]) => VD2)
+ (mapFunc: (VertexId, VD, Option[U]) => VD2)(implicit eq: VD =:= VD2 = null)
: Graph[VD2, ED]
/**
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala
index 15ea05cbe2..ccdaa82eb9 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala
@@ -104,8 +104,11 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected (
new GraphImpl(vertices.reverseRoutingTables(), replicatedVertexView.reverse())
}
- override def mapVertices[VD2: ClassTag](f: (VertexId, VD) => VD2): Graph[VD2, ED] = {
- if (classTag[VD] equals classTag[VD2]) {
+ override def mapVertices[VD2: ClassTag]
+ (f: (VertexId, VD) => VD2)(implicit eq: VD =:= VD2 = null): Graph[VD2, ED] = {
+ // The implicit parameter eq will be populated by the compiler if VD and VD2 are equal, and left
+ // null if not
+ if (eq != null) {
vertices.cache()
// The map preserves type, so we can use incremental replication
val newVerts = vertices.mapVertexPartitions(_.map(f)).cache()
@@ -232,8 +235,11 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected (
override def outerJoinVertices[U: ClassTag, VD2: ClassTag]
(other: RDD[(VertexId, U)])
- (updateF: (VertexId, VD, Option[U]) => VD2): Graph[VD2, ED] = {
- if (classTag[VD] equals classTag[VD2]) {
+ (updateF: (VertexId, VD, Option[U]) => VD2)
+ (implicit eq: VD =:= VD2 = null): Graph[VD2, ED] = {
+ // The implicit parameter eq will be populated by the compiler if VD and VD2 are equal, and left
+ // null if not
+ if (eq != null) {
vertices.cache()
// updateF preserves type, so we can use incremental replication
val newVerts = vertices.leftJoin(other)(updateF).cache()
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala
index 776bfb8dd6..82e9e06515 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala
@@ -41,7 +41,7 @@ object LabelPropagation {
*
* @return a graph with vertex attributes containing the label of community affiliation
*/
- def run[ED: ClassTag](graph: Graph[_, ED], maxSteps: Int): Graph[VertexId, ED] = {
+ def run[VD, ED: ClassTag](graph: Graph[VD, ED], maxSteps: Int): Graph[VertexId, ED] = {
val lpaGraph = graph.mapVertices { case (vid, _) => vid }
def sendMessage(e: EdgeTriplet[VertexId, ED]) = {
Iterator((e.srcId, Map(e.dstAttr -> 1L)), (e.dstId, Map(e.srcAttr -> 1L)))
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala
index bba070f256..590f047495 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala
@@ -49,7 +49,7 @@ object ShortestPaths {
* @return a graph where each vertex attribute is a map containing the shortest-path distance to
* each reachable landmark vertex.
*/
- def run[ED: ClassTag](graph: Graph[_, ED], landmarks: Seq[VertexId]): Graph[SPMap, ED] = {
+ def run[VD, ED: ClassTag](graph: Graph[VD, ED], landmarks: Seq[VertexId]): Graph[SPMap, ED] = {
val spGraph = graph.mapVertices { (vid, attr) =>
if (landmarks.contains(vid)) makeMap(vid -> 0) else makeMap()
}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
index abc25d0671..6506bac73d 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
@@ -159,6 +159,31 @@ class GraphSuite extends FunSuite with LocalSparkContext {
}
}
+ test("mapVertices changing type with same erased type") {
+ withSpark { sc =>
+ val vertices = sc.parallelize(Array[(Long, Option[java.lang.Integer])](
+ (1L, Some(1)),
+ (2L, Some(2)),
+ (3L, Some(3))
+ ))
+ val edges = sc.parallelize(Array(
+ Edge(1L, 2L, 0),
+ Edge(2L, 3L, 0),
+ Edge(3L, 1L, 0)
+ ))
+ val graph0 = Graph(vertices, edges)
+ // Trigger initial vertex replication
+ graph0.triplets.foreach(x => {})
+ // Change type of replicated vertices, but preserve erased type
+ val graph1 = graph0.mapVertices {
+ case (vid, integerOpt) => integerOpt.map((x: java.lang.Integer) => (x.toDouble): java.lang.Double)
+ }
+ // Access replicated vertices, exposing the erased type
+ val graph2 = graph1.mapTriplets(t => t.srcAttr.get)
+ assert(graph2.edges.map(_.attr).collect.toSet === Set[java.lang.Double](1.0, 2.0, 3.0))
+ }
+ }
+
test("mapEdges") {
withSpark { sc =>
val n = 3