aboutsummaryrefslogtreecommitdiff
path: root/graph/src
diff options
context:
space:
mode:
Diffstat (limited to 'graph/src')
-rw-r--r--graph/src/main/scala/org/apache/spark/graph/EdgeRDD.scala29
-rw-r--r--graph/src/main/scala/org/apache/spark/graph/VertexRDD.scala43
-rw-r--r--graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala12
3 files changed, 35 insertions, 49 deletions
diff --git a/graph/src/main/scala/org/apache/spark/graph/EdgeRDD.scala b/graph/src/main/scala/org/apache/spark/graph/EdgeRDD.scala
index 6f1d790325..230202d6b0 100644
--- a/graph/src/main/scala/org/apache/spark/graph/EdgeRDD.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/EdgeRDD.scala
@@ -53,34 +53,17 @@ class EdgeRDD[@specialized ED: ClassTag](
}, preservesPartitioning = true))
}
- def zipEdgePartitions[T: ClassTag, U: ClassTag]
- (other: RDD[T])
- (f: (Pid, EdgePartition[ED], Iterator[T]) => Iterator[U]): RDD[U] = {
- partitionsRDD.zipPartitions(other, preservesPartitioning = true) { (ePartIter, otherIter) =>
- val (pid, edgePartition) = ePartIter.next()
- f(pid, edgePartition, otherIter)
- }
- }
-
- def zipEdgePartitions[ED2: ClassTag, ED3: ClassTag]
- (other: EdgeRDD[ED2])
- (f: (Pid, EdgePartition[ED], EdgePartition[ED2]) => EdgePartition[ED3]): EdgeRDD[ED3] = {
- new EdgeRDD[ED3](partitionsRDD.zipPartitions(other.partitionsRDD, preservesPartitioning = true) {
- (thisIter, otherIter) =>
- val (pid, thisEPart) = thisIter.next()
- val (_, otherEPart) = otherIter.next()
- Iterator(Tuple2(pid, f(pid, thisEPart, otherEPart)))
- })
- }
-
def innerJoin[ED2: ClassTag, ED3: ClassTag]
(other: EdgeRDD[ED2])
(f: (Vid, Vid, ED, ED2) => ED3): EdgeRDD[ED3] = {
val ed2Tag = classTag[ED2]
val ed3Tag = classTag[ED3]
- zipEdgePartitions(other) { (pid, thisEPart, otherEPart) =>
- thisEPart.innerJoin(otherEPart)(f)(ed2Tag, ed3Tag)
- }
+ new EdgeRDD[ED3](partitionsRDD.zipPartitions(other.partitionsRDD, true) {
+ (thisIter, otherIter) =>
+ val (pid, thisEPart) = thisIter.next()
+ val (_, otherEPart) = otherIter.next()
+ Iterator(Tuple2(pid, thisEPart.innerJoin(otherEPart)(f)(ed2Tag, ed3Tag)))
+ })
}
def collectVids(): RDD[Vid] = {
diff --git a/graph/src/main/scala/org/apache/spark/graph/VertexRDD.scala b/graph/src/main/scala/org/apache/spark/graph/VertexRDD.scala
index 8e5e319928..c5fb4aeca7 100644
--- a/graph/src/main/scala/org/apache/spark/graph/VertexRDD.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/VertexRDD.scala
@@ -119,22 +119,6 @@ class VertexRDD[@specialized VD: ClassTag](
new VertexRDD(newPartitionsRDD)
}
- /**
- * Return a new VertexRDD by applying a function to corresponding
- * VertexPartitions of this VertexRDD and another one.
- */
- def zipVertexPartitions[VD2: ClassTag, VD3: ClassTag]
- (other: VertexRDD[VD2])
- (f: (VertexPartition[VD], VertexPartition[VD2]) => VertexPartition[VD3]): VertexRDD[VD3] = {
- val newPartitionsRDD = partitionsRDD.zipPartitions(
- other.partitionsRDD, preservesPartitioning = true
- ) { (thisIter, otherIter) =>
- val thisPart = thisIter.next()
- val otherPart = otherIter.next()
- Iterator(f(thisPart, otherPart))
- }
- new VertexRDD(newPartitionsRDD)
- }
/**
* Restrict the vertex set to the set of vertices satisfying the
@@ -184,9 +168,14 @@ class VertexRDD[@specialized VD: ClassTag](
* the values from `other`.
*/
def diff(other: VertexRDD[VD]): VertexRDD[VD] = {
- this.zipVertexPartitions(other) { (thisPart, otherPart) =>
- thisPart.diff(otherPart)
+ val newPartitionsRDD = partitionsRDD.zipPartitions(
+ other.partitionsRDD, preservesPartitioning = true
+ ) { (thisIter, otherIter) =>
+ val thisPart = thisIter.next()
+ val otherPart = otherIter.next()
+ Iterator(thisPart.diff(otherPart))
}
+ new VertexRDD(newPartitionsRDD)
}
/**
@@ -209,9 +198,14 @@ class VertexRDD[@specialized VD: ClassTag](
*/
def leftZipJoin[VD2: ClassTag, VD3: ClassTag]
(other: VertexRDD[VD2])(f: (Vid, VD, Option[VD2]) => VD3): VertexRDD[VD3] = {
- this.zipVertexPartitions(other) { (thisPart, otherPart) =>
- thisPart.leftJoin(otherPart)(f)
+ val newPartitionsRDD = partitionsRDD.zipPartitions(
+ other.partitionsRDD, preservesPartitioning = true
+ ) { (thisIter, otherIter) =>
+ val thisPart = thisIter.next()
+ val otherPart = otherIter.next()
+ Iterator(thisPart.leftJoin(otherPart)(f))
}
+ new VertexRDD(newPartitionsRDD)
}
/**
@@ -261,9 +255,14 @@ class VertexRDD[@specialized VD: ClassTag](
*/
def innerZipJoin[U: ClassTag, VD2: ClassTag](other: VertexRDD[U])
(f: (Vid, VD, U) => VD2): VertexRDD[VD2] = {
- this.zipVertexPartitions(other) { (thisPart, otherPart) =>
- thisPart.innerJoin(otherPart)(f)
+ val newPartitionsRDD = partitionsRDD.zipPartitions(
+ other.partitionsRDD, preservesPartitioning = true
+ ) { (thisIter, otherIter) =>
+ val thisPart = thisIter.next()
+ val otherPart = otherIter.next()
+ Iterator(thisPart.innerJoin(otherPart)(f))
}
+ new VertexRDD(newPartitionsRDD)
}
/**
diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala
index 826c1074a8..4d35755e7e 100644
--- a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala
@@ -49,7 +49,9 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected (
@transient override val triplets: RDD[EdgeTriplet[VD, ED]] = {
val vdTag = classTag[VD]
val edTag = classTag[ED]
- edges.zipEdgePartitions(replicatedVertexView.get(true, true)) { (pid, ePart, vPartIter) =>
+ edges.partitionsRDD.zipPartitions(
+ replicatedVertexView.get(true, true), true) { (ePartIter, vPartIter) =>
+ val (pid, ePart) = ePartIter.next()
val (_, vPart) = vPartIter.next()
new EdgeTripletIterator(vPart.index, vPart.values, ePart)(vdTag, edTag)
}
@@ -182,8 +184,9 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected (
// manifest from GraphImpl (which would require serializing GraphImpl).
val vdTag = classTag[VD]
val newEdgePartitions =
- edges.zipEdgePartitions(replicatedVertexView.get(true, true)) {
- (ePid, edgePartition, vTableReplicatedIter) =>
+ edges.partitionsRDD.zipPartitions(replicatedVertexView.get(true, true), true) {
+ (ePartIter, vTableReplicatedIter) =>
+ val (ePid, edgePartition) = ePartIter.next()
val (vPid, vPart) = vTableReplicatedIter.next()
assert(!vTableReplicatedIter.hasNext)
assert(ePid == vPid)
@@ -267,7 +270,8 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected (
val activeDirectionOpt = activeSetOpt.map(_._2)
// Map and combine.
- val preAgg = edges.zipEdgePartitions(vs) { (ePid, edgePartition, vPartIter) =>
+ val preAgg = edges.partitionsRDD.zipPartitions(vs, true) { (ePartIter, vPartIter) =>
+ val (ePid, edgePartition) = ePartIter.next()
val (vPid, vPart) = vPartIter.next()
assert(!vPartIter.hasNext)
assert(ePid == vPid)