aboutsummaryrefslogtreecommitdiff
path: root/graphx
diff options
context:
space:
mode:
authorReynold Xin <rxin@apache.org>2014-01-13 18:45:20 -0800
committerReynold Xin <rxin@apache.org>2014-01-13 18:45:20 -0800
commit8e5c7324303ee9a9a61ad35e94ada5638ca0cf70 (patch)
tree772e77a4b5a988dd1fecfa4ebe3cda45b46daefe /graphx
parent1dce9ce446dd248755cd65b7a6a0729a4dca2d62 (diff)
downloadspark-8e5c7324303ee9a9a61ad35e94ada5638ca0cf70.tar.gz
spark-8e5c7324303ee9a9a61ad35e94ada5638ca0cf70.tar.bz2
spark-8e5c7324303ee9a9a61ad35e94ada5638ca0cf70.zip
Moved SVDPlusPlusConf into SVDPlusPlus object itself.
Diffstat (limited to 'graphx')
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala30
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala2
2 files changed, 17 insertions, 15 deletions
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala
index fa6b1db29b..ba6517e012 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala
@@ -5,19 +5,21 @@ import org.apache.commons.math.linear._
import org.apache.spark.rdd._
import org.apache.spark.graphx._
-/** Configuration parameters for SVDPlusPlus. */
-class SVDPlusPlusConf(
- var rank: Int,
- var maxIters: Int,
- var minVal: Double,
- var maxVal: Double,
- var gamma1: Double,
- var gamma2: Double,
- var gamma6: Double,
- var gamma7: Double) extends Serializable
-
/** Implementation of SVD++ algorithm. */
object SVDPlusPlus {
+
+ /** Configuration parameters for SVDPlusPlus. */
+ class Conf(
+ var rank: Int,
+ var maxIters: Int,
+ var minVal: Double,
+ var maxVal: Double,
+ var gamma1: Double,
+ var gamma2: Double,
+ var gamma6: Double,
+ var gamma7: Double)
+ extends Serializable
+
/**
* Implement SVD++ based on "Factorization Meets the Neighborhood:
* a Multifaceted Collaborative Filtering Model",
@@ -32,7 +34,7 @@ object SVDPlusPlus {
*
* @return a graph with vertex attributes containing the trained model
*/
- def run(edges: RDD[Edge[Double]], conf: SVDPlusPlusConf)
+ def run(edges: RDD[Edge[Double]], conf: Conf)
: (Graph[(RealVector, RealVector, Double, Double), Double], Double) =
{
// Generate default vertex attribute
@@ -64,7 +66,7 @@ object SVDPlusPlus {
(vd._1, vd._2, msg.get._2 / msg.get._1, 1.0 / scala.math.sqrt(msg.get._1))
}
- def mapTrainF(conf: SVDPlusPlusConf, u: Double)
+ def mapTrainF(conf: Conf, u: Double)
(et: EdgeTriplet[(RealVector, RealVector, Double, Double), Double])
: Iterator[(VertexID, (RealVector, RealVector, Double))] = {
val (usr, itm) = (et.srcAttr, et.dstAttr)
@@ -112,7 +114,7 @@ object SVDPlusPlus {
}
// calculate error on training set
- def mapTestF(conf: SVDPlusPlusConf, u: Double)
+ def mapTestF(conf: Conf, u: Double)
(et: EdgeTriplet[(RealVector, RealVector, Double, Double), Double])
: Iterator[(VertexID, Double)] =
{
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala
index a4a1cdab18..e173c652a5 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala
@@ -18,7 +18,7 @@ class SVDPlusPlusSuite extends FunSuite with LocalSparkContext {
val fields = line.split(",")
Edge(fields(0).toLong * 2, fields(1).toLong * 2 + 1, fields(2).toDouble)
}
- val conf = new SVDPlusPlusConf(10, 2, 0.0, 5.0, 0.007, 0.007, 0.005, 0.015) // 2 iterations
+ val conf = new SVDPlusPlus.Conf(10, 2, 0.0, 5.0, 0.007, 0.007, 0.005, 0.015) // 2 iterations
var (graph, u) = SVDPlusPlus.run(edges, conf)
graph.cache()
val err = graph.vertices.collect.map{ case (vid, vd) =>