aboutsummaryrefslogtreecommitdiff
path: root/graph/src
diff options
context:
space:
mode:
authorAnkur Dave <ankurdave@gmail.com>2013-12-20 13:36:14 -0800
committerAnkur Dave <ankurdave@gmail.com>2013-12-20 13:38:19 -0800
commit6bb077cd3de5ce959576ac21b0ae917452802cbc (patch)
tree0b7f38b97411440a983794eabd2b00ffc80f5478 /graph/src
parentac70b8f234493fa670104f0599669500697d2533 (diff)
downloadspark-6bb077cd3de5ce959576ac21b0ae917452802cbc.tar.gz
spark-6bb077cd3de5ce959576ac21b0ae917452802cbc.tar.bz2
spark-6bb077cd3de5ce959576ac21b0ae917452802cbc.zip
Reuse VTableReplicated in GraphImpl.subgraph
Diffstat (limited to 'graph/src')
-rw-r--r--graph/src/main/scala/org/apache/spark/graph/impl/EdgePartition.scala10
-rw-r--r--graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala14
-rw-r--r--graph/src/test/scala/org/apache/spark/graph/impl/EdgePartitionSuite.scala11
3 files changed, 28 insertions, 7 deletions
diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/EdgePartition.scala b/graph/src/main/scala/org/apache/spark/graph/impl/EdgePartition.scala
index e97522feae..3430ffdfc4 100644
--- a/graph/src/main/scala/org/apache/spark/graph/impl/EdgePartition.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/impl/EdgePartition.scala
@@ -56,6 +56,16 @@ class EdgePartition[@specialized(Char, Int, Boolean, Byte, Long, Float, Double)
new EdgePartition(srcIds, dstIds, newData, index)
}
+ def filter(pred: Edge[ED] => Boolean): EdgePartition[ED] = {
+ val builder = new EdgePartitionBuilder[ED]
+ iterator.foreach { e =>
+ if (pred(e)) {
+ builder.add(e.srcId, e.dstId, e.attr)
+ }
+ }
+ builder.toEdgePartition
+ }
+
/**
* Apply the function f to all edges in this partition.
*
diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala
index e7f975253a..9e44f49113 100644
--- a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala
@@ -196,14 +196,12 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
override def subgraph(
epred: EdgeTriplet[VD, ED] => Boolean = x => true,
vpred: (Vid, VD) => Boolean = (a, b) => true): Graph[VD, ED] = {
+ // Filter the vertices, reusing the partitioner and the index from this graph
+ val newVerts = vertices.mapVertexPartitions(_.filter(vpred))
- // Filter the vertices, reusing the partitioner (but not the index) from
- // this graph
- val newVTable = vertices.mapVertexPartitions(_.filter(vpred).reindex())
-
+ // Filter the edges
val edManifest = classManifest[ED]
-
- val newETable = new EdgeRDD[ED](triplets.filter { et =>
+ val newEdges = new EdgeRDD[ED](triplets.filter { et =>
vpred(et.srcId, et.srcAttr) && vpred(et.dstId, et.dstAttr) && epred(et)
}.mapPartitionsWithIndex( { (pid, iter) =>
val builder = new EdgePartitionBuilder[ED]()(edManifest)
@@ -212,7 +210,9 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
Iterator((pid, edgePartition))
}, preservesPartitioning = true)).cache()
- new GraphImpl(newVTable, newETable)
+ // Reuse the previous VTableReplicated unmodified. It will contain extra vertices, which is
+ // fine.
+ new GraphImpl(newVerts, newEdges, new VertexPlacement(newEdges, newVerts), vTableReplicated)
} // end of subgraph
override def mask[VD2: ClassManifest, ED2: ClassManifest] (
diff --git a/graph/src/test/scala/org/apache/spark/graph/impl/EdgePartitionSuite.scala b/graph/src/test/scala/org/apache/spark/graph/impl/EdgePartitionSuite.scala
index a52a5653e2..2991533e89 100644
--- a/graph/src/test/scala/org/apache/spark/graph/impl/EdgePartitionSuite.scala
+++ b/graph/src/test/scala/org/apache/spark/graph/impl/EdgePartitionSuite.scala
@@ -31,6 +31,17 @@ class EdgePartitionSuite extends FunSuite {
edges.map(e => e.copy(attr = e.srcId + e.dstId)))
}
+ test("filter") {
+ val edges = List(Edge(0, 1, 0), Edge(1, 2, 0), Edge(2, 0, 0))
+ val builder = new EdgePartitionBuilder[Int]
+ for (e <- edges) {
+ builder.add(e.srcId, e.dstId, e.attr)
+ }
+ val edgePartition = builder.toEdgePartition
+ assert(edgePartition.filter(e => e.srcId <= 1).iterator.map(_.copy()).toList ===
+ edges.filter(e => e.srcId <= 1))
+ }
+
test("groupEdges") {
val edges = List(
Edge(0, 1, 1), Edge(1, 2, 2), Edge(2, 0, 4), Edge(0, 1, 8), Edge(1, 2, 16), Edge(2, 0, 32))