diff options
author | Reynold Xin <rxin@apache.org> | 2014-01-13 18:45:20 -0800 |
---|---|---|
committer | Reynold Xin <rxin@apache.org> | 2014-01-13 18:45:20 -0800 |
commit | 8e5c7324303ee9a9a61ad35e94ada5638ca0cf70 (patch) | |
tree | 772e77a4b5a988dd1fecfa4ebe3cda45b46daefe | |
parent | 1dce9ce446dd248755cd65b7a6a0729a4dca2d62 (diff) | |
download | spark-8e5c7324303ee9a9a61ad35e94ada5638ca0cf70.tar.gz spark-8e5c7324303ee9a9a61ad35e94ada5638ca0cf70.tar.bz2 spark-8e5c7324303ee9a9a61ad35e94ada5638ca0cf70.zip |
Moved SVDPlusPlusConf into SVDPlusPlus object itself.
-rw-r--r-- | graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala | 30 | ||||
-rw-r--r-- | graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala | 2 |
2 files changed, 17 insertions, 15 deletions
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala index fa6b1db29b..ba6517e012 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala @@ -5,19 +5,21 @@ import org.apache.commons.math.linear._ import org.apache.spark.rdd._ import org.apache.spark.graphx._ -/** Configuration parameters for SVDPlusPlus. */ -class SVDPlusPlusConf( - var rank: Int, - var maxIters: Int, - var minVal: Double, - var maxVal: Double, - var gamma1: Double, - var gamma2: Double, - var gamma6: Double, - var gamma7: Double) extends Serializable - /** Implementation of SVD++ algorithm. */ object SVDPlusPlus { + + /** Configuration parameters for SVDPlusPlus. */ + class Conf( + var rank: Int, + var maxIters: Int, + var minVal: Double, + var maxVal: Double, + var gamma1: Double, + var gamma2: Double, + var gamma6: Double, + var gamma7: Double) + extends Serializable + /** * Implement SVD++ based on "Factorization Meets the Neighborhood: * a Multifaceted Collaborative Filtering Model", @@ -32,7 +34,7 @@ object SVDPlusPlus { * * @return a graph with vertex attributes containing the trained model */ - def run(edges: RDD[Edge[Double]], conf: SVDPlusPlusConf) + def run(edges: RDD[Edge[Double]], conf: Conf) : (Graph[(RealVector, RealVector, Double, Double), Double], Double) = { // Generate default vertex attribute @@ -64,7 +66,7 @@ object SVDPlusPlus { (vd._1, vd._2, msg.get._2 / msg.get._1, 1.0 / scala.math.sqrt(msg.get._1)) } - def mapTrainF(conf: SVDPlusPlusConf, u: Double) + def mapTrainF(conf: Conf, u: Double) (et: EdgeTriplet[(RealVector, RealVector, Double, Double), Double]) : Iterator[(VertexID, (RealVector, RealVector, Double))] = { val (usr, itm) = (et.srcAttr, et.dstAttr) @@ -112,7 +114,7 @@ object SVDPlusPlus { } // calculate error on training set - def mapTestF(conf: SVDPlusPlusConf, u: Double) + def mapTestF(conf: Conf, u: Double) (et: EdgeTriplet[(RealVector, RealVector, Double, Double), Double]) : Iterator[(VertexID, Double)] = { diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala index a4a1cdab18..e173c652a5 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala @@ -18,7 +18,7 @@ class SVDPlusPlusSuite extends FunSuite with LocalSparkContext { val fields = line.split(",") Edge(fields(0).toLong * 2, fields(1).toLong * 2 + 1, fields(2).toDouble) } - val conf = new SVDPlusPlusConf(10, 2, 0.0, 5.0, 0.007, 0.007, 0.005, 0.015) // 2 iterations + val conf = new SVDPlusPlus.Conf(10, 2, 0.0, 5.0, 0.007, 0.007, 0.005, 0.015) // 2 iterations var (graph, u) = SVDPlusPlus.run(edges, conf) graph.cache() val err = graph.vertices.collect.map{ case (vid, vd) => |