aboutsummaryrefslogtreecommitdiff
path: root/graph/src
diff options
context:
space:
mode:
authorWang Jianping J <jianping.j.wang@gmail.com>2013-12-21 15:51:05 +0800
committerWang Jianping J <jianping.j.wang@gmail.com>2013-12-21 15:51:05 +0800
commit49eb0f1351fc73bcb44fffbd30d955a95bd88fbe (patch)
tree8dc711b15d88b2994b4dc75f3421d497a2a36709 /graph/src
parentf986e4a13662724f0ff8a31a46133616aa2ca1e0 (diff)
downloadspark-49eb0f1351fc73bcb44fffbd30d955a95bd88fbe.tar.gz
spark-49eb0f1351fc73bcb44fffbd30d955a95bd88fbe.tar.bz2
spark-49eb0f1351fc73bcb44fffbd30d955a95bd88fbe.zip
Update Svdpp.scala
Diffstat (limited to 'graph/src')
-rw-r--r--graph/src/main/scala/org/apache/spark/graph/algorithms/Svdpp.scala14
1 files changed, 5 insertions, 9 deletions
diff --git a/graph/src/main/scala/org/apache/spark/graph/algorithms/Svdpp.scala b/graph/src/main/scala/org/apache/spark/graph/algorithms/Svdpp.scala
index fe0093c4a4..ac20f15072 100644
--- a/graph/src/main/scala/org/apache/spark/graph/algorithms/Svdpp.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/algorithms/Svdpp.scala
@@ -98,17 +98,15 @@ object Svdpp {
// phase 2
def mapF2(et: EdgeTriplet[VT, Double]): Iterator[(Vid, Msg)] = {
assert(et.srcAttr != null && et.dstAttr != null)
- val usr = et.srcAttr
- val itm = et.dstAttr
- val p = usr.v1
- val q = itm.v1
+ val (usr, itm) = (et.srcAttr, et.dstAttr)
+ val (p, q) = (usr.v1, itm.v1)
var pred = u + usr.bias + itm.bias + q.dotProduct(usr.v2)
pred = math.max(pred, minVal)
pred = math.min(pred, maxVal)
val err = et.attr - pred
- val updateY = (q.mapMultiply(err*usr.norm)).subtract((itm.v2).mapMultiply(gamma7))
val updateP = (q.mapMultiply(err)).subtract(p.mapMultiply(gamma7))
val updateQ = (usr.v2.mapMultiply(err)).subtract(q.mapMultiply(gamma7))
+ val updateY = (q.mapMultiply(err*usr.norm)).subtract((itm.v2).mapMultiply(gamma7))
Iterator((et.srcId, new Msg(updateP, updateY, err - gamma6*usr.bias)), (et.dstId, new Msg(updateQ, updateY, err - gamma6*itm.bias)))
}
def reduceF2(g1: Msg, g2: Msg):Msg = {
@@ -140,10 +138,8 @@ object Svdpp {
// calculate error on training set
def mapF3(et: EdgeTriplet[VT, Double]): Iterator[(Vid, Double)] = {
assert(et.srcAttr != null && et.dstAttr != null)
- val usr = et.srcAttr
- val itm = et.dstAttr
- val p = usr.v1
- val q = itm.v1
+ val (usr, item) = (et.srcAttr, et.dstAttr)
+ val (p, q) = (usr.v1, itm.v1)
var pred = u + usr.bias + itm.bias + q.dotProduct(usr.v2)
pred = math.max(pred, minVal)
pred = math.min(pred, maxVal)