aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorNick Pentreath <nick.pentreath@gmail.com>2013-01-17 16:21:00 +0200
committerNick Pentreath <nick.pentreath@gmail.com>2013-01-17 16:21:00 +0200
commita5ba7a9f322dce763350864bf89d94e6656d9984 (patch)
tree8b909c78b9e4d319e4ee055da1fd102bddab2b2a /examples
parenta512df551f85086a6ec363744542e74749c6b560 (diff)
downloadspark-a5ba7a9f322dce763350864bf89d94e6656d9984.tar.gz
spark-a5ba7a9f322dce763350864bf89d94e6656d9984.tar.bz2
spark-a5ba7a9f322dce763350864bf89d94e6656d9984.zip
Use only one update function and pass in transpose of ratings matrix where appropriate
Diffstat (limited to 'examples')
-rw-r--r--examples/src/main/scala/spark/examples/SparkALS.scala32
1 files changed, 3 insertions, 29 deletions
diff --git a/examples/src/main/scala/spark/examples/SparkALS.scala b/examples/src/main/scala/spark/examples/SparkALS.scala
index 4672812565..2766ad1702 100644
--- a/examples/src/main/scala/spark/examples/SparkALS.scala
+++ b/examples/src/main/scala/spark/examples/SparkALS.scala
@@ -43,7 +43,7 @@ object SparkALS {
return sqrt(sumSqs / (M * U))
}
- def updateMovie(i: Int, m: DoubleMatrix1D, us: Array[DoubleMatrix1D],
+ def update(i: Int, m: DoubleMatrix1D, us: Array[DoubleMatrix1D],
R: DoubleMatrix2D) : DoubleMatrix1D =
{
val U = us.size
@@ -69,32 +69,6 @@ object SparkALS {
return solved2D.viewColumn(0)
}
- def updateUser(j: Int, u: DoubleMatrix1D, ms: Array[DoubleMatrix1D],
- R: DoubleMatrix2D) : DoubleMatrix1D =
- {
- val M = ms.size
- val F = ms(0).size
- val XtX = factory2D.make(F, F)
- val Xty = factory1D.make(F)
- // For each movie that the user rated
- for (i <- 0 until M) {
- val m = ms(i)
- // Add m * m^t to XtX
- blas.dger(1, m, m, XtX)
- // Add m * rating to Xty
- blas.daxpy(R.get(i, j), m, Xty)
- }
- // Add regularization coefs to diagonal terms
- for (d <- 0 until F) {
- XtX.set(d, d, XtX.get(d, d) + LAMBDA * M)
- }
- // Solve it with Cholesky
- val ch = new CholeskyDecomposition(XtX)
- val Xty2D = factory2D.make(Xty.toArray, F)
- val solved2D = ch.solve(Xty2D)
- return solved2D.viewColumn(0)
- }
-
def main(args: Array[String]) {
var host = ""
var slices = 0
@@ -134,11 +108,11 @@ object SparkALS {
for (iter <- 1 to ITERATIONS) {
println("Iteration " + iter + ":")
ms = spark.parallelize(0 until M, slices)
- .map(i => updateMovie(i, msc.value(i), usc.value, Rc.value))
+ .map(i => update(i, msc.value(i), usc.value, Rc.value))
.toArray
msc = spark.broadcast(ms) // Re-broadcast ms because it was updated
us = spark.parallelize(0 until U, slices)
- .map(i => updateUser(i, usc.value(i), msc.value, Rc.value))
+ .map(i => update(i, usc.value(i), msc.value, algebra.transpose(Rc.value)))
.toArray
usc = spark.broadcast(us) // Re-broadcast us because it was updated
println("RMSE = " + rmse(R, ms, us))