diff options
author | Wang Jianping J <jianping.j.wang@gmail.com> | 2013-12-21 15:51:05 +0800 |
---|---|---|
committer | Wang Jianping J <jianping.j.wang@gmail.com> | 2013-12-21 15:51:05 +0800 |
commit | 49eb0f1351fc73bcb44fffbd30d955a95bd88fbe (patch) | |
tree | 8dc711b15d88b2994b4dc75f3421d497a2a36709 /graph/src | |
parent | f986e4a13662724f0ff8a31a46133616aa2ca1e0 (diff) | |
download | spark-49eb0f1351fc73bcb44fffbd30d955a95bd88fbe.tar.gz spark-49eb0f1351fc73bcb44fffbd30d955a95bd88fbe.tar.bz2 spark-49eb0f1351fc73bcb44fffbd30d955a95bd88fbe.zip |
Update Svdpp.scala
Diffstat (limited to 'graph/src')
-rw-r--r-- | graph/src/main/scala/org/apache/spark/graph/algorithms/Svdpp.scala | 14 |
1 files changed, 5 insertions, 9 deletions
diff --git a/graph/src/main/scala/org/apache/spark/graph/algorithms/Svdpp.scala b/graph/src/main/scala/org/apache/spark/graph/algorithms/Svdpp.scala index fe0093c4a4..ac20f15072 100644 --- a/graph/src/main/scala/org/apache/spark/graph/algorithms/Svdpp.scala +++ b/graph/src/main/scala/org/apache/spark/graph/algorithms/Svdpp.scala @@ -98,17 +98,15 @@ object Svdpp { // phase 2 def mapF2(et: EdgeTriplet[VT, Double]): Iterator[(Vid, Msg)] = { assert(et.srcAttr != null && et.dstAttr != null) - val usr = et.srcAttr - val itm = et.dstAttr - val p = usr.v1 - val q = itm.v1 + val (usr, itm) = (et.srcAttr, et.dstAttr) + val (p, q) = (usr.v1, itm.v1) var pred = u + usr.bias + itm.bias + q.dotProduct(usr.v2) pred = math.max(pred, minVal) pred = math.min(pred, maxVal) val err = et.attr - pred - val updateY = (q.mapMultiply(err*usr.norm)).subtract((itm.v2).mapMultiply(gamma7)) val updateP = (q.mapMultiply(err)).subtract(p.mapMultiply(gamma7)) val updateQ = (usr.v2.mapMultiply(err)).subtract(q.mapMultiply(gamma7)) + val updateY = (q.mapMultiply(err*usr.norm)).subtract((itm.v2).mapMultiply(gamma7)) Iterator((et.srcId, new Msg(updateP, updateY, err - gamma6*usr.bias)), (et.dstId, new Msg(updateQ, updateY, err - gamma6*itm.bias))) } def reduceF2(g1: Msg, g2: Msg):Msg = { @@ -140,10 +138,8 @@ object Svdpp { // calculate error on training set def mapF3(et: EdgeTriplet[VT, Double]): Iterator[(Vid, Double)] = { assert(et.srcAttr != null && et.dstAttr != null) - val usr = et.srcAttr - val itm = et.dstAttr - val p = usr.v1 - val q = itm.v1 + val (usr, item) = (et.srcAttr, et.dstAttr) + val (p, q) = (usr.v1, itm.v1) var pred = u + usr.bias + itm.bias + q.dotProduct(usr.v2) pred = math.max(pred, minVal) pred = math.min(pred, maxVal) |