aboutsummaryrefslogtreecommitdiff
path: root/examples/src/main/scala/SparkALS.scala
diff options
context:
space:
mode:
Diffstat (limited to 'examples/src/main/scala/SparkALS.scala')
-rw-r--r--examples/src/main/scala/SparkALS.scala139
1 files changed, 139 insertions, 0 deletions
diff --git a/examples/src/main/scala/SparkALS.scala b/examples/src/main/scala/SparkALS.scala
new file mode 100644
index 0000000000..6fae3c0940
--- /dev/null
+++ b/examples/src/main/scala/SparkALS.scala
@@ -0,0 +1,139 @@
+import java.io.Serializable
+import java.util.Random
+import scala.math.sqrt
+import cern.jet.math._
+import cern.colt.matrix._
+import cern.colt.matrix.linalg._
+import spark._
+
+object SparkALS {
+ // Parameters set through command line arguments
+ var M = 0 // Number of movies
+ var U = 0 // Number of users
+ var F = 0 // Number of features
+ var ITERATIONS = 0
+
+ val LAMBDA = 0.01 // Regularization coefficient
+
+ // Some COLT objects
+ val factory2D = DoubleFactory2D.dense
+ val factory1D = DoubleFactory1D.dense
+ val algebra = Algebra.DEFAULT
+ val blas = SeqBlas.seqBlas
+
+ def generateR(): DoubleMatrix2D = {
+ val mh = factory2D.random(M, F)
+ val uh = factory2D.random(U, F)
+ return algebra.mult(mh, algebra.transpose(uh))
+ }
+
+ def rmse(targetR: DoubleMatrix2D, ms: Array[DoubleMatrix1D],
+ us: Array[DoubleMatrix1D]): Double =
+ {
+ val r = factory2D.make(M, U)
+ for (i <- 0 until M; j <- 0 until U) {
+ r.set(i, j, blas.ddot(ms(i), us(j)))
+ }
+ //println("R: " + r)
+ blas.daxpy(-1, targetR, r)
+ val sumSqs = r.aggregate(Functions.plus, Functions.square)
+ return sqrt(sumSqs / (M * U))
+ }
+
+ def updateMovie(i: Int, m: DoubleMatrix1D, us: Array[DoubleMatrix1D],
+ R: DoubleMatrix2D) : DoubleMatrix1D =
+ {
+ val U = us.size
+ val F = us(0).size
+ val XtX = factory2D.make(F, F)
+ val Xty = factory1D.make(F)
+ // For each user that rated the movie
+ for (j <- 0 until U) {
+ val u = us(j)
+ // Add u * u^t to XtX
+ blas.dger(1, u, u, XtX)
+ // Add u * rating to Xty
+ blas.daxpy(R.get(i, j), u, Xty)
+ }
+ // Add regularization coefs to diagonal terms
+ for (d <- 0 until F) {
+ XtX.set(d, d, XtX.get(d, d) + LAMBDA * U)
+ }
+ // Solve it with Cholesky
+ val ch = new CholeskyDecomposition(XtX)
+ val Xty2D = factory2D.make(Xty.toArray, F)
+ val solved2D = ch.solve(Xty2D)
+ return solved2D.viewColumn(0)
+ }
+
+ def updateUser(j: Int, u: DoubleMatrix1D, ms: Array[DoubleMatrix1D],
+ R: DoubleMatrix2D) : DoubleMatrix1D =
+ {
+ val M = ms.size
+ val F = ms(0).size
+ val XtX = factory2D.make(F, F)
+ val Xty = factory1D.make(F)
+ // For each movie that the user rated
+ for (i <- 0 until M) {
+ val m = ms(i)
+ // Add m * m^t to XtX
+ blas.dger(1, m, m, XtX)
+ // Add m * rating to Xty
+ blas.daxpy(R.get(i, j), m, Xty)
+ }
+ // Add regularization coefs to diagonal terms
+ for (d <- 0 until F) {
+ XtX.set(d, d, XtX.get(d, d) + LAMBDA * M)
+ }
+ // Solve it with Cholesky
+ val ch = new CholeskyDecomposition(XtX)
+ val Xty2D = factory2D.make(Xty.toArray, F)
+ val solved2D = ch.solve(Xty2D)
+ return solved2D.viewColumn(0)
+ }
+
+ def main(args: Array[String]) {
+ var host = ""
+ var slices = 0
+ args match {
+ case Array(m, u, f, iters, slices_, host_) => {
+ M = m.toInt
+ U = u.toInt
+ F = f.toInt
+ ITERATIONS = iters.toInt
+ slices = slices_.toInt
+ host = host_
+ }
+ case _ => {
+ System.err.println("Usage: SparkALS <M> <U> <F> <iters> <slices> <host>")
+ System.exit(1)
+ }
+ }
+ printf("Running with M=%d, U=%d, F=%d, iters=%d\n", M, U, F, ITERATIONS);
+ val spark = new SparkContext(host, "SparkALS")
+
+ val R = generateR()
+
+ // Initialize m and u randomly
+ var ms = Array.fromFunction(_ => factory1D.random(F))(M)
+ var us = Array.fromFunction(_ => factory1D.random(F))(U)
+
+ // Iteratively update movies then users
+ val Rc = spark.broadcast(R)
+ var msc = spark.broadcast(ms)
+ var usc = spark.broadcast(us)
+ for (iter <- 1 to ITERATIONS) {
+ println("Iteration " + iter + ":")
+ ms = spark.parallelize(0 until M, slices)
+ .map(i => updateMovie(i, msc.value(i), usc.value, Rc.value))
+ .toArray
+ msc = spark.broadcast(ms) // Re-broadcast ms because it was updated
+ us = spark.parallelize(0 until U, slices)
+ .map(i => updateUser(i, usc.value(i), msc.value, Rc.value))
+ .toArray
+ usc = spark.broadcast(us) // Re-broadcast us because it was updated
+ println("RMSE = " + rmse(R, ms, us))
+ println()
+ }
+ }
+}