From 8e5c7324303ee9a9a61ad35e94ada5638ca0cf70 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 13 Jan 2014 18:45:20 -0800 Subject: Moved SVDPlusPlusConf into SVDPlusPlus object itself. --- .../org/apache/spark/graphx/lib/SVDPlusPlus.scala | 30 ++++++++++++---------- .../apache/spark/graphx/lib/SVDPlusPlusSuite.scala | 2 +- 2 files changed, 17 insertions(+), 15 deletions(-) (limited to 'graphx') diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala index fa6b1db29b..ba6517e012 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala @@ -5,19 +5,21 @@ import org.apache.commons.math.linear._ import org.apache.spark.rdd._ import org.apache.spark.graphx._ -/** Configuration parameters for SVDPlusPlus. */ -class SVDPlusPlusConf( - var rank: Int, - var maxIters: Int, - var minVal: Double, - var maxVal: Double, - var gamma1: Double, - var gamma2: Double, - var gamma6: Double, - var gamma7: Double) extends Serializable - /** Implementation of SVD++ algorithm. */ object SVDPlusPlus { + + /** Configuration parameters for SVDPlusPlus. */ + class Conf( + var rank: Int, + var maxIters: Int, + var minVal: Double, + var maxVal: Double, + var gamma1: Double, + var gamma2: Double, + var gamma6: Double, + var gamma7: Double) + extends Serializable + /** * Implement SVD++ based on "Factorization Meets the Neighborhood: * a Multifaceted Collaborative Filtering Model", @@ -32,7 +34,7 @@ object SVDPlusPlus { * * @return a graph with vertex attributes containing the trained model */ - def run(edges: RDD[Edge[Double]], conf: SVDPlusPlusConf) + def run(edges: RDD[Edge[Double]], conf: Conf) : (Graph[(RealVector, RealVector, Double, Double), Double], Double) = { // Generate default vertex attribute @@ -64,7 +66,7 @@ object SVDPlusPlus { (vd._1, vd._2, msg.get._2 / msg.get._1, 1.0 / scala.math.sqrt(msg.get._1)) } - def mapTrainF(conf: SVDPlusPlusConf, u: Double) + def mapTrainF(conf: Conf, u: Double) (et: EdgeTriplet[(RealVector, RealVector, Double, Double), Double]) : Iterator[(VertexID, (RealVector, RealVector, Double))] = { val (usr, itm) = (et.srcAttr, et.dstAttr) @@ -112,7 +114,7 @@ object SVDPlusPlus { } // calculate error on training set - def mapTestF(conf: SVDPlusPlusConf, u: Double) + def mapTestF(conf: Conf, u: Double) (et: EdgeTriplet[(RealVector, RealVector, Double, Double), Double]) : Iterator[(VertexID, Double)] = { diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala index a4a1cdab18..e173c652a5 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala @@ -18,7 +18,7 @@ class SVDPlusPlusSuite extends FunSuite with LocalSparkContext { val fields = line.split(",") Edge(fields(0).toLong * 2, fields(1).toLong * 2 + 1, fields(2).toDouble) } - val conf = new SVDPlusPlusConf(10, 2, 0.0, 5.0, 0.007, 0.007, 0.005, 0.015) // 2 iterations + val conf = new SVDPlusPlus.Conf(10, 2, 0.0, 5.0, 0.007, 0.007, 0.005, 0.015) // 2 iterations var (graph, u) = SVDPlusPlus.run(edges, conf) graph.cache() val err = graph.vertices.collect.map{ case (vid, vd) => -- cgit v1.2.3