aboutsummaryrefslogtreecommitdiff
path: root/graph/src
diff options
context:
space:
mode:
authorReynold Xin <rxin@apache.org>2013-12-22 11:44:18 -0800
committerReynold Xin <rxin@apache.org>2013-12-22 11:44:18 -0800
commit44e4205ac579a9a4dfb2f6041d34caea568059ce (patch)
treef75bae17356cf91daea7c1bc294e85c2adb82137 /graph/src
parent4797c227ff7aafcc1e4dcbbaa5281b55361484e6 (diff)
parente64a794a4417f614e1b74180a123f5f913a6db53 (diff)
downloadspark-44e4205ac579a9a4dfb2f6041d34caea568059ce.tar.gz
spark-44e4205ac579a9a4dfb2f6041d34caea568059ce.tar.bz2
spark-44e4205ac579a9a4dfb2f6041d34caea568059ce.zip
Merge pull request #116 from jianpingjwang/master
remove unused variables and fix a bug
Diffstat (limited to 'graph/src')
-rw-r--r--graph/src/main/scala/org/apache/spark/graph/algorithms/Svdpp.scala63
1 files changed, 32 insertions, 31 deletions
diff --git a/graph/src/main/scala/org/apache/spark/graph/algorithms/Svdpp.scala b/graph/src/main/scala/org/apache/spark/graph/algorithms/Svdpp.scala
index 4ddf0b1fd5..26b999f4cf 100644
--- a/graph/src/main/scala/org/apache/spark/graph/algorithms/Svdpp.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/algorithms/Svdpp.scala
@@ -1,6 +1,5 @@
package org.apache.spark.graph.algorithms
-import org.apache.spark._
import org.apache.spark.rdd._
import org.apache.spark.graph._
import scala.util.Random
@@ -10,7 +9,7 @@ class VT ( // vertex type
var v1: RealVector, // v1: p for user node, q for item node
var v2: RealVector, // v2: pu + |N(u)|^(-0.5)*sum(y) for user node, y for item node
var bias: Double,
- var norm: Double // only for user node
+ var norm: Double // |N(u)|^(-0.5) for user node
) extends Serializable
class Msg ( // message
@@ -20,7 +19,15 @@ class Msg ( // message
) extends Serializable
object Svdpp {
- // implement SVD++ based on http://public.research.att.com/~volinsky/netflix/kdd08koren.pdf
+ /**
+ * Implement SVD++ based on "Factorization Meets the Neighborhood: a Multifaceted Collaborative Filtering Model",
+ * paper is available at [[http://public.research.att.com/~volinsky/netflix/kdd08koren.pdf]].
+ * The prediction rule is rui = u + bu + bi + qi*(pu + |N(u)|^(-0.5)*sum(y)), see the details on page 6.
+ *
+ * @param edges edges for constructing the graph
+ *
+ * @return a graph with vertex attributes containing the trained model
+ */
def run(edges: RDD[Edge[Double]]): Graph[VT, Double] = {
// defalut parameters
@@ -32,7 +39,8 @@ object Svdpp {
val gamma2 = 0.007
val gamma6 = 0.005
val gamma7 = 0.015
-
+
+ // generate default vertex attribute
def defaultF(rank: Int) = {
val v1 = new ArrayRealVector(rank)
val v2 = new ArrayRealVector(rank)
@@ -44,7 +52,7 @@ object Svdpp {
vd
}
- // calculate initial norm and bias
+ // calculate initial bias and norm
def mapF0(et: EdgeTriplet[VT, Double]): Iterator[(Vid, (Long, Double))] = {
assert(et.srcAttr != null && et.dstAttr != null)
Iterator((et.srcId, (1L, et.attr)), (et.dstId, (1L, et.attr)))
@@ -67,10 +75,10 @@ object Svdpp {
// make graph
var g = Graph.fromEdges(edges, defaultF(rank)).cache()
- // calculate initial norm and bias
+ // calculate initial bias and norm
val t0 = g.mapReduceTriplets(mapF0, reduceF0)
- g.outerJoinVertices(t0) {updateF0}
-
+ g.outerJoinVertices(t0) {updateF0}
+
// phase 1
def mapF1(et: EdgeTriplet[VT, Double]): Iterator[(Vid, RealVector)] = {
assert(et.srcAttr != null && et.dstAttr != null)
@@ -89,21 +97,18 @@ object Svdpp {
// phase 2
def mapF2(et: EdgeTriplet[VT, Double]): Iterator[(Vid, Msg)] = {
assert(et.srcAttr != null && et.dstAttr != null)
- val usr = et.srcAttr
- val itm = et.dstAttr
- var p = usr.v1
- var q = itm.v1
- val itmBias = 0.0
- val usrBias = 0.0
+ val (usr, itm) = (et.srcAttr, et.dstAttr)
+ val (p, q) = (usr.v1, itm.v1)
var pred = u + usr.bias + itm.bias + q.dotProduct(usr.v2)
pred = math.max(pred, minVal)
pred = math.min(pred, maxVal)
val err = et.attr - pred
- val y = (q.mapMultiply(err*usr.norm)).subtract((usr.v2).mapMultiply(gamma7))
- val newP = (q.mapMultiply(err)).subtract(p.mapMultiply(gamma7)) // for each connected item q
- val newQ = (usr.v2.mapMultiply(err)).subtract(q.mapMultiply(gamma7))
- Iterator((et.srcId, new Msg(newP, y, err - gamma6*usr.bias)), (et.dstId, new Msg(newQ, y, err - gamma6*itm.bias)))
- }
+ val updateP = (q.mapMultiply(err)).subtract(p.mapMultiply(gamma7))
+ val updateQ = (usr.v2.mapMultiply(err)).subtract(q.mapMultiply(gamma7))
+ val updateY = (q.mapMultiply(err*usr.norm)).subtract((itm.v2).mapMultiply(gamma7))
+ Iterator((et.srcId, new Msg(updateP, updateY, err - gamma6*usr.bias)),
+ (et.dstId, new Msg(updateQ, updateY, err - gamma6*itm.bias)))
+ }
def reduceF2(g1: Msg, g2: Msg):Msg = {
g1.v1 = g1.v1.add(g2.v1)
g1.v2 = g1.v2.add(g2.v2)
@@ -113,7 +118,7 @@ object Svdpp {
def updateF2(vid: Vid, vd: VT, msg: Option[Msg]) = {
if (msg.isDefined) {
vd.v1 = vd.v1.add(msg.get.v1.mapMultiply(gamma2))
- if (vid % 2 == 1) { // item node update y
+ if (vid % 2 == 1) { // item nodes update y
vd.v2 = vd.v2.add(msg.get.v2.mapMultiply(gamma2))
}
vd.bias += msg.get.bias*gamma1
@@ -122,23 +127,19 @@ object Svdpp {
}
for (i <- 0 until maxIters) {
- // phase 1
+ // phase 1, calculate v2 for user nodes
val t1: VertexRDD[RealVector] = g.mapReduceTriplets(mapF1, reduceF1)
- g.outerJoinVertices(t1) {updateF1}
- // phase 2
+ g.outerJoinVertices(t1) {updateF1}
+ // phase 2, update p for user nodes and q, y for item nodes
val t2: VertexRDD[Msg] = g.mapReduceTriplets(mapF2, reduceF2)
g.outerJoinVertices(t2) {updateF2}
}
-
+
// calculate error on training set
def mapF3(et: EdgeTriplet[VT, Double]): Iterator[(Vid, Double)] = {
assert(et.srcAttr != null && et.dstAttr != null)
- val usr = et.srcAttr
- val itm = et.dstAttr
- var p = usr.v1
- var q = itm.v1
- val itmBias = 0.0
- val usrBias = 0.0
+ val (usr, itm) = (et.srcAttr, et.dstAttr)
+ val (p, q) = (usr.v1, itm.v1)
var pred = u + usr.bias + itm.bias + q.dotProduct(usr.v2)
pred = math.max(pred, minVal)
pred = math.min(pred, maxVal)
@@ -146,7 +147,7 @@ object Svdpp {
Iterator((et.dstId, err))
}
def updateF3(vid: Vid, vd: VT, msg: Option[Double]) = {
- if (msg.isDefined && vid % 2 == 1) { // item sum up the errors
+ if (msg.isDefined && vid % 2 == 1) { // item nodes sum up the errors
vd.norm = msg.get
}
vd