aboutsummaryrefslogtreecommitdiff
path: root/graph/src
diff options
context:
space:
mode:
authorAnkur Dave <ankurdave@gmail.com>2013-12-20 12:59:07 -0800
committerAnkur Dave <ankurdave@gmail.com>2013-12-20 13:00:06 -0800
commit32508e20d468fcb72fb89e6ae23c9fdd6475f0c8 (patch)
treefd4335609dc6af1530d90eb91f2ad428af8fb40b /graph/src
parentac70b8f234493fa670104f0599669500697d2533 (diff)
downloadspark-32508e20d468fcb72fb89e6ae23c9fdd6475f0c8.tar.gz
spark-32508e20d468fcb72fb89e6ae23c9fdd6475f0c8.tar.bz2
spark-32508e20d468fcb72fb89e6ae23c9fdd6475f0c8.zip
Test VertexPartition and fix bugs
Diffstat (limited to 'graph/src')
-rw-r--r--graph/src/main/scala/org/apache/spark/graph/impl/VertexPartition.scala24
-rw-r--r--graph/src/test/scala/org/apache/spark/graph/impl/VertexPartitionSuite.scala113
2 files changed, 128 insertions, 9 deletions
diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/VertexPartition.scala b/graph/src/main/scala/org/apache/spark/graph/impl/VertexPartition.scala
index 7710d6eada..9b2d66999c 100644
--- a/graph/src/main/scala/org/apache/spark/graph/impl/VertexPartition.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/impl/VertexPartition.scala
@@ -188,8 +188,10 @@ class VertexPartition[@specialized(Long, Int, Double) VD: ClassManifest](
val newValues = new Array[VD2](capacity)
iter.foreach { case (vid, vdata) =>
val pos = index.getPos(vid)
- newMask.set(pos)
- newValues(pos) = vdata
+ if (pos >= 0) {
+ newMask.set(pos)
+ newValues(pos) = vdata
+ }
}
new VertexPartition[VD2](index, newValues, newMask)
}
@@ -204,8 +206,10 @@ class VertexPartition[@specialized(Long, Int, Double) VD: ClassManifest](
System.arraycopy(values, 0, newValues, 0, newValues.length)
iter.foreach { case (vid, vdata) =>
val pos = index.getPos(vid)
- newMask.set(pos)
- newValues(pos) = vdata
+ if (pos >= 0) {
+ newMask.set(pos)
+ newValues(pos) = vdata
+ }
}
new VertexPartition(index, newValues, newMask)
}
@@ -219,11 +223,13 @@ class VertexPartition[@specialized(Long, Int, Double) VD: ClassManifest](
val vid = product._1
val vdata = product._2
val pos = index.getPos(vid)
- if (newMask.get(pos)) {
- newValues(pos) = reduceFunc(newValues(pos), vdata)
- } else { // otherwise just store the new value
- newMask.set(pos)
- newValues(pos) = vdata
+ if (pos >= 0) {
+ if (newMask.get(pos)) {
+ newValues(pos) = reduceFunc(newValues(pos), vdata)
+ } else { // otherwise just store the new value
+ newMask.set(pos)
+ newValues(pos) = vdata
+ }
}
}
new VertexPartition[VD2](index, newValues, newMask)
diff --git a/graph/src/test/scala/org/apache/spark/graph/impl/VertexPartitionSuite.scala b/graph/src/test/scala/org/apache/spark/graph/impl/VertexPartitionSuite.scala
new file mode 100644
index 0000000000..72579a48c2
--- /dev/null
+++ b/graph/src/test/scala/org/apache/spark/graph/impl/VertexPartitionSuite.scala
@@ -0,0 +1,113 @@
+package org.apache.spark.graph.impl
+
+import org.apache.spark.graph._
+import org.scalatest.FunSuite
+
+class VertexPartitionSuite extends FunSuite {
+
+ test("isDefined, filter") {
+ val vp = VertexPartition(Iterator((0L, 1), (1L, 1))).filter { (vid, attr) => vid == 0 }
+ assert(vp.isDefined(0))
+ assert(!vp.isDefined(1))
+ assert(!vp.isDefined(2))
+ assert(!vp.isDefined(-1))
+ }
+
+ test("isActive, numActives, replaceActives") {
+ val vp = VertexPartition(Iterator((0L, 1), (1L, 1)))
+ .filter { (vid, attr) => vid == 0 }
+ .replaceActives(Iterator(0, 2, 0))
+ assert(vp.isActive(0))
+ assert(!vp.isActive(1))
+ assert(vp.isActive(2))
+ assert(!vp.isActive(-1))
+ assert(vp.numActives == Some(2))
+ }
+
+ test("map") {
+ val vp = VertexPartition(Iterator((0L, 1), (1L, 1))).map { (vid, attr) => 2 }
+ assert(vp(0) === 2)
+ }
+
+ test("diff") {
+ val vp = VertexPartition(Iterator((0L, 1), (1L, 1), (2L, 1)))
+ val vp2 = vp.filter { (vid, attr) => vid <= 1 }
+ val vp3a = vp.map { (vid, attr) => 2 }
+ val vp3b = VertexPartition(vp3a.iterator)
+ // diff with same index
+ val diff1 = vp2.diff(vp3a)
+ assert(diff1(0) === 2)
+ assert(diff1(1) === 2)
+ assert(diff1(2) === 2)
+ assert(!diff1.isDefined(2))
+ // diff with different indexes
+ val diff2 = vp2.diff(vp3b)
+ assert(diff2(0) === 2)
+ assert(diff2(1) === 2)
+ assert(diff2(2) === 2)
+ assert(!diff2.isDefined(2))
+ }
+
+ test("leftJoin") {
+ val vp = VertexPartition(Iterator((0L, 1), (1L, 1), (2L, 1)))
+ val vp2a = vp.filter { (vid, attr) => vid <= 1 }.map { (vid, attr) => 2 }
+ val vp2b = VertexPartition(vp2a.iterator)
+ // leftJoin with same index
+ val join1 = vp.leftJoin(vp2a) { (vid, a, bOpt) => bOpt.getOrElse(a) }
+ assert(join1.iterator.toSet === Set((0L, 2), (1L, 2), (2L, 1)))
+ // leftJoin with different indexes
+ val join2 = vp.leftJoin(vp2b) { (vid, a, bOpt) => bOpt.getOrElse(a) }
+ assert(join2.iterator.toSet === Set((0L, 2), (1L, 2), (2L, 1)))
+ // leftJoin an iterator
+ val join3 = vp.leftJoin(vp2a.iterator) { (vid, a, bOpt) => bOpt.getOrElse(a) }
+ assert(join3.iterator.toSet === Set((0L, 2), (1L, 2), (2L, 1)))
+ }
+
+ test("innerJoin") {
+ val vp = VertexPartition(Iterator((0L, 1), (1L, 1), (2L, 1)))
+ val vp2a = vp.filter { (vid, attr) => vid <= 1 }.map { (vid, attr) => 2 }
+ val vp2b = VertexPartition(vp2a.iterator)
+ // innerJoin with same index
+ val join1 = vp.innerJoin(vp2a) { (vid, a, b) => b }
+ assert(join1.iterator.toSet === Set((0L, 2), (1L, 2)))
+ // innerJoin with different indexes
+ val join2 = vp.innerJoin(vp2b) { (vid, a, b) => b }
+ assert(join2.iterator.toSet === Set((0L, 2), (1L, 2)))
+ // innerJoin an iterator
+ val join3 = vp.innerJoin(vp2a.iterator) { (vid, a, b) => b }
+ assert(join3.iterator.toSet === Set((0L, 2), (1L, 2)))
+ }
+
+ test("createUsingIndex") {
+ val vp = VertexPartition(Iterator((0L, 1), (1L, 1), (2L, 1)))
+ val elems = List((0L, 2), (2L, 2), (3L, 2))
+ val vp2 = vp.createUsingIndex(elems.iterator)
+ assert(vp2.iterator.toSet === Set((0L, 2), (2L, 2)))
+ assert(vp.index === vp2.index)
+ }
+
+ test("innerJoinKeepLeft") {
+ val vp = VertexPartition(Iterator((0L, 1), (1L, 1), (2L, 1)))
+ val elems = List((0L, 2), (2L, 2), (3L, 2))
+ val vp2 = vp.innerJoinKeepLeft(elems.iterator)
+ assert(vp2.iterator.toSet === Set((0L, 2), (2L, 2)))
+ assert(vp2(1) === 1)
+ }
+
+ test("aggregateUsingIndex") {
+ val vp = VertexPartition(Iterator((0L, 1), (1L, 1), (2L, 1)))
+ val messages = List((0L, "a"), (2L, "b"), (0L, "c"), (3L, "d"))
+ val vp2 = vp.aggregateUsingIndex[String](messages.iterator, _ + _)
+ assert(vp2.iterator.toSet === Set((0L, "ac"), (2L, "b")))
+ }
+
+ test("reindex") {
+ val vp = VertexPartition(Iterator((0L, 1), (1L, 1), (2L, 1)))
+ val vp2 = vp.filter { (vid, attr) => vid <= 1 }
+ val vp3 = vp2.reindex()
+ assert(vp2.iterator.toSet === vp3.iterator.toSet)
+ assert(vp2(2) === 1)
+ assert(vp3.index.getPos(2) === -1)
+ }
+
+}