aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-03-13 00:43:19 -0700
committerReynold Xin <rxin@apache.org>2014-03-13 00:43:19 -0700
commite4e8d8f395aea48f0cae00d7c381a863c48a2837 (patch)
tree283f03c2f4a7be86eb7d950dd0ee51c375b0f737
parent4ea23db0efff2f39ac5b8f0bd1d9a6ffa3eceb0d (diff)
downloadspark-e4e8d8f395aea48f0cae00d7c381a863c48a2837.tar.gz
spark-e4e8d8f395aea48f0cae00d7c381a863c48a2837.tar.bz2
spark-e4e8d8f395aea48f0cae00d7c381a863c48a2837.zip
[SPARK-1237, 1238] Improve the computation of YtY for implicit ALS
Computing YtY can be implemented using BLAS's DSPR operations instead of generating y_i y_i^T and then combining them. The latter generates many k-by-k matrices. On the movielens data, this change improves the performance by 10-20%. The algorithm remains the same, verified by computing RMSE on the movielens data. To compare the results, I also added an option to set a random seed in ALS. JIRA: 1. https://spark-project.atlassian.net/browse/SPARK-1237 2. https://spark-project.atlassian.net/browse/SPARK-1238 Author: Xiangrui Meng <meng@databricks.com> Closes #131 from mengxr/als and squashes the following commits: ed00432 [Xiangrui Meng] minor changes d984623 [Xiangrui Meng] minor changes 2fc1641 [Xiangrui Meng] remove commented code 4c7cde2 [Xiangrui Meng] allow specifying a random seed in ALS 200bef0 [Xiangrui Meng] optimize computeYtY and updateBlock
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala174
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala15
2 files changed, 134 insertions, 55 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
index 8958040e36..777d0db2d6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
@@ -89,10 +89,15 @@ case class Rating(val user: Int, val product: Int, val rating: Double)
* indicated user
* preferences rather than explicit ratings given to items.
*/
-class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var lambda: Double,
- var implicitPrefs: Boolean, var alpha: Double)
- extends Serializable with Logging
-{
+class ALS private (
+ var numBlocks: Int,
+ var rank: Int,
+ var iterations: Int,
+ var lambda: Double,
+ var implicitPrefs: Boolean,
+ var alpha: Double,
+ var seed: Long = System.nanoTime()
+ ) extends Serializable with Logging {
def this() = this(-1, 10, 10, 0.01, false, 1.0)
/**
@@ -132,6 +137,12 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
this
}
+ /** Sets a random seed to have deterministic results. */
+ def setSeed(seed: Long): ALS = {
+ this.seed = seed
+ this
+ }
+
/**
* Run ALS with the configured parameters on an input RDD of (user, product, rating) triples.
* Returns a MatrixFactorizationModel with feature vectors for each user and product.
@@ -155,7 +166,7 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
// Initialize user and product factors randomly, but use a deterministic seed for each
// partition so that fault recovery works
- val seedGen = new Random()
+ val seedGen = new Random(seed)
val seed1 = seedGen.nextInt()
val seed2 = seedGen.nextInt()
// Hash an integer to propagate random bits at all positions, similar to java.util.HashTable
@@ -210,22 +221,47 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
*/
def computeYtY(factors: RDD[(Int, Array[Array[Double]])]) = {
if (implicitPrefs) {
- Option(
- factors.flatMapValues { case factorArray =>
- factorArray.view.map { vector =>
- val x = new DoubleMatrix(vector)
- x.mmul(x.transpose())
- }
- }.reduceByKeyLocally((a, b) => a.addi(b))
- .values
- .reduce((a, b) => a.addi(b))
- )
+ val n = rank * (rank + 1) / 2
+ val LYtY = factors.values.aggregate(new DoubleMatrix(n))( seqOp = (L, Y) => {
+ Y.foreach(y => dspr(1.0, new DoubleMatrix(y), L))
+ L
+ }, combOp = (L1, L2) => {
+ L1.addi(L2)
+ })
+ val YtY = new DoubleMatrix(rank, rank)
+ fillFullMatrix(LYtY, YtY)
+ Option(YtY)
} else {
None
}
}
/**
+ * Adds alpha * x * x.t to a matrix in-place. This is the same as BLAS's DSPR.
+ *
+ * @param L the lower triangular part of the matrix packed in an array (row major)
+ */
+ private def dspr(alpha: Double, x: DoubleMatrix, L: DoubleMatrix) = {
+ val n = x.length
+ var i = 0
+ var j = 0
+ var idx = 0
+ var axi = 0.0
+ val xd = x.data
+ val Ld = L.data
+ while (i < n) {
+ axi = alpha * xd(i)
+ j = 0
+ while (j <= i) {
+ Ld(idx) += axi * xd(j)
+ j += 1
+ idx += 1
+ }
+ i += 1
+ }
+ }
+
+ /**
* Flatten out blocked user or product factors into an RDD of (id, factor vector) pairs
*/
def unblockFactors(blockedFactors: RDD[(Int, Array[Array[Double]])],
@@ -376,7 +412,8 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
for (productBlock <- 0 until numBlocks) {
for (p <- 0 until blockFactors(productBlock).length) {
val x = new DoubleMatrix(blockFactors(productBlock)(p))
- fillXtX(x, tempXtX)
+ tempXtX.fill(0.0)
+ dspr(1.0, x, tempXtX)
val (us, rs) = inLinkBlock.ratingsForBlock(productBlock)(p)
for (i <- 0 until us.length) {
implicitPrefs match {
@@ -387,7 +424,7 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
// Extension to the original paper to handle rs(i) < 0. confidence is a function
// of |rs(i)| instead so that it is never negative:
val confidence = 1 + alpha * abs(rs(i))
- userXtX(us(i)).addi(tempXtX.mul(confidence - 1))
+ SimpleBlas.axpy(confidence - 1.0, tempXtX, userXtX(us(i)))
// For rs(i) < 0, the corresponding entry in P is 0 now, not 1 -- negative rs(i)
// means we try to reconstruct 0. We add terms only where P = 1, so, term below
// is now only added for rs(i) > 0:
@@ -400,39 +437,20 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
}
// Solve the least-squares problem for each user and return the new feature vectors
- userXtX.zipWithIndex.map{ case (triangularXtX, index) =>
+ Array.range(0, numUsers).map { index =>
// Compute the full XtX matrix from the lower-triangular part we got above
- fillFullMatrix(triangularXtX, fullXtX)
+ fillFullMatrix(userXtX(index), fullXtX)
// Add regularization
(0 until rank).foreach(i => fullXtX.data(i*rank + i) += lambda)
// Solve the resulting matrix, which is symmetric and positive-definite
implicitPrefs match {
case false => Solve.solvePositive(fullXtX, userXy(index)).data
- case true => Solve.solvePositive(fullXtX.add(YtY.value.get), userXy(index)).data
+ case true => Solve.solvePositive(fullXtX.addi(YtY.value.get), userXy(index)).data
}
}
}
/**
- * Set xtxDest to the lower-triangular part of x transpose * x. For efficiency in summing
- * these matrices, we store xtxDest as only rank * (rank+1) / 2 values, namely the values
- * at (0,0), (1,0), (1,1), (2,0), (2,1), (2,2), etc in that order.
- */
- private def fillXtX(x: DoubleMatrix, xtxDest: DoubleMatrix) {
- var i = 0
- var pos = 0
- while (i < x.length) {
- var j = 0
- while (j <= i) {
- xtxDest.data(pos) = x.data(i) * x.data(j)
- pos += 1
- j += 1
- }
- i += 1
- }
- }
-
- /**
* Given a triangular matrix in the order of fillXtX above, compute the full symmetric square
* matrix that it represents, storing it into destMatrix.
*/
@@ -455,9 +473,10 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
/**
- * Top-level methods for calling Alternating Least Squares (ALS) matrix factorizaton.
+ * Top-level methods for calling Alternating Least Squares (ALS) matrix factorization.
*/
object ALS {
+
/**
* Train a matrix factorization model given an RDD of ratings given by users to some products,
* in the form of (userID, productID, rating) pairs. We approximate the ratings matrix as the
@@ -470,15 +489,39 @@ object ALS {
* @param iterations number of iterations of ALS (recommended: 10-20)
* @param lambda regularization factor (recommended: 0.01)
* @param blocks level of parallelism to split computation into
+ * @param seed random seed
*/
def train(
ratings: RDD[Rating],
rank: Int,
iterations: Int,
lambda: Double,
- blocks: Int)
- : MatrixFactorizationModel =
- {
+ blocks: Int,
+ seed: Long
+ ): MatrixFactorizationModel = {
+ new ALS(blocks, rank, iterations, lambda, false, 1.0, seed).run(ratings)
+ }
+
+ /**
+ * Train a matrix factorization model given an RDD of ratings given by users to some products,
+ * in the form of (userID, productID, rating) pairs. We approximate the ratings matrix as the
+ * product of two lower-rank matrices of a given rank (number of features). To solve for these
+ * features, we run a given number of iterations of ALS. This is done using a level of
+ * parallelism given by `blocks`.
+ *
+ * @param ratings RDD of (userID, productID, rating) pairs
+ * @param rank number of features to use
+ * @param iterations number of iterations of ALS (recommended: 10-20)
+ * @param lambda regularization factor (recommended: 0.01)
+ * @param blocks level of parallelism to split computation into
+ */
+ def train(
+ ratings: RDD[Rating],
+ rank: Int,
+ iterations: Int,
+ lambda: Double,
+ blocks: Int
+ ): MatrixFactorizationModel = {
new ALS(blocks, rank, iterations, lambda, false, 1.0).run(ratings)
}
@@ -495,8 +538,7 @@ object ALS {
* @param lambda regularization factor (recommended: 0.01)
*/
def train(ratings: RDD[Rating], rank: Int, iterations: Int, lambda: Double)
- : MatrixFactorizationModel =
- {
+ : MatrixFactorizationModel = {
train(ratings, rank, iterations, lambda, -1)
}
@@ -512,8 +554,7 @@ object ALS {
* @param iterations number of iterations of ALS (recommended: 10-20)
*/
def train(ratings: RDD[Rating], rank: Int, iterations: Int)
- : MatrixFactorizationModel =
- {
+ : MatrixFactorizationModel = {
train(ratings, rank, iterations, 0.01, -1)
}
@@ -530,6 +571,7 @@ object ALS {
* @param lambda regularization factor (recommended: 0.01)
* @param blocks level of parallelism to split computation into
* @param alpha confidence parameter (only applies when immplicitPrefs = true)
+ * @param seed random seed
*/
def trainImplicit(
ratings: RDD[Rating],
@@ -537,9 +579,34 @@ object ALS {
iterations: Int,
lambda: Double,
blocks: Int,
- alpha: Double)
- : MatrixFactorizationModel =
- {
+ alpha: Double,
+ seed: Long
+ ): MatrixFactorizationModel = {
+ new ALS(blocks, rank, iterations, lambda, true, alpha, seed).run(ratings)
+ }
+
+ /**
+ * Train a matrix factorization model given an RDD of 'implicit preferences' given by users
+ * to some products, in the form of (userID, productID, preference) pairs. We approximate the
+ * ratings matrix as the product of two lower-rank matrices of a given rank (number of features).
+ * To solve for these features, we run a given number of iterations of ALS. This is done using
+ * a level of parallelism given by `blocks`.
+ *
+ * @param ratings RDD of (userID, productID, rating) pairs
+ * @param rank number of features to use
+ * @param iterations number of iterations of ALS (recommended: 10-20)
+ * @param lambda regularization factor (recommended: 0.01)
+ * @param blocks level of parallelism to split computation into
+ * @param alpha confidence parameter (only applies when immplicitPrefs = true)
+ */
+ def trainImplicit(
+ ratings: RDD[Rating],
+ rank: Int,
+ iterations: Int,
+ lambda: Double,
+ blocks: Int,
+ alpha: Double
+ ): MatrixFactorizationModel = {
new ALS(blocks, rank, iterations, lambda, true, alpha).run(ratings)
}
@@ -555,8 +622,8 @@ object ALS {
* @param iterations number of iterations of ALS (recommended: 10-20)
* @param lambda regularization factor (recommended: 0.01)
*/
- def trainImplicit(ratings: RDD[Rating], rank: Int, iterations: Int, lambda: Double,
- alpha: Double): MatrixFactorizationModel = {
+ def trainImplicit(ratings: RDD[Rating], rank: Int, iterations: Int, lambda: Double, alpha: Double)
+ : MatrixFactorizationModel = {
trainImplicit(ratings, rank, iterations, lambda, -1, alpha)
}
@@ -573,8 +640,7 @@ object ALS {
* @param iterations number of iterations of ALS (recommended: 10-20)
*/
def trainImplicit(ratings: RDD[Rating], rank: Int, iterations: Int)
- : MatrixFactorizationModel =
- {
+ : MatrixFactorizationModel = {
trainImplicit(ratings, rank, iterations, 0.01, -1, 1.0)
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
index 45e7d2db00..5aab9aba8f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
@@ -23,9 +23,10 @@ import scala.util.Random
import org.scalatest.FunSuite
-import org.jblas._
+import org.jblas.DoubleMatrix
import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.SparkContext._
object ALSSuite {
@@ -115,6 +116,18 @@ class ALSSuite extends FunSuite with LocalSparkContext {
testALS(100, 200, 2, 15, 0.7, 0.4, true, false, true)
}
+ test("pseudorandomness") {
+ val ratings = sc.parallelize(ALSSuite.generateRatings(10, 20, 5, 0.5, false, false)._1, 2)
+ val model11 = ALS.train(ratings, 5, 1, 1.0, 2, 1)
+ val model12 = ALS.train(ratings, 5, 1, 1.0, 2, 1)
+ val u11 = model11.userFeatures.values.flatMap(_.toList).collect().toList
+ val u12 = model12.userFeatures.values.flatMap(_.toList).collect().toList
+ val model2 = ALS.train(ratings, 5, 1, 1.0, 2, 2)
+ val u2 = model2.userFeatures.values.flatMap(_.toList).collect().toList
+ assert(u11 == u12)
+ assert(u11 != u2)
+ }
+
/**
* Test if we can correctly factorize R = U * P where U and P are of known rank.
*