From e94fe39d0f89ba102b1ebadf8becf99c051eb58d Mon Sep 17 00:00:00 2001 From: Wang Jianping J Date: Wed, 18 Dec 2013 06:39:28 +0800 Subject: Update Svdpp.scala --- .../org/apache/spark/graph/algorithms/Svdpp.scala | 27 +++++++++++----------- 1 file changed, 14 insertions(+), 13 deletions(-) (limited to 'graph') diff --git a/graph/src/main/scala/org/apache/spark/graph/algorithms/Svdpp.scala b/graph/src/main/scala/org/apache/spark/graph/algorithms/Svdpp.scala index a0f025d708..ef266bb551 100644 --- a/graph/src/main/scala/org/apache/spark/graph/algorithms/Svdpp.scala +++ b/graph/src/main/scala/org/apache/spark/graph/algorithms/Svdpp.scala @@ -6,14 +6,14 @@ import org.apache.spark.graph._ import scala.util.Random import org.apache.commons.math.linear._ -class VD ( +class VT ( // vertex type var v1: RealVector, // v1: p for user node, q for item node var v2: RealVector, // v2: pu + |N(u)|^(-0.5)*sum(y) for user node, y for item node var bias: Double, var norm: Double // only for user node ) extends Serializable -class Msg ( +class Msg ( // message var v1: RealVector, var v2: RealVector, var bias: Double @@ -22,7 +22,7 @@ class Msg ( object Svdpp { // implement SVD++ based on http://public.research.att.com/~volinsky/netflix/kdd08koren.pdf - def run(edges: RDD[Edge[Double]]): Graph[VD,Double] = { + def run(edges: RDD[Edge[Double]]): Graph[VT, Double] = { // defalut parameters val rank = 10 val maxIters = 20 @@ -40,19 +40,19 @@ object Svdpp { v1.setEntry(i, Random.nextDouble) v2.setEntry(i, Random.nextDouble) } - var vd = new VD(v1, v2, 0.0, 0.0) + var vd = new VT(v1, v2, 0.0, 0.0) vd } // calculate initial norm and bias - def mapF0(et: EdgeTriplet[VD, Double]): Iterator[(Vid, (Long, Double))] = { + def mapF0(et: EdgeTriplet[VT, Double]): Iterator[(Vid, (Long, Double))] = { assert(et.srcAttr != null && et.dstAttr != null) Iterator((et.srcId, (1L, et.attr)), (et.dstId, (1L, et.attr))) } def reduceF0(g1: (Long, Double), g2: (Long, Double)) = { (g1._1 + g2._1, g1._2 + g2._2) } - def updateF0(vid: Vid, vd: VD, msg: Option[(Long, Double)]) = { + def updateF0(vid: Vid, vd: VT, msg: Option[(Long, Double)]) = { if (msg.isDefined) { vd.bias = msg.get._2 / msg.get._1 vd.norm = 1.0 / scala.math.sqrt(msg.get._1) @@ -65,21 +65,21 @@ object Svdpp { val u = rs / rc // global rating mean // make graph - var g = Graph.fromEdges(edges, defaultF(rank)).cache() + var g = Graph.fromEdges(edges, defaultF(rank), RandomVertexCut).cache() // calculate initial norm and bias val t0 = g.mapReduceTriplets(mapF0, reduceF0) g.outerJoinVertices(t0) {updateF0} // phase 1 - def mapF1(et: EdgeTriplet[VD, Double]): Iterator[(Vid, RealVector)] = { + def mapF1(et: EdgeTriplet[VT, Double]): Iterator[(Vid, RealVector)] = { assert(et.srcAttr != null && et.dstAttr != null) Iterator((et.srcId, et.dstAttr.v2)) // sum up y of connected item nodes } def reduceF1(g1: RealVector, g2: RealVector) = { g1.add(g2) } - def updateF1(vid: Vid, vd: VD, msg: Option[RealVector]) = { + def updateF1(vid: Vid, vd: VT, msg: Option[RealVector]) = { if (msg.isDefined) { vd.v2 = vd.v1.add(msg.get.mapMultiply(vd.norm)) // pu + |N(u)|^(-0.5)*sum(y) } @@ -87,7 +87,7 @@ object Svdpp { } // phase 2 - def mapF2(et: EdgeTriplet[VD, Double]): Iterator[(Vid, Msg)] = { + def mapF2(et: EdgeTriplet[VT, Double]): Iterator[(Vid, Msg)] = { assert(et.srcAttr != null && et.dstAttr != null) val usr = et.srcAttr val itm = et.dstAttr @@ -96,6 +96,7 @@ object Svdpp { val itmBias = 0.0 val usrBias = 0.0 var pred = u + usr.bias + itm.bias + q.dotProduct(usr.v2) + println(pred) pred = math.max(pred, minVal) pred = math.min(pred, maxVal) val err = et.attr - pred @@ -110,7 +111,7 @@ object Svdpp { g1.bias += g2.bias g1 } - def updateF2(vid: Vid, vd: VD, msg: Option[Msg]) = { + def updateF2(vid: Vid, vd: VT, msg: Option[Msg]) = { if (msg.isDefined) { vd.v1 = vd.v1.add(msg.get.v1.mapMultiply(gamma2)) if (vid % 2 == 1) { // item node update y @@ -131,7 +132,7 @@ object Svdpp { } // calculate error on training set - def mapF3(et: EdgeTriplet[VD, Double]): Iterator[(Vid, Double)] = { + def mapF3(et: EdgeTriplet[VT, Double]): Iterator[(Vid, Double)] = { assert(et.srcAttr != null && et.dstAttr != null) val usr = et.srcAttr val itm = et.dstAttr @@ -145,7 +146,7 @@ object Svdpp { val err = (et.attr - pred)*(et.attr - pred) Iterator((et.dstId, err)) } - def updateF3(vid: Vid, vd: VD, msg: Option[Double]) = { + def updateF3(vid: Vid, vd: VT, msg: Option[Double]) = { if (msg.isDefined && vid % 2 == 1) { // item sum up the errors vd.norm = msg.get } -- cgit v1.2.3