aboutsummaryrefslogtreecommitdiff
path: root/graph/src
diff options
context:
space:
mode:
authorWang Jianping J <jianping.j.wang@gmail.com>2013-12-17 19:41:04 +0800
committerWang Jianping J <jianping.j.wang@gmail.com>2013-12-17 19:41:04 +0800
commit772b192910207b5a8cbfaae9573c30efa3add7c3 (patch)
tree24622cc674f6b8ce2416f6e6827cf109d1ee7d1b /graph/src
parent9d2351f501b47d02399391475a57ee19ac2a28e1 (diff)
downloadspark-772b192910207b5a8cbfaae9573c30efa3add7c3.tar.gz
spark-772b192910207b5a8cbfaae9573c30efa3add7c3.tar.bz2
spark-772b192910207b5a8cbfaae9573c30efa3add7c3.zip
Update AnalyticsSuite.scala
Diffstat (limited to 'graph/src')
-rw-r--r--graph/src/test/scala/org/apache/spark/graph/AnalyticsSuite.scala91
1 files changed, 56 insertions, 35 deletions
diff --git a/graph/src/test/scala/org/apache/spark/graph/AnalyticsSuite.scala b/graph/src/test/scala/org/apache/spark/graph/AnalyticsSuite.scala
index fd70306803..05ebe2b84d 100644
--- a/graph/src/test/scala/org/apache/spark/graph/AnalyticsSuite.scala
+++ b/graph/src/test/scala/org/apache/spark/graph/AnalyticsSuite.scala
@@ -51,35 +51,38 @@ class AnalyticsSuite extends FunSuite with LocalSparkContext {
System.setProperty("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
System.setProperty("spark.kryo.registrator", "org.apache.spark.graph.GraphKryoRegistrator")
+ def compareRanks(a: VertexRDD[Double], b: VertexRDD[Double]): Double = {
+ a.leftJoin(b) { case (id, a, bOpt) => (a - bOpt.getOrElse(0.0)) * (a - bOpt.getOrElse(0.0)) }
+ .map { case (id, error) => error }.sum
+ }
test("Star PageRank") {
withSpark(new SparkContext("local", "test")) { sc =>
val nVertices = 100
val starGraph = GraphGenerators.starGraph(sc, nVertices).cache()
val resetProb = 0.15
- val prGraph1 = PageRank.run(starGraph, 1, resetProb)
- val prGraph2 = PageRank.run(starGraph, 2, resetProb)
+ val errorTol = 1.0e-5
+
+ val staticRanks1 = PageRank.run(starGraph, numIter = 1, resetProb).vertices.cache()
+ val staticRanks2 = PageRank.run(starGraph, numIter = 2, resetProb).vertices.cache()
- val notMatching = prGraph1.vertices.zipJoin(prGraph2.vertices) { (vid, pr1, pr2) =>
- if (pr1 != pr2) { 1 } else { 0 }
+ // Static PageRank should only take 2 iterations to converge
+ val notMatching = staticRanks1.zipJoin(staticRanks2) { (vid, pr1, pr2) =>
+ if (pr1 != pr2) 1 else 0
}.map { case (vid, test) => test }.sum
assert(notMatching === 0)
- //prGraph2.vertices.foreach(println(_))
- val errors = prGraph2.vertices.map { case (vid, pr) =>
+
+ val staticErrors = staticRanks2.map { case (vid, pr) =>
val correct = (vid > 0 && pr == resetProb) ||
- (vid == 0 && math.abs(pr - (resetProb + (1.0 - resetProb) * (resetProb * (nVertices - 1)) )) < 1.0E-5)
- if ( !correct ) { 1 } else { 0 }
+ (vid == 0 && math.abs(pr - (resetProb + (1.0 - resetProb) * (resetProb * (nVertices - 1)) )) < 1.0E-5)
+ if (!correct) 1 else 0
}
- assert(errors.sum === 0)
+ assert(staticErrors.sum === 0)
- val prGraph3 = PageRank.runUntillConvergence(starGraph, 0, resetProb)
- val errors2 = prGraph2.vertices.leftJoin(prGraph3.vertices){ (vid, pr1, pr2Opt) =>
- pr2Opt match {
- case Some(pr2) if(pr1 == pr2) => 0
- case _ => 1
- }
- }.map { case (vid, test) => test }.sum
- assert(errors2 === 0)
+ val dynamicRanks = PageRank.runUntillConvergence(starGraph, 0, resetProb).vertices.cache()
+ val standaloneRanks = PageRank.runStandalone(starGraph, 0, resetProb).cache()
+ assert(compareRanks(staticRanks2, dynamicRanks) < errorTol)
+ assert(compareRanks(staticRanks2, standaloneRanks) < errorTol)
}
} // end of test Star PageRank
@@ -87,27 +90,46 @@ class AnalyticsSuite extends FunSuite with LocalSparkContext {
test("Grid PageRank") {
withSpark(new SparkContext("local", "test")) { sc =>
- val gridGraph = GraphGenerators.gridGraph(sc, 10, 10).cache()
+ val rows = 10
+ val cols = 10
val resetProb = 0.15
- val prGraph1 = PageRank.run(gridGraph, 50, resetProb).cache()
- val prGraph2 = PageRank.runUntillConvergence(gridGraph, 0.0001, resetProb).cache()
- val error = prGraph1.vertices.zipJoin(prGraph2.vertices) { case (id, a, b) => (a - b) * (a - b) }
- .map { case (id, error) => error }.sum
- //prGraph1.vertices.zipJoin(prGraph2.vertices) { (id, a, b) => (a, b, a-b) }.foreach(println(_))
- println(error)
- assert(error < 1.0e-5)
- val pr3: RDD[(Vid, Double)] = sc.parallelize(GridPageRank(10,10, 50, resetProb))
- val error2 = prGraph1.vertices.leftJoin(pr3) { (id, a, bOpt) =>
- val b: Double = bOpt.get
- (a - b) * (a - b)
- }.map { case (id, error) => error }.sum
- //prGraph1.vertices.leftJoin(pr3) { (id, a, b) => (a, b) }.foreach( println(_) )
- println(error2)
- assert(error2 < 1.0e-5)
+ val tol = 0.0001
+ val numIter = 50
+ val errorTol = 1.0e-5
+ val gridGraph = GraphGenerators.gridGraph(sc, rows, cols).cache()
+
+ val staticRanks = PageRank.run(gridGraph, numIter, resetProb).vertices.cache()
+ val dynamicRanks = PageRank.runUntillConvergence(gridGraph, tol, resetProb).vertices.cache()
+ val standaloneRanks = PageRank.runStandalone(gridGraph, tol, resetProb).cache()
+ val referenceRanks = VertexRDD(sc.parallelize(GridPageRank(rows, cols, numIter, resetProb)))
+
+ assert(compareRanks(staticRanks, referenceRanks) < errorTol)
+ assert(compareRanks(dynamicRanks, referenceRanks) < errorTol)
+ assert(compareRanks(standaloneRanks, referenceRanks) < errorTol)
}
} // end of Grid PageRank
+ test("Chain PageRank") {
+ withSpark(new SparkContext("local", "test")) { sc =>
+ val chain1 = (0 until 9).map(x => (x, x+1) )
+ val rawEdges = sc.parallelize(chain1, 1).map { case (s,d) => (s.toLong, d.toLong) }
+ val chain = Graph.fromEdgeTuples(rawEdges, 1.0).cache()
+ val resetProb = 0.15
+ val tol = 0.0001
+ val numIter = 10
+ val errorTol = 1.0e-5
+
+ val staticRanks = PageRank.run(chain, numIter, resetProb).vertices.cache()
+ val dynamicRanks = PageRank.runUntillConvergence(chain, tol, resetProb).vertices.cache()
+ val standaloneRanks = PageRank.runStandalone(chain, tol, resetProb).cache()
+
+ assert(compareRanks(staticRanks, dynamicRanks) < errorTol)
+ assert(compareRanks(dynamicRanks, standaloneRanks) < errorTol)
+ }
+ }
+
+
test("Grid Connected Components") {
withSpark(new SparkContext("local", "test")) { sc =>
val gridGraph = GraphGenerators.gridGraph(sc, 10, 10).cache()
@@ -167,7 +189,6 @@ class AnalyticsSuite extends FunSuite with LocalSparkContext {
}
}
val ccMap = vertices.toMap
- println(ccMap)
for ( id <- 0 until 20 ) {
if (id < 10) {
assert(ccMap(id) === 0)
@@ -230,7 +251,7 @@ class AnalyticsSuite extends FunSuite with LocalSparkContext {
withSpark(new SparkContext("local", "test")) { sc =>
val rawEdges = sc.parallelize(Array(0L -> 1L, 1L -> 2L, 2L -> 0L) ++
Array(0L -> 1L, 1L -> 2L, 2L -> 0L), 2)
- val graph = Graph.fromEdgeTuples(rawEdges, true).cache()
+ val graph = Graph.fromEdgeTuples(rawEdges, true, uniqueEdges = Some(RandomVertexCut)).cache()
val triangleCount = TriangleCount.run(graph)
val verts = triangleCount.vertices
verts.collect.foreach { case (vid, count) => assert(count === 1) }