aboutsummaryrefslogtreecommitdiff
path: root/graph/src
diff options
context:
space:
mode:
authorJoey <joseph.e.gonzalez@gmail.com>2013-12-18 12:52:36 -0800
committerJoey <joseph.e.gonzalez@gmail.com>2013-12-18 12:52:36 -0800
commit3fd2e09ffb8718f347f9fa1fb057d8738ce73c80 (patch)
tree8bf233db61117e76af32e99faa8122a2ca364085 /graph/src
parent1b5eacbb28b1f0f56e1bdf9282e064af4198ba18 (diff)
parent06581b6a96713d61a61c4ad8eba34fa1e7ecff48 (diff)
downloadspark-3fd2e09ffb8718f347f9fa1fb057d8738ce73c80.tar.gz
spark-3fd2e09ffb8718f347f9fa1fb057d8738ce73c80.tar.bz2
spark-3fd2e09ffb8718f347f9fa1fb057d8738ce73c80.zip
Merge pull request #104 from jianpingjwang/master
SVD++ demo
Diffstat (limited to 'graph/src')
-rw-r--r--graph/src/main/scala/org/apache/spark/graph/algorithms/Svdpp.scala158
-rw-r--r--graph/src/test/scala/org/apache/spark/graph/AnalyticsSuite.scala15
2 files changed, 173 insertions, 0 deletions
diff --git a/graph/src/main/scala/org/apache/spark/graph/algorithms/Svdpp.scala b/graph/src/main/scala/org/apache/spark/graph/algorithms/Svdpp.scala
new file mode 100644
index 0000000000..4ddf0b1fd5
--- /dev/null
+++ b/graph/src/main/scala/org/apache/spark/graph/algorithms/Svdpp.scala
@@ -0,0 +1,158 @@
+package org.apache.spark.graph.algorithms
+
+import org.apache.spark._
+import org.apache.spark.rdd._
+import org.apache.spark.graph._
+import scala.util.Random
+import org.apache.commons.math.linear._
+
+class VT ( // vertex type
+ var v1: RealVector, // v1: p for user node, q for item node
+ var v2: RealVector, // v2: pu + |N(u)|^(-0.5)*sum(y) for user node, y for item node
+ var bias: Double,
+ var norm: Double // only for user node
+) extends Serializable
+
+class Msg ( // message
+ var v1: RealVector,
+ var v2: RealVector,
+ var bias: Double
+) extends Serializable
+
+object Svdpp {
+ // implement SVD++ based on http://public.research.att.com/~volinsky/netflix/kdd08koren.pdf
+
+ def run(edges: RDD[Edge[Double]]): Graph[VT, Double] = {
+ // defalut parameters
+ val rank = 10
+ val maxIters = 20
+ val minVal = 0.0
+ val maxVal = 5.0
+ val gamma1 = 0.007
+ val gamma2 = 0.007
+ val gamma6 = 0.005
+ val gamma7 = 0.015
+
+ def defaultF(rank: Int) = {
+ val v1 = new ArrayRealVector(rank)
+ val v2 = new ArrayRealVector(rank)
+ for (i <- 0 until rank) {
+ v1.setEntry(i, Random.nextDouble)
+ v2.setEntry(i, Random.nextDouble)
+ }
+ var vd = new VT(v1, v2, 0.0, 0.0)
+ vd
+ }
+
+ // calculate initial norm and bias
+ def mapF0(et: EdgeTriplet[VT, Double]): Iterator[(Vid, (Long, Double))] = {
+ assert(et.srcAttr != null && et.dstAttr != null)
+ Iterator((et.srcId, (1L, et.attr)), (et.dstId, (1L, et.attr)))
+ }
+ def reduceF0(g1: (Long, Double), g2: (Long, Double)) = {
+ (g1._1 + g2._1, g1._2 + g2._2)
+ }
+ def updateF0(vid: Vid, vd: VT, msg: Option[(Long, Double)]) = {
+ if (msg.isDefined) {
+ vd.bias = msg.get._2 / msg.get._1
+ vd.norm = 1.0 / scala.math.sqrt(msg.get._1)
+ }
+ vd
+ }
+
+ // calculate global rating mean
+ val (rs, rc) = edges.map(e => (e.attr, 1L)).reduce((a, b) => (a._1 + b._1, a._2 + b._2))
+ val u = rs / rc // global rating mean
+
+ // make graph
+ var g = Graph.fromEdges(edges, defaultF(rank)).cache()
+
+ // calculate initial norm and bias
+ val t0 = g.mapReduceTriplets(mapF0, reduceF0)
+ g.outerJoinVertices(t0) {updateF0}
+
+ // phase 1
+ def mapF1(et: EdgeTriplet[VT, Double]): Iterator[(Vid, RealVector)] = {
+ assert(et.srcAttr != null && et.dstAttr != null)
+ Iterator((et.srcId, et.dstAttr.v2)) // sum up y of connected item nodes
+ }
+ def reduceF1(g1: RealVector, g2: RealVector) = {
+ g1.add(g2)
+ }
+ def updateF1(vid: Vid, vd: VT, msg: Option[RealVector]) = {
+ if (msg.isDefined) {
+ vd.v2 = vd.v1.add(msg.get.mapMultiply(vd.norm)) // pu + |N(u)|^(-0.5)*sum(y)
+ }
+ vd
+ }
+
+ // phase 2
+ def mapF2(et: EdgeTriplet[VT, Double]): Iterator[(Vid, Msg)] = {
+ assert(et.srcAttr != null && et.dstAttr != null)
+ val usr = et.srcAttr
+ val itm = et.dstAttr
+ var p = usr.v1
+ var q = itm.v1
+ val itmBias = 0.0
+ val usrBias = 0.0
+ var pred = u + usr.bias + itm.bias + q.dotProduct(usr.v2)
+ pred = math.max(pred, minVal)
+ pred = math.min(pred, maxVal)
+ val err = et.attr - pred
+ val y = (q.mapMultiply(err*usr.norm)).subtract((usr.v2).mapMultiply(gamma7))
+ val newP = (q.mapMultiply(err)).subtract(p.mapMultiply(gamma7)) // for each connected item q
+ val newQ = (usr.v2.mapMultiply(err)).subtract(q.mapMultiply(gamma7))
+ Iterator((et.srcId, new Msg(newP, y, err - gamma6*usr.bias)), (et.dstId, new Msg(newQ, y, err - gamma6*itm.bias)))
+ }
+ def reduceF2(g1: Msg, g2: Msg):Msg = {
+ g1.v1 = g1.v1.add(g2.v1)
+ g1.v2 = g1.v2.add(g2.v2)
+ g1.bias += g2.bias
+ g1
+ }
+ def updateF2(vid: Vid, vd: VT, msg: Option[Msg]) = {
+ if (msg.isDefined) {
+ vd.v1 = vd.v1.add(msg.get.v1.mapMultiply(gamma2))
+ if (vid % 2 == 1) { // item node update y
+ vd.v2 = vd.v2.add(msg.get.v2.mapMultiply(gamma2))
+ }
+ vd.bias += msg.get.bias*gamma1
+ }
+ vd
+ }
+
+ for (i <- 0 until maxIters) {
+ // phase 1
+ val t1: VertexRDD[RealVector] = g.mapReduceTriplets(mapF1, reduceF1)
+ g.outerJoinVertices(t1) {updateF1}
+ // phase 2
+ val t2: VertexRDD[Msg] = g.mapReduceTriplets(mapF2, reduceF2)
+ g.outerJoinVertices(t2) {updateF2}
+ }
+
+ // calculate error on training set
+ def mapF3(et: EdgeTriplet[VT, Double]): Iterator[(Vid, Double)] = {
+ assert(et.srcAttr != null && et.dstAttr != null)
+ val usr = et.srcAttr
+ val itm = et.dstAttr
+ var p = usr.v1
+ var q = itm.v1
+ val itmBias = 0.0
+ val usrBias = 0.0
+ var pred = u + usr.bias + itm.bias + q.dotProduct(usr.v2)
+ pred = math.max(pred, minVal)
+ pred = math.min(pred, maxVal)
+ val err = (et.attr - pred)*(et.attr - pred)
+ Iterator((et.dstId, err))
+ }
+ def updateF3(vid: Vid, vd: VT, msg: Option[Double]) = {
+ if (msg.isDefined && vid % 2 == 1) { // item sum up the errors
+ vd.norm = msg.get
+ }
+ vd
+ }
+ val t3: VertexRDD[Double] = g.mapReduceTriplets(mapF3, _ + _)
+ g.outerJoinVertices(t3) {updateF3}
+ g
+ }
+}
diff --git a/graph/src/test/scala/org/apache/spark/graph/AnalyticsSuite.scala b/graph/src/test/scala/org/apache/spark/graph/AnalyticsSuite.scala
index b413b4587e..05ebe2b84d 100644
--- a/graph/src/test/scala/org/apache/spark/graph/AnalyticsSuite.scala
+++ b/graph/src/test/scala/org/apache/spark/graph/AnalyticsSuite.scala
@@ -257,4 +257,19 @@ class AnalyticsSuite extends FunSuite with LocalSparkContext {
verts.collect.foreach { case (vid, count) => assert(count === 1) }
}
}
+
+ test("Test SVD++ with mean square error on training set") {
+ withSpark(new SparkContext("local", "test")) { sc =>
+ val SvdppErr = 0.01
+ val edges = sc.textFile("mllib/data/als/test.data").map { line =>
+ val fields = line.split(",")
+ Edge(fields(0).toLong * 2, fields(1).toLong * 2 + 1, fields(2).toDouble)
+ }
+ val graph = Svdpp.run(edges)
+ val err = graph.vertices.collect.map{ case (vid, vd) =>
+ if (vid % 2 == 1) { vd.norm } else { 0.0 }
+ }.reduce(_ + _) / graph.triplets.collect.size
+ assert(err < SvdppErr)
+ }
+ }
} // end of AnalyticsSuite