aboutsummaryrefslogtreecommitdiff
path: root/graphx
diff options
context:
space:
mode:
authorZheng RuiFeng <ruifengz@foxmail.com>2016-03-16 11:52:25 -0700
committerReynold Xin <rxin@databricks.com>2016-03-16 11:52:25 -0700
commit91984978e7c86650d4cf523724b4a1aeaaecf260 (patch)
tree6e1d0fecda1b0d66128d754f7508f901218c60c8 /graphx
parentd9670f84739b0840501b19b8cb0e851850edb8c1 (diff)
downloadspark-91984978e7c86650d4cf523724b4a1aeaaecf260.tar.gz
spark-91984978e7c86650d4cf523724b4a1aeaaecf260.tar.bz2
spark-91984978e7c86650d4cf523724b4a1aeaaecf260.zip
[SPARK-13816][GRAPHX] Add parameter checks for algorithms in Graphx
JIRA: https://issues.apache.org/jira/browse/SPARK-13816 ## What changes were proposed in this pull request? Add parameter checks for algorithms in Graphx: Pregel,LabelPropagation,PageRank,SVDPlusPlus ## How was this patch tested? manual tests Author: Zheng RuiFeng <ruifengz@foxmail.com> Closes #11655 from zhengruifeng/graphx_param_check.
Diffstat (limited to 'graphx')
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala3
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala4
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala2
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala9
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala5
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala4
6 files changed, 25 insertions, 2 deletions
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
index 3ba73b4c96..efdc2481fe 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
@@ -119,6 +119,9 @@ object Pregel extends Logging {
mergeMsg: (A, A) => A)
: Graph[VD, ED] =
{
+ require(maxIterations > 0, s"Maximum of iterations must be greater than 0," +
+ s" but got ${maxIterations}")
+
var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg)).cache()
// compute the messages
var messages = GraphXUtils.mapReduceTriplets(g, sendMsg, mergeMsg)
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala
index 40cf0735e2..137c512c99 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala
@@ -36,7 +36,9 @@ object ConnectedComponents {
*/
def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED],
maxIterations: Int): Graph[VertexId, ED] = {
- require(maxIterations > 0)
+ require(maxIterations > 0, s"Maximum of iterations must be greater than 0," +
+ s" but got ${maxIterations}")
+
val ccGraph = graph.mapVertices { case (vid, _) => vid }
def sendMessage(edge: EdgeTriplet[VertexId, ED]): Iterator[(VertexId, VertexId)] = {
if (edge.srcAttr < edge.dstAttr) {
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala
index 7a53eca7ea..fc7547a2c7 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala
@@ -43,6 +43,8 @@ object LabelPropagation {
* @return a graph with vertex attributes containing the label of community affiliation
*/
def run[VD, ED: ClassTag](graph: Graph[VD, ED], maxSteps: Int): Graph[VertexId, ED] = {
+ require(maxSteps > 0, s"Maximum of steps must be greater than 0, but got ${maxSteps}")
+
val lpaGraph = graph.mapVertices { case (vid, _) => vid }
def sendMessage(e: EdgeTriplet[VertexId, ED]): Iterator[(VertexId, Map[VertexId, Long])] = {
Iterator((e.srcId, Map(e.dstAttr -> 1L)), (e.dstId, Map(e.srcAttr -> 1L)))
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala
index 00ba358a9b..9d9a26ebeb 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala
@@ -104,6 +104,11 @@ object PageRank extends Logging {
graph: Graph[VD, ED], numIter: Int, resetProb: Double = 0.15,
srcId: Option[VertexId] = None): Graph[Double, Double] =
{
+ require(numIter > 0, s"Number of iterations must be greater than 0," +
+ s" but got ${numIter}")
+ require(resetProb >= 0 && resetProb <= 1, s"Random reset probability must belong" +
+ s" to [0, 1], but got ${resetProb}")
+
val personalized = srcId isDefined
val src: VertexId = srcId.getOrElse(-1L)
@@ -197,6 +202,10 @@ object PageRank extends Logging {
graph: Graph[VD, ED], tol: Double, resetProb: Double = 0.15,
srcId: Option[VertexId] = None): Graph[Double, Double] =
{
+ require(tol >= 0, s"Tolerance must be no less than 0, but got ${tol}")
+ require(resetProb >= 0 && resetProb <= 1, s"Random reset probability must belong" +
+ s" to [0, 1], but got ${resetProb}")
+
val personalized = srcId.isDefined
val src: VertexId = srcId.getOrElse(-1L)
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala
index 78a5cb057d..bb2ffab0f6 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala
@@ -56,6 +56,11 @@ object SVDPlusPlus {
def run(edges: RDD[Edge[Double]], conf: Conf)
: (Graph[(Array[Double], Array[Double], Double, Double), Double], Double) =
{
+ require(conf.maxIters > 0, s"Maximum of iterations must be greater than 0," +
+ s" but got ${conf.maxIters}")
+ require(conf.maxVal > conf.minVal, s"MaxVal must be greater than MinVal," +
+ s" but got {maxVal: ${conf.maxVal}, minVal: ${conf.minVal}}")
+
// Generate default vertex attribute
def defaultF(rank: Int): (Array[Double], Array[Double], Double, Double) = {
// TODO: use a fixed random seed
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala
index 7063137d47..1fa92b0195 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala
@@ -36,7 +36,9 @@ object StronglyConnectedComponents {
* @return a graph with vertex attributes containing the smallest vertex id in each SCC
*/
def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED], numIter: Int): Graph[VertexId, ED] = {
- require(numIter > 0, s"Number of iterations ${numIter} must be greater than 0.")
+ require(numIter > 0, s"Number of iterations must be greater than 0," +
+ s" but got ${numIter}")
+
// the graph we update with final SCC ids, and the graph we return at the end
var sccGraph = graph.mapVertices { case (vid, _) => vid }
// graph we are going to work with in our iterations