aboutsummaryrefslogtreecommitdiff
path: root/examples/src/main/scala/SparkALS.scala
blob: 6fae3c0940cf5a147c4b3b2e32a5131d67175b0b (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import java.io.Serializable
import java.util.Random
import scala.math.sqrt
import cern.jet.math._
import cern.colt.matrix._
import cern.colt.matrix.linalg._
import spark._

object SparkALS {
  // Parameters set through command line arguments
  var M = 0 // Number of movies
  var U = 0 // Number of users
  var F = 0 // Number of features
  var ITERATIONS = 0

  val LAMBDA = 0.01 // Regularization coefficient

  // Some COLT objects
  val factory2D = DoubleFactory2D.dense
  val factory1D = DoubleFactory1D.dense
  val algebra = Algebra.DEFAULT
  val blas = SeqBlas.seqBlas

  def generateR(): DoubleMatrix2D = {
    val mh = factory2D.random(M, F)
    val uh = factory2D.random(U, F)
    return algebra.mult(mh, algebra.transpose(uh))
  }

  def rmse(targetR: DoubleMatrix2D, ms: Array[DoubleMatrix1D],
    us: Array[DoubleMatrix1D]): Double =
  {
    val r = factory2D.make(M, U)
    for (i <- 0 until M; j <- 0 until U) {
      r.set(i, j, blas.ddot(ms(i), us(j)))
    }
    //println("R: " + r)
    blas.daxpy(-1, targetR, r)
    val sumSqs = r.aggregate(Functions.plus, Functions.square)
    return sqrt(sumSqs / (M * U))
  }

  def updateMovie(i: Int, m: DoubleMatrix1D, us: Array[DoubleMatrix1D],
    R: DoubleMatrix2D) : DoubleMatrix1D =
  {
    val U = us.size
    val F = us(0).size
    val XtX = factory2D.make(F, F)
    val Xty = factory1D.make(F)
    // For each user that rated the movie
    for (j <- 0 until U) {
      val u = us(j)
      // Add u * u^t to XtX
      blas.dger(1, u, u, XtX)
      // Add u * rating to Xty
      blas.daxpy(R.get(i, j), u, Xty)
    }
    // Add regularization coefs to diagonal terms
    for (d <- 0 until F) {
      XtX.set(d, d, XtX.get(d, d) + LAMBDA * U)
    }
    // Solve it with Cholesky
    val ch = new CholeskyDecomposition(XtX)
    val Xty2D = factory2D.make(Xty.toArray, F)
    val solved2D = ch.solve(Xty2D)
    return solved2D.viewColumn(0)
  }

  def updateUser(j: Int, u: DoubleMatrix1D, ms: Array[DoubleMatrix1D],
    R: DoubleMatrix2D) : DoubleMatrix1D =
  {
    val M = ms.size
    val F = ms(0).size
    val XtX = factory2D.make(F, F)
    val Xty = factory1D.make(F)
    // For each movie that the user rated
    for (i <- 0 until M) {
      val m = ms(i)
      // Add m * m^t to XtX
      blas.dger(1, m, m, XtX)
      // Add m * rating to Xty
      blas.daxpy(R.get(i, j), m, Xty)
    }
    // Add regularization coefs to diagonal terms
    for (d <- 0 until F) {
      XtX.set(d, d, XtX.get(d, d) + LAMBDA * M)
    }
    // Solve it with Cholesky
    val ch = new CholeskyDecomposition(XtX)
    val Xty2D = factory2D.make(Xty.toArray, F)
    val solved2D = ch.solve(Xty2D)
    return solved2D.viewColumn(0)
  }

  def main(args: Array[String]) {
    var host = ""
    var slices = 0
    args match {
      case Array(m, u, f, iters, slices_, host_) => {
        M = m.toInt
        U = u.toInt
        F = f.toInt
        ITERATIONS = iters.toInt
        slices = slices_.toInt
        host = host_
      }
      case _ => {
        System.err.println("Usage: SparkALS <M> <U> <F> <iters> <slices> <host>")
        System.exit(1)
      }
    }
    printf("Running with M=%d, U=%d, F=%d, iters=%d\n", M, U, F, ITERATIONS);
    val spark = new SparkContext(host, "SparkALS")
    
    val R = generateR()

    // Initialize m and u randomly
    var ms = Array.fromFunction(_ => factory1D.random(F))(M)
    var us = Array.fromFunction(_ => factory1D.random(F))(U)

    // Iteratively update movies then users
    val Rc  = spark.broadcast(R)
    var msc = spark.broadcast(ms)
    var usc = spark.broadcast(us)
    for (iter <- 1 to ITERATIONS) {
      println("Iteration " + iter + ":")
      ms = spark.parallelize(0 until M, slices)
                .map(i => updateMovie(i, msc.value(i), usc.value, Rc.value))
                .toArray
      msc = spark.broadcast(ms) // Re-broadcast ms because it was updated
      us = spark.parallelize(0 until U, slices)
                .map(i => updateUser(i, usc.value(i), msc.value, Rc.value))
                .toArray
      usc = spark.broadcast(us) // Re-broadcast us because it was updated
      println("RMSE = " + rmse(R, ms, us))
      println()
    }
  }
}