From cc0fd3317757eb745d0df8ba1510dda94cb5d655 Mon Sep 17 00:00:00 2001 From: Jianping J Wang Date: Thu, 23 Jan 2014 19:44:30 +0800 Subject: Replace commons-math with jblas --- .../org/apache/spark/graphx/lib/SVDPlusPlus.scala | 68 ++++++++++++---------- 1 file changed, 36 insertions(+), 32 deletions(-) (limited to 'graphx') diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala index 79280f836f..ccd7de537b 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala @@ -18,7 +18,7 @@ package org.apache.spark.graphx.lib import scala.util.Random -import org.apache.commons.math3.linear._ +import org.jblas.DoubleMatrix import org.apache.spark.rdd._ import org.apache.spark.graphx._ @@ -52,15 +52,15 @@ object SVDPlusPlus { * @return a graph with vertex attributes containing the trained model */ def run(edges: RDD[Edge[Double]], conf: Conf) - : (Graph[(RealVector, RealVector, Double, Double), Double], Double) = + : (Graph[(DoubleMatrix, DoubleMatrix, Double, Double), Double], Double) = { // Generate default vertex attribute - def defaultF(rank: Int): (RealVector, RealVector, Double, Double) = { - val v1 = new ArrayRealVector(rank) - val v2 = new ArrayRealVector(rank) + def defaultF(rank: Int): (DoubleMatrix, DoubleMatrix, Double, Double) = { + val v1 = new DoubleMatrix(rank) + val v2 = new DoubleMatrix(rank) for (i <- 0 until rank) { - v1.setEntry(i, Random.nextDouble()) - v2.setEntry(i, Random.nextDouble()) + v1.put(i, Random.nextDouble()) + v2.put(i, Random.nextDouble()) } (v1, v2, 0.0, 0.0) } @@ -76,31 +76,32 @@ object SVDPlusPlus { // Calculate initial bias and norm val t0 = g.mapReduceTriplets( et => Iterator((et.srcId, (1L, et.attr)), (et.dstId, (1L, et.attr))), - (g1: (Long, Double), g2: (Long, Double)) => (g1._1 + g2._1, g1._2 + g2._2)) + (g1: (Long, Double), g2: (Long, Double)) => (g1._1 + g2._1, g1._2 + g2._2)) g = g.outerJoinVertices(t0) { - (vid: VertexId, vd: (RealVector, RealVector, Double, Double), msg: Option[(Long, Double)]) => + (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), + msg: Option[(Long, Double)]) => (vd._1, vd._2, msg.get._2 / msg.get._1, 1.0 / scala.math.sqrt(msg.get._1)) } def mapTrainF(conf: Conf, u: Double) - (et: EdgeTriplet[(RealVector, RealVector, Double, Double), Double]) - : Iterator[(VertexId, (RealVector, RealVector, Double))] = { + (et: EdgeTriplet[(DoubleMatrix, DoubleMatrix, Double, Double), Double]) + : Iterator[(VertexId, (DoubleMatrix, DoubleMatrix, Double))] = { val (usr, itm) = (et.srcAttr, et.dstAttr) val (p, q) = (usr._1, itm._1) - var pred = u + usr._3 + itm._3 + q.dotProduct(usr._2) + var pred = u + usr._3 + itm._3 + q.dot(usr._2) pred = math.max(pred, conf.minVal) pred = math.min(pred, conf.maxVal) val err = et.attr - pred - val updateP = q.mapMultiply(err) - .subtract(p.mapMultiply(conf.gamma7)) - .mapMultiply(conf.gamma2) - val updateQ = usr._2.mapMultiply(err) - .subtract(q.mapMultiply(conf.gamma7)) - .mapMultiply(conf.gamma2) - val updateY = q.mapMultiply(err * usr._4) - .subtract(itm._2.mapMultiply(conf.gamma7)) - .mapMultiply(conf.gamma2) + val updateP = q.mul(err) + .subColumnVector(p.mul(conf.gamma7)) + .mul(conf.gamma2) + val updateQ = usr._2.mul(err) + .subColumnVector(q.mul(conf.gamma7)) + .mul(conf.gamma2) + val updateY = q.mul(err * usr._4) + .subColumnVector(itm._2.mul(conf.gamma7)) + .mul(conf.gamma2) Iterator((et.srcId, (updateP, updateY, (err - conf.gamma6 * usr._3) * conf.gamma1)), (et.dstId, (updateQ, updateY, (err - conf.gamma6 * itm._3) * conf.gamma1))) } @@ -110,34 +111,37 @@ object SVDPlusPlus { g.cache() val t1 = g.mapReduceTriplets( et => Iterator((et.srcId, et.dstAttr._2)), - (g1: RealVector, g2: RealVector) => g1.add(g2)) + (g1: DoubleMatrix, g2: DoubleMatrix) => g1.addColumnVector(g2)) g = g.outerJoinVertices(t1) { - (vid: VertexId, vd: (RealVector, RealVector, Double, Double), msg: Option[RealVector]) => - if (msg.isDefined) (vd._1, vd._1.add(msg.get.mapMultiply(vd._4)), vd._3, vd._4) else vd + (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), + msg: Option[DoubleMatrix]) => + if (msg.isDefined) (vd._1, vd._1 + .addColumnVector(msg.get.mul(vd._4)), vd._3, vd._4) else vd } // Phase 2, update p for user nodes and q, y for item nodes g.cache() val t2 = g.mapReduceTriplets( mapTrainF(conf, u), - (g1: (RealVector, RealVector, Double), g2: (RealVector, RealVector, Double)) => - (g1._1.add(g2._1), g1._2.add(g2._2), g1._3 + g2._3)) + (g1: (DoubleMatrix, DoubleMatrix, Double), g2: (DoubleMatrix, DoubleMatrix, Double)) => + (g1._1.addColumnVector(g2._1), g1._2.addColumnVector(g2._2), g1._3 + g2._3)) g = g.outerJoinVertices(t2) { (vid: VertexId, - vd: (RealVector, RealVector, Double, Double), - msg: Option[(RealVector, RealVector, Double)]) => - (vd._1.add(msg.get._1), vd._2.add(msg.get._2), vd._3 + msg.get._3, vd._4) + vd: (DoubleMatrix, DoubleMatrix, Double, Double), + msg: Option[(DoubleMatrix, DoubleMatrix, Double)]) => + (vd._1.addColumnVector(msg.get._1), vd._2.addColumnVector(msg.get._2), + vd._3 + msg.get._3, vd._4) } } // calculate error on training set def mapTestF(conf: Conf, u: Double) - (et: EdgeTriplet[(RealVector, RealVector, Double, Double), Double]) + (et: EdgeTriplet[(DoubleMatrix, DoubleMatrix, Double, Double), Double]) : Iterator[(VertexId, Double)] = { val (usr, itm) = (et.srcAttr, et.dstAttr) val (p, q) = (usr._1, itm._1) - var pred = u + usr._3 + itm._3 + q.dotProduct(usr._2) + var pred = u + usr._3 + itm._3 + q.dot(usr._2) pred = math.max(pred, conf.minVal) pred = math.min(pred, conf.maxVal) val err = (et.attr - pred) * (et.attr - pred) @@ -146,7 +150,7 @@ object SVDPlusPlus { g.cache() val t3 = g.mapReduceTriplets(mapTestF(conf, u), (g1: Double, g2: Double) => g1 + g2) g = g.outerJoinVertices(t3) { - (vid: VertexId, vd: (RealVector, RealVector, Double, Double), msg: Option[Double]) => + (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), msg: Option[Double]) => if (msg.isDefined) (vd._1, vd._2, vd._3, msg.get) else vd } -- cgit v1.2.3