aboutsummaryrefslogtreecommitdiff
path: root/graph/src
diff options
context:
space:
mode:
authorJianping J Wang <jianping.j.wang@gmail.com>2013-12-30 23:41:15 +0800
committerJianping J Wang <jianping.j.wang@gmail.com>2013-12-30 23:41:15 +0800
commit29fe6bdaa29193c9dbf3a8fbd05094f3d812d4e5 (patch)
tree636689f31a3e07a55719238378722b8971b25305 /graph/src
parent44e4205ac579a9a4dfb2f6041d34caea568059ce (diff)
downloadspark-29fe6bdaa29193c9dbf3a8fbd05094f3d812d4e5.tar.gz
spark-29fe6bdaa29193c9dbf3a8fbd05094f3d812d4e5.tar.bz2
spark-29fe6bdaa29193c9dbf3a8fbd05094f3d812d4e5.zip
refactor and bug fix
Diffstat (limited to 'graph/src')
-rw-r--r--graph/src/main/scala/org/apache/spark/graph/algorithms/Svdpp.scala155
1 files changed, 64 insertions, 91 deletions
diff --git a/graph/src/main/scala/org/apache/spark/graph/algorithms/Svdpp.scala b/graph/src/main/scala/org/apache/spark/graph/algorithms/Svdpp.scala
index 26b999f4cf..cbbe240c90 100644
--- a/graph/src/main/scala/org/apache/spark/graph/algorithms/Svdpp.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/algorithms/Svdpp.scala
@@ -5,18 +5,27 @@ import org.apache.spark.graph._
import scala.util.Random
import org.apache.commons.math.linear._
-class VT ( // vertex type
+class VT( // vertex type
var v1: RealVector, // v1: p for user node, q for item node
var v2: RealVector, // v2: pu + |N(u)|^(-0.5)*sum(y) for user node, y for item node
var bias: Double,
var norm: Double // |N(u)|^(-0.5) for user node
-) extends Serializable
+ ) extends Serializable
-class Msg ( // message
+class Msg( // message
var v1: RealVector,
var v2: RealVector,
- var bias: Double
-) extends Serializable
+ var bias: Double) extends Serializable
+
+class SvdppConf( // Svdpp parameters
+ var rank: Int,
+ var maxIters: Int,
+ var minVal: Double,
+ var maxVal: Double,
+ var gamma1: Double,
+ var gamma2: Double,
+ var gamma6: Double,
+ var gamma7: Double) extends Serializable
object Svdpp {
/**
@@ -24,21 +33,14 @@ object Svdpp {
* paper is available at [[http://public.research.att.com/~volinsky/netflix/kdd08koren.pdf]].
* The prediction rule is rui = u + bu + bi + qi*(pu + |N(u)|^(-0.5)*sum(y)), see the details on page 6.
*
- * @param edges edges for constructing the graph
+ * @param edges edges for constructing the graph
+ *
+ * @param conf Svdpp parameters
*
* @return a graph with vertex attributes containing the trained model
*/
- def run(edges: RDD[Edge[Double]]): Graph[VT, Double] = {
- // defalut parameters
- val rank = 10
- val maxIters = 20
- val minVal = 0.0
- val maxVal = 5.0
- val gamma1 = 0.007
- val gamma2 = 0.007
- val gamma6 = 0.005
- val gamma7 = 0.015
+ def run(edges: RDD[Edge[Double]], conf: SvdppConf): Graph[VT, Double] = {
// generate default vertex attribute
def defaultF(rank: Int) = {
@@ -52,108 +54,79 @@ object Svdpp {
vd
}
- // calculate initial bias and norm
- def mapF0(et: EdgeTriplet[VT, Double]): Iterator[(Vid, (Long, Double))] = {
- assert(et.srcAttr != null && et.dstAttr != null)
- Iterator((et.srcId, (1L, et.attr)), (et.dstId, (1L, et.attr)))
- }
- def reduceF0(g1: (Long, Double), g2: (Long, Double)) = {
- (g1._1 + g2._1, g1._2 + g2._2)
- }
- def updateF0(vid: Vid, vd: VT, msg: Option[(Long, Double)]) = {
- if (msg.isDefined) {
- vd.bias = msg.get._2 / msg.get._1
- vd.norm = 1.0 / scala.math.sqrt(msg.get._1)
- }
- vd
- }
-
// calculate global rating mean
val (rs, rc) = edges.map(e => (e.attr, 1L)).reduce((a, b) => (a._1 + b._1, a._2 + b._2))
val u = rs / rc // global rating mean
- // make graph
- var g = Graph.fromEdges(edges, defaultF(rank)).cache()
+ // construct graph
+ var g = Graph.fromEdges(edges, defaultF(conf.rank)).cache()
// calculate initial bias and norm
- val t0 = g.mapReduceTriplets(mapF0, reduceF0)
- g.outerJoinVertices(t0) {updateF0}
-
- // phase 1
- def mapF1(et: EdgeTriplet[VT, Double]): Iterator[(Vid, RealVector)] = {
- assert(et.srcAttr != null && et.dstAttr != null)
- Iterator((et.srcId, et.dstAttr.v2)) // sum up y of connected item nodes
- }
- def reduceF1(g1: RealVector, g2: RealVector) = {
- g1.add(g2)
- }
- def updateF1(vid: Vid, vd: VT, msg: Option[RealVector]) = {
- if (msg.isDefined) {
- vd.v2 = vd.v1.add(msg.get.mapMultiply(vd.norm)) // pu + |N(u)|^(-0.5)*sum(y)
- }
- vd
+ var t0: VertexRDD[(Long, Double)] = g.mapReduceTriplets(et =>
+ Iterator((et.srcId, (1L, et.attr)), (et.dstId, (1L, et.attr))),
+ (g1: (Long, Double), g2: (Long, Double)) =>
+ (g1._1 + g2._1, g1._2 + g2._2))
+ g = g.outerJoinVertices(t0) {
+ (vid: Vid, vd: VT, msg: Option[(Long, Double)]) =>
+ vd.bias = msg.get._2 / msg.get._1; vd.norm = 1.0 / scala.math.sqrt(msg.get._1)
+ vd
}
- // phase 2
- def mapF2(et: EdgeTriplet[VT, Double]): Iterator[(Vid, Msg)] = {
+ def mapTrainF(conf: SvdppConf, u: Double)(et: EdgeTriplet[VT, Double]): Iterator[(Vid, Msg)] = {
assert(et.srcAttr != null && et.dstAttr != null)
val (usr, itm) = (et.srcAttr, et.dstAttr)
val (p, q) = (usr.v1, itm.v1)
var pred = u + usr.bias + itm.bias + q.dotProduct(usr.v2)
- pred = math.max(pred, minVal)
- pred = math.min(pred, maxVal)
+ pred = math.max(pred, conf.minVal)
+ pred = math.min(pred, conf.maxVal)
val err = et.attr - pred
- val updateP = (q.mapMultiply(err)).subtract(p.mapMultiply(gamma7))
- val updateQ = (usr.v2.mapMultiply(err)).subtract(q.mapMultiply(gamma7))
- val updateY = (q.mapMultiply(err*usr.norm)).subtract((itm.v2).mapMultiply(gamma7))
- Iterator((et.srcId, new Msg(updateP, updateY, err - gamma6*usr.bias)),
- (et.dstId, new Msg(updateQ, updateY, err - gamma6*itm.bias)))
- }
- def reduceF2(g1: Msg, g2: Msg):Msg = {
- g1.v1 = g1.v1.add(g2.v1)
- g1.v2 = g1.v2.add(g2.v2)
- g1.bias += g2.bias
- g1
- }
- def updateF2(vid: Vid, vd: VT, msg: Option[Msg]) = {
- if (msg.isDefined) {
- vd.v1 = vd.v1.add(msg.get.v1.mapMultiply(gamma2))
- if (vid % 2 == 1) { // item nodes update y
- vd.v2 = vd.v2.add(msg.get.v2.mapMultiply(gamma2))
- }
- vd.bias += msg.get.bias*gamma1
- }
- vd
+ val updateP = ((q.mapMultiply(err)).subtract(p.mapMultiply(conf.gamma7))).mapMultiply(conf.gamma2)
+ val updateQ = ((usr.v2.mapMultiply(err)).subtract(q.mapMultiply(conf.gamma7))).mapMultiply(conf.gamma2)
+ val updateY = ((q.mapMultiply(err * usr.norm)).subtract((itm.v2).mapMultiply(conf.gamma7))).mapMultiply(conf.gamma2)
+ Iterator((et.srcId, new Msg(updateP, updateY, (err - conf.gamma6 * usr.bias) * conf.gamma1)),
+ (et.dstId, new Msg(updateQ, updateY, (err - conf.gamma6 * itm.bias) * conf.gamma1)))
}
- for (i <- 0 until maxIters) {
+ for (i <- 0 until conf.maxIters) {
// phase 1, calculate v2 for user nodes
- val t1: VertexRDD[RealVector] = g.mapReduceTriplets(mapF1, reduceF1)
- g.outerJoinVertices(t1) {updateF1}
+ var t1 = g.mapReduceTriplets(et =>
+ Iterator((et.srcId, et.dstAttr.v2)),
+ (g1: RealVector, g2: RealVector) => g1.add(g2))
+ g = g.outerJoinVertices(t1) { (vid: Vid, vd: VT, msg: Option[RealVector]) =>
+ if (msg.isDefined) vd.v2 = vd.v1.add(msg.get.mapMultiply(vd.norm))
+ vd
+ }
// phase 2, update p for user nodes and q, y for item nodes
- val t2: VertexRDD[Msg] = g.mapReduceTriplets(mapF2, reduceF2)
- g.outerJoinVertices(t2) {updateF2}
+ val t2: VertexRDD[Msg] = g.mapReduceTriplets(mapTrainF(conf, u), (g1: Msg, g2: Msg) => {
+ g1.v1 = g1.v1.add(g2.v1)
+ g1.v2 = g1.v2.add(g2.v2)
+ g1.bias += g2.bias
+ g1
+ })
+ g = g.outerJoinVertices(t2) { (vid: Vid, vd: VT, msg: Option[Msg]) =>
+ vd.v1 = vd.v1.add(msg.get.v1)
+ if (vid % 2 == 1) vd.v2 = vd.v2.add(msg.get.v2)
+ vd.bias += msg.get.bias
+ vd
+ }
}
// calculate error on training set
- def mapF3(et: EdgeTriplet[VT, Double]): Iterator[(Vid, Double)] = {
+ def mapTestF(conf: SvdppConf, u: Double)(et: EdgeTriplet[VT, Double]): Iterator[(Vid, Double)] = {
assert(et.srcAttr != null && et.dstAttr != null)
val (usr, itm) = (et.srcAttr, et.dstAttr)
val (p, q) = (usr.v1, itm.v1)
var pred = u + usr.bias + itm.bias + q.dotProduct(usr.v2)
- pred = math.max(pred, minVal)
- pred = math.min(pred, maxVal)
- val err = (et.attr - pred)*(et.attr - pred)
+ pred = math.max(pred, conf.minVal)
+ pred = math.min(pred, conf.maxVal)
+ val err = (et.attr - pred) * (et.attr - pred)
Iterator((et.dstId, err))
}
- def updateF3(vid: Vid, vd: VT, msg: Option[Double]) = {
- if (msg.isDefined && vid % 2 == 1) { // item nodes sum up the errors
- vd.norm = msg.get
- }
+ val t3: VertexRDD[Double] = g.mapReduceTriplets(mapTestF(conf, u), _ + _)
+ g.outerJoinVertices(t3) { (vid: Vid, vd: VT, msg: Option[Double]) =>
+ if (msg.isDefined && vid % 2 == 1) vd.norm = msg.get // item nodes sum up the errors
vd
}
- val t3: VertexRDD[Double] = g.mapReduceTriplets(mapF3, _ + _)
- g.outerJoinVertices(t3) {updateF3}
- g
+ g
}
}