aboutsummaryrefslogtreecommitdiff
path: root/graph/src
diff options
context:
space:
mode:
authorAnkur Dave <ankurdave@gmail.com>2013-12-19 20:32:30 -0800
committerAnkur Dave <ankurdave@gmail.com>2013-12-19 20:32:30 -0800
commita69465b1fa7250d036e1585543c225b6340e4790 (patch)
treebbe359648ddc556d180d0602dd23aceefd314c0b /graph/src
parentda9f5e3fc093a91e0e91bc9311d5f5d085dbc929 (diff)
downloadspark-a69465b1fa7250d036e1585543c225b6340e4790.tar.gz
spark-a69465b1fa7250d036e1585543c225b6340e4790.tar.bz2
spark-a69465b1fa7250d036e1585543c225b6340e4790.zip
Split VertexRDD tests; fix #114
Diffstat (limited to 'graph/src')
-rw-r--r--graph/src/main/scala/org/apache/spark/graph/VertexRDD.scala26
-rw-r--r--graph/src/main/scala/org/apache/spark/graph/impl/VertexPartition.scala38
-rw-r--r--graph/src/test/scala/org/apache/spark/graph/AnalyticsSuite.scala2
-rw-r--r--graph/src/test/scala/org/apache/spark/graph/GraphSuite.scala5
-rw-r--r--graph/src/test/scala/org/apache/spark/graph/VertexRDDSuite.scala89
5 files changed, 86 insertions, 74 deletions
diff --git a/graph/src/main/scala/org/apache/spark/graph/VertexRDD.scala b/graph/src/main/scala/org/apache/spark/graph/VertexRDD.scala
index 90ac6dc61d..fe0f0ae491 100644
--- a/graph/src/main/scala/org/apache/spark/graph/VertexRDD.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/VertexRDD.scala
@@ -188,28 +188,6 @@ class VertexRDD[@specialized VD: ClassManifest](
}
/**
- * Inner join this VertexSet with another VertexSet which has the
- * same Index. This function will fail if both VertexSets do not
- * share the same index. The resulting vertex set will only contain
- * vertices that are in both this and the other vertex set.
- *
- * @tparam VD2 the attribute type of the other VertexSet
- * @tparam VD3 the attribute type of the resulting VertexSet
- *
- * @param other the other VertexSet with which to join.
- * @param f the function mapping a vertex id and its attributes in
- * this and the other vertex set to a new vertex attribute.
- * @return a VertexRDD containing only the vertices in both this
- * and the other VertexSet and with tuple attributes.
- */
- def zipJoin[VD2: ClassManifest, VD3: ClassManifest]
- (other: VertexRDD[VD2])(f: (Vid, VD, VD2) => VD3): VertexRDD[VD3] = {
- this.zipVertexPartitions(other) { (thisPart, otherPart) =>
- thisPart.join(otherPart)(f)
- }
- }
-
- /**
* Left join this VertexSet with another VertexSet which has the
* same Index. This function will fail if both VertexSets do not
* share the same index. The resulting vertex set contains an entry
@@ -309,6 +287,10 @@ class VertexRDD[@specialized VD: ClassManifest](
}
}
+ /**
+ * Aggregate messages with the same ids using `reduceFunc`, returning a VertexRDD that is
+ * co-indexed with this one.
+ */
def aggregateUsingIndex[VD2: ClassManifest](
messages: RDD[(Vid, VD2)], reduceFunc: (VD2, VD2) => VD2): VertexRDD[VD2] =
{
diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/VertexPartition.scala b/graph/src/main/scala/org/apache/spark/graph/impl/VertexPartition.scala
index ccbc83c512..7710d6eada 100644
--- a/graph/src/main/scala/org/apache/spark/graph/impl/VertexPartition.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/impl/VertexPartition.scala
@@ -125,27 +125,6 @@ class VertexPartition[@specialized(Long, Int, Double) VD: ClassManifest](
}
}
- /** Inner join another VertexPartition. */
- def join[VD2: ClassManifest, VD3: ClassManifest]
- (other: VertexPartition[VD2])
- (f: (Vid, VD, VD2) => VD3): VertexPartition[VD3] =
- {
- if (index != other.index) {
- logWarning("Joining two VertexPartitions with different indexes is slow.")
- join(createUsingIndex(other.iterator))(f)
- } else {
- val newValues = new Array[VD3](capacity)
- val newMask = mask & other.mask
-
- var i = newMask.nextSetBit(0)
- while (i >= 0) {
- newValues(i) = f(index.getValue(i), values(i), other.values(i))
- i = mask.nextSetBit(i + 1)
- }
- new VertexPartition(index, newValues, newMask)
- }
- }
-
/** Left outer join another VertexPartition. */
def leftJoin[VD2: ClassManifest, VD3: ClassManifest]
(other: VertexPartition[VD2])
@@ -179,15 +158,16 @@ class VertexPartition[@specialized(Long, Int, Double) VD: ClassManifest](
if (index != other.index) {
logWarning("Joining two VertexPartitions with different indexes is slow.")
innerJoin(createUsingIndex(other.iterator))(f)
+ } else {
+ val newMask = mask & other.mask
+ val newValues = new Array[VD2](capacity)
+ var i = newMask.nextSetBit(0)
+ while (i >= 0) {
+ newValues(i) = f(index.getValue(i), values(i), other.values(i))
+ i = newMask.nextSetBit(i + 1)
+ }
+ new VertexPartition(index, newValues, newMask)
}
- val newMask = mask & other.mask
- val newValues = new Array[VD2](capacity)
- var i = newMask.nextSetBit(0)
- while (i >= 0) {
- newValues(i) = f(index.getValue(i), values(i), other.values(i))
- i = newMask.nextSetBit(i + 1)
- }
- new VertexPartition(index, newValues, newMask)
}
/**
diff --git a/graph/src/test/scala/org/apache/spark/graph/AnalyticsSuite.scala b/graph/src/test/scala/org/apache/spark/graph/AnalyticsSuite.scala
index 1e6d8ec7cf..77a193a9ac 100644
--- a/graph/src/test/scala/org/apache/spark/graph/AnalyticsSuite.scala
+++ b/graph/src/test/scala/org/apache/spark/graph/AnalyticsSuite.scala
@@ -62,7 +62,7 @@ class AnalyticsSuite extends FunSuite with LocalSparkContext {
val staticRanks2 = PageRank.run(starGraph, numIter = 2, resetProb).vertices.cache()
// Static PageRank should only take 2 iterations to converge
- val notMatching = staticRanks1.zipJoin(staticRanks2) { (vid, pr1, pr2) =>
+ val notMatching = staticRanks1.innerZipJoin(staticRanks2) { (vid, pr1, pr2) =>
if (pr1 != pr2) 1 else 0
}.map { case (vid, test) => test }.sum
assert(notMatching === 0)
diff --git a/graph/src/test/scala/org/apache/spark/graph/GraphSuite.scala b/graph/src/test/scala/org/apache/spark/graph/GraphSuite.scala
index 09da102350..487d949e1f 100644
--- a/graph/src/test/scala/org/apache/spark/graph/GraphSuite.scala
+++ b/graph/src/test/scala/org/apache/spark/graph/GraphSuite.scala
@@ -1,13 +1,9 @@
package org.apache.spark.graph
-import scala.util.Random
-
import org.scalatest.FunSuite
import org.apache.spark.SparkContext
import org.apache.spark.graph.Graph._
-import org.apache.spark.graph.impl.EdgePartition
-import org.apache.spark.graph.impl.EdgePartitionBuilder
import org.apache.spark.rdd._
class GraphSuite extends FunSuite with LocalSparkContext {
@@ -118,6 +114,7 @@ class GraphSuite extends FunSuite with LocalSparkContext {
}
test("mapTriplets") {
+ // TODO(ankurdave): Write the test
}
test("reverse") {
diff --git a/graph/src/test/scala/org/apache/spark/graph/VertexRDDSuite.scala b/graph/src/test/scala/org/apache/spark/graph/VertexRDDSuite.scala
index 316968bbf0..e876b8e4e8 100644
--- a/graph/src/test/scala/org/apache/spark/graph/VertexRDDSuite.scala
+++ b/graph/src/test/scala/org/apache/spark/graph/VertexRDDSuite.scala
@@ -1,32 +1,85 @@
package org.apache.spark.graph
-import scala.util.Random
-
-import org.scalatest.FunSuite
-
import org.apache.spark.SparkContext
import org.apache.spark.graph.Graph._
import org.apache.spark.graph.impl.EdgePartition
-import org.apache.spark.graph.impl.EdgePartitionBuilder
import org.apache.spark.rdd._
+import org.scalatest.FunSuite
class VertexRDDSuite extends FunSuite with LocalSparkContext {
- test("VertexRDD") {
+ def vertices(sc: SparkContext, n: Int) = {
+ VertexRDD(sc.parallelize((0 to n).map(x => (x.toLong, x)), 5))
+ }
+
+ test("filter") {
+ withSpark { sc =>
+ val n = 100
+ val verts = vertices(sc, n)
+ val evens = verts.filter(q => ((q._2 % 2) == 0))
+ assert(evens.count === (0 to n).filter(_ % 2 == 0).size)
+ }
+ }
+
+ test("mapValues") {
+ withSpark { sc =>
+ val n = 100
+ val verts = vertices(sc, n)
+ val negatives = verts.mapValues(x => -x).cache() // Allow joining b with a derived RDD of b
+ assert(negatives.count === n + 1)
+ }
+ }
+
+ test("diff") {
withSpark { sc =>
val n = 100
- val a = sc.parallelize((0 to n).map(x => (x.toLong, x.toLong)), 5)
- val b = VertexRDD(a).mapValues(x => -x).cache() // Allow joining b with a derived RDD of b
- assert(b.count === n + 1)
- assert(b.leftJoin(a){ (id, a, bOpt) => a + bOpt.get }.map(x=> x._2).reduce(_+_) === 0)
- val c = b.aggregateUsingIndex[Long](a, (x, y) => x)
- assert(b.leftJoin(c){ (id, b, cOpt) => b + cOpt.get }.map(x=> x._2).reduce(_+_) === 0)
- val d = c.filter(q => ((q._2 % 2) == 0))
- val e = a.filter(q => ((q._2 % 2) == 0))
- assert(d.count === e.count)
- assert(b.zipJoin(c)((id, b, c) => b + c).map(x => x._2).reduce(_+_) === 0)
- val f = b.mapValues(x => if (x % 2 == 0) -x else x)
- assert(b.diff(f).collect().toSet === (2 to n by 2).map(x => (x.toLong, x.toLong)).toSet)
+ val verts = vertices(sc, n)
+ val flipEvens = verts.mapValues(x => if (x % 2 == 0) -x else x)
+ // diff should keep only the changed vertices
+ assert(verts.diff(flipEvens).map(_._2).collect().toSet === (2 to n by 2).map(-_).toSet)
+ // diff should keep the vertex values from `other`
+ assert(flipEvens.diff(verts).map(_._2).collect().toSet === (2 to n by 2).toSet)
}
}
+
+ test("leftJoin") {
+ withSpark { sc =>
+ val n = 100
+ val verts = vertices(sc, n)
+ val evens = verts.filter(q => ((q._2 % 2) == 0))
+ // leftJoin with another VertexRDD
+ assert(verts.leftJoin(evens) { (id, a, bOpt) => a - bOpt.getOrElse(0) }.collect.toSet ===
+ (0 to n by 2).map(x => (x.toLong, 0)).toSet ++ (1 to n by 2).map(x => (x.toLong, x)).toSet)
+ // leftJoin with an RDD
+ val evensRDD = evens.map(identity)
+ assert(verts.leftJoin(evensRDD) { (id, a, bOpt) => a - bOpt.getOrElse(0) }.collect.toSet ===
+ (0 to n by 2).map(x => (x.toLong, 0)).toSet ++ (1 to n by 2).map(x => (x.toLong, x)).toSet)
+ }
+ }
+
+ test("innerJoin") {
+ withSpark { sc =>
+ val n = 100
+ val verts = vertices(sc, n)
+ val evens = verts.filter(q => ((q._2 % 2) == 0))
+ // innerJoin with another VertexRDD
+ assert(verts.innerJoin(evens) { (id, a, b) => a - b }.collect.toSet ===
+ (0 to n by 2).map(x => (x.toLong, 0)).toSet)
+ // innerJoin with an RDD
+ val evensRDD = evens.map(identity)
+ assert(verts.innerJoin(evensRDD) { (id, a, b) => a - b }.collect.toSet ===
+ (0 to n by 2).map(x => (x.toLong, 0)).toSet) }
+ }
+
+ test("aggregateUsingIndex") {
+ withSpark { sc =>
+ val n = 100
+ val verts = vertices(sc, n)
+ val messageTargets = (0 to n) ++ (0 to n by 2)
+ val messages = sc.parallelize(messageTargets.map(x => (x.toLong, 1)))
+ assert(verts.aggregateUsingIndex[Int](messages, _ + _).collect.toSet ===
+ (0 to n).map(x => (x.toLong, if (x % 2 == 0) 2 else 1)).toSet)
+ }
+ }
+
}