From a5ba7a9f322dce763350864bf89d94e6656d9984 Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Thu, 17 Jan 2013 16:21:00 +0200 Subject: Use only one update function and pass in transpose of ratings matrix where appropriate --- .../src/main/scala/spark/examples/SparkALS.scala | 32 ++-------------------- 1 file changed, 3 insertions(+), 29 deletions(-) (limited to 'examples') diff --git a/examples/src/main/scala/spark/examples/SparkALS.scala b/examples/src/main/scala/spark/examples/SparkALS.scala index 4672812565..2766ad1702 100644 --- a/examples/src/main/scala/spark/examples/SparkALS.scala +++ b/examples/src/main/scala/spark/examples/SparkALS.scala @@ -43,7 +43,7 @@ object SparkALS { return sqrt(sumSqs / (M * U)) } - def updateMovie(i: Int, m: DoubleMatrix1D, us: Array[DoubleMatrix1D], + def update(i: Int, m: DoubleMatrix1D, us: Array[DoubleMatrix1D], R: DoubleMatrix2D) : DoubleMatrix1D = { val U = us.size @@ -69,32 +69,6 @@ object SparkALS { return solved2D.viewColumn(0) } - def updateUser(j: Int, u: DoubleMatrix1D, ms: Array[DoubleMatrix1D], - R: DoubleMatrix2D) : DoubleMatrix1D = - { - val M = ms.size - val F = ms(0).size - val XtX = factory2D.make(F, F) - val Xty = factory1D.make(F) - // For each movie that the user rated - for (i <- 0 until M) { - val m = ms(i) - // Add m * m^t to XtX - blas.dger(1, m, m, XtX) - // Add m * rating to Xty - blas.daxpy(R.get(i, j), m, Xty) - } - // Add regularization coefs to diagonal terms - for (d <- 0 until F) { - XtX.set(d, d, XtX.get(d, d) + LAMBDA * M) - } - // Solve it with Cholesky - val ch = new CholeskyDecomposition(XtX) - val Xty2D = factory2D.make(Xty.toArray, F) - val solved2D = ch.solve(Xty2D) - return solved2D.viewColumn(0) - } - def main(args: Array[String]) { var host = "" var slices = 0 @@ -134,11 +108,11 @@ object SparkALS { for (iter <- 1 to ITERATIONS) { println("Iteration " + iter + ":") ms = spark.parallelize(0 until M, slices) - .map(i => updateMovie(i, msc.value(i), usc.value, Rc.value)) + .map(i => update(i, msc.value(i), usc.value, Rc.value)) .toArray msc = spark.broadcast(ms) // Re-broadcast ms because it was updated us = spark.parallelize(0 until U, slices) - .map(i => updateUser(i, usc.value(i), msc.value, Rc.value)) + .map(i => update(i, usc.value(i), msc.value, algebra.transpose(Rc.value))) .toArray usc = spark.broadcast(us) // Re-broadcast us because it was updated println("RMSE = " + rmse(R, ms, us)) -- cgit v1.2.3