aboutsummaryrefslogtreecommitdiff
path: root/graph/src
diff options
context:
space:
mode:
authorWang Jianping J <jianping.j.wang@gmail.com>2013-12-20 16:57:24 +0800
committerWang Jianping J <jianping.j.wang@gmail.com>2013-12-20 16:57:24 +0800
commit343d8977aa7d53f381b014778fb60106f9cbcabb (patch)
tree06f1e68fa214a698b690a192fe17a27d7f1d2272 /graph/src
parentda301b57fc7f606e2b8fd0acaf95aa1bd9b643b0 (diff)
downloadspark-343d8977aa7d53f381b014778fb60106f9cbcabb.tar.gz
spark-343d8977aa7d53f381b014778fb60106f9cbcabb.tar.bz2
spark-343d8977aa7d53f381b014778fb60106f9cbcabb.zip
remove unused variable and fix a bug
Diffstat (limited to 'graph/src')
-rw-r--r--graph/src/main/scala/org/apache/spark/graph/algorithms/Svdpp.scala15
1 files changed, 6 insertions, 9 deletions
diff --git a/graph/src/main/scala/org/apache/spark/graph/algorithms/Svdpp.scala b/graph/src/main/scala/org/apache/spark/graph/algorithms/Svdpp.scala
index 4ddf0b1fd5..ffd0ddba7e 100644
--- a/graph/src/main/scala/org/apache/spark/graph/algorithms/Svdpp.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/algorithms/Svdpp.scala
@@ -21,6 +21,7 @@ class Msg ( // message
object Svdpp {
// implement SVD++ based on http://public.research.att.com/~volinsky/netflix/kdd08koren.pdf
+ // model (15) on page 6
def run(edges: RDD[Edge[Double]]): Graph[VT, Double] = {
// defalut parameters
@@ -91,15 +92,13 @@ object Svdpp {
assert(et.srcAttr != null && et.dstAttr != null)
val usr = et.srcAttr
val itm = et.dstAttr
- var p = usr.v1
- var q = itm.v1
- val itmBias = 0.0
- val usrBias = 0.0
+ val p = usr.v1
+ val q = itm.v1
var pred = u + usr.bias + itm.bias + q.dotProduct(usr.v2)
pred = math.max(pred, minVal)
pred = math.min(pred, maxVal)
val err = et.attr - pred
- val y = (q.mapMultiply(err*usr.norm)).subtract((usr.v2).mapMultiply(gamma7))
+ val y = (q.mapMultiply(err*usr.norm)).subtract((itm.v2).mapMultiply(gamma7))
val newP = (q.mapMultiply(err)).subtract(p.mapMultiply(gamma7)) // for each connected item q
val newQ = (usr.v2.mapMultiply(err)).subtract(q.mapMultiply(gamma7))
Iterator((et.srcId, new Msg(newP, y, err - gamma6*usr.bias)), (et.dstId, new Msg(newQ, y, err - gamma6*itm.bias)))
@@ -135,10 +134,8 @@ object Svdpp {
assert(et.srcAttr != null && et.dstAttr != null)
val usr = et.srcAttr
val itm = et.dstAttr
- var p = usr.v1
- var q = itm.v1
- val itmBias = 0.0
- val usrBias = 0.0
+ val p = usr.v1
+ val q = itm.v1
var pred = u + usr.bias + itm.bias + q.dotProduct(usr.v2)
pred = math.max(pred, minVal)
pred = math.min(pred, maxVal)