From 42fbef3c2a6460bcd389bb86306be3ebc14c998b Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Thu, 17 Jan 2013 15:54:59 +0200 Subject: Adding default command line args to SparkALS --- .../src/main/scala/spark/examples/SparkALS.scala | 27 ++++++++++++++-------- 1 file changed, 17 insertions(+), 10 deletions(-) (limited to 'examples') diff --git a/examples/src/main/scala/spark/examples/SparkALS.scala b/examples/src/main/scala/spark/examples/SparkALS.scala index fb28e2c932..cbd749666d 100644 --- a/examples/src/main/scala/spark/examples/SparkALS.scala +++ b/examples/src/main/scala/spark/examples/SparkALS.scala @@ -7,6 +7,7 @@ import cern.jet.math._ import cern.colt.matrix._ import cern.colt.matrix.linalg._ import spark._ +import scala.Option object SparkALS { // Parameters set through command line arguments @@ -97,21 +98,27 @@ object SparkALS { def main(args: Array[String]) { var host = "" var slices = 0 - args match { - case Array(m, u, f, iters, slices_, host_) => { - M = m.toInt - U = u.toInt - F = f.toInt - ITERATIONS = iters.toInt - slices = slices_.toInt - host = host_ + + (1 to 6).map(i => { + i match { + case a if a < args.length => Option(args(a)) + case _ => Option(null) + } + }).toArray match { + case Array(host_, m, u, f, iters, slices_) => { + host = host_ getOrElse "local" + M = (m getOrElse "100").toInt + U = (u getOrElse "500").toInt + F = (f getOrElse "10").toInt + ITERATIONS = (iters getOrElse "5").toInt + slices = (slices_ getOrElse "2").toInt } case _ => { - System.err.println("Usage: SparkALS ") + System.err.println("Usage: SparkALS [ ]") System.exit(1) } } - printf("Running with M=%d, U=%d, F=%d, iters=%d\n", M, U, F, ITERATIONS); + printf("Running with M=%d, U=%d, F=%d, iters=%d\n", M, U, F, ITERATIONS) val spark = new SparkContext(host, "SparkALS") val R = generateR() -- cgit v1.2.3 From a512df551f85086a6ec363744542e74749c6b560 Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Thu, 17 Jan 2013 16:05:27 +0200 Subject: Fixed index error missing first argument --- examples/src/main/scala/spark/examples/SparkALS.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'examples') diff --git a/examples/src/main/scala/spark/examples/SparkALS.scala b/examples/src/main/scala/spark/examples/SparkALS.scala index cbd749666d..4672812565 100644 --- a/examples/src/main/scala/spark/examples/SparkALS.scala +++ b/examples/src/main/scala/spark/examples/SparkALS.scala @@ -99,7 +99,7 @@ object SparkALS { var host = "" var slices = 0 - (1 to 6).map(i => { + (0 to 5).map(i => { i match { case a if a < args.length => Option(args(a)) case _ => Option(null) -- cgit v1.2.3 From a5ba7a9f322dce763350864bf89d94e6656d9984 Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Thu, 17 Jan 2013 16:21:00 +0200 Subject: Use only one update function and pass in transpose of ratings matrix where appropriate --- .../src/main/scala/spark/examples/SparkALS.scala | 32 ++-------------------- 1 file changed, 3 insertions(+), 29 deletions(-) (limited to 'examples') diff --git a/examples/src/main/scala/spark/examples/SparkALS.scala b/examples/src/main/scala/spark/examples/SparkALS.scala index 4672812565..2766ad1702 100644 --- a/examples/src/main/scala/spark/examples/SparkALS.scala +++ b/examples/src/main/scala/spark/examples/SparkALS.scala @@ -43,7 +43,7 @@ object SparkALS { return sqrt(sumSqs / (M * U)) } - def updateMovie(i: Int, m: DoubleMatrix1D, us: Array[DoubleMatrix1D], + def update(i: Int, m: DoubleMatrix1D, us: Array[DoubleMatrix1D], R: DoubleMatrix2D) : DoubleMatrix1D = { val U = us.size @@ -69,32 +69,6 @@ object SparkALS { return solved2D.viewColumn(0) } - def updateUser(j: Int, u: DoubleMatrix1D, ms: Array[DoubleMatrix1D], - R: DoubleMatrix2D) : DoubleMatrix1D = - { - val M = ms.size - val F = ms(0).size - val XtX = factory2D.make(F, F) - val Xty = factory1D.make(F) - // For each movie that the user rated - for (i <- 0 until M) { - val m = ms(i) - // Add m * m^t to XtX - blas.dger(1, m, m, XtX) - // Add m * rating to Xty - blas.daxpy(R.get(i, j), m, Xty) - } - // Add regularization coefs to diagonal terms - for (d <- 0 until F) { - XtX.set(d, d, XtX.get(d, d) + LAMBDA * M) - } - // Solve it with Cholesky - val ch = new CholeskyDecomposition(XtX) - val Xty2D = factory2D.make(Xty.toArray, F) - val solved2D = ch.solve(Xty2D) - return solved2D.viewColumn(0) - } - def main(args: Array[String]) { var host = "" var slices = 0 @@ -134,11 +108,11 @@ object SparkALS { for (iter <- 1 to ITERATIONS) { println("Iteration " + iter + ":") ms = spark.parallelize(0 until M, slices) - .map(i => updateMovie(i, msc.value(i), usc.value, Rc.value)) + .map(i => update(i, msc.value(i), usc.value, Rc.value)) .toArray msc = spark.broadcast(ms) // Re-broadcast ms because it was updated us = spark.parallelize(0 until U, slices) - .map(i => updateUser(i, usc.value(i), msc.value, Rc.value)) + .map(i => update(i, usc.value(i), msc.value, algebra.transpose(Rc.value))) .toArray usc = spark.broadcast(us) // Re-broadcast us because it was updated println("RMSE = " + rmse(R, ms, us)) -- cgit v1.2.3 From 2a8c2a67909c4878ea24ec94f203287e55dd3782 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 20 Jan 2013 10:24:53 -0800 Subject: Minor formatting fixes --- examples/src/main/scala/spark/examples/SparkALS.scala | 4 ++-- python/examples/als.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) (limited to 'examples') diff --git a/examples/src/main/scala/spark/examples/SparkALS.scala b/examples/src/main/scala/spark/examples/SparkALS.scala index 2766ad1702..5e01885dbb 100644 --- a/examples/src/main/scala/spark/examples/SparkALS.scala +++ b/examples/src/main/scala/spark/examples/SparkALS.scala @@ -75,8 +75,8 @@ object SparkALS { (0 to 5).map(i => { i match { - case a if a < args.length => Option(args(a)) - case _ => Option(null) + case a if a < args.length => Some(args(a)) + case _ => None } }).toArray match { case Array(host_, m, u, f, iters, slices_) => { diff --git a/python/examples/als.py b/python/examples/als.py index 284cf0d3a2..010f80097f 100755 --- a/python/examples/als.py +++ b/python/examples/als.py @@ -68,4 +68,4 @@ if __name__ == "__main__": error = rmse(R, ms, us) print "Iteration %d:" % i - print "\nRMSE: %5.4f\n" % error \ No newline at end of file + print "\nRMSE: %5.4f\n" % error -- cgit v1.2.3