aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala8
1 files changed, 7 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
index f1ae7b85b4..cc56fd6ef2 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
@@ -507,6 +507,9 @@ class ALS private (
val tempXtX = DoubleMatrix.zeros(triangleSize)
val fullXtX = DoubleMatrix.zeros(rank, rank)
+ // Count the number of ratings each user gives to provide user-specific regularization
+ val numRatings = Array.fill(numUsers)(0)
+
// Compute the XtX and Xy values for each user by adding products it rated in each product
// block
for (productBlock <- 0 until numProductBlocks) {
@@ -519,6 +522,7 @@ class ALS private (
if (implicitPrefs) {
var i = 0
while (i < us.length) {
+ numRatings(us(i)) += 1
// Extension to the original paper to handle rs(i) < 0. confidence is a function
// of |rs(i)| instead so that it is never negative:
val confidence = 1 + alpha * abs(rs(i))
@@ -534,6 +538,7 @@ class ALS private (
} else {
var i = 0
while (i < us.length) {
+ numRatings(us(i)) += 1
userXtX(us(i)).addi(tempXtX)
SimpleBlas.axpy(rs(i), x, userXy(us(i)))
i += 1
@@ -550,9 +555,10 @@ class ALS private (
// Compute the full XtX matrix from the lower-triangular part we got above
fillFullMatrix(userXtX(index), fullXtX)
// Add regularization
+ val regParam = numRatings(index) * lambda
var i = 0
while (i < rank) {
- fullXtX.data(i * rank + i) += lambda
+ fullXtX.data(i * rank + i) += regParam
i += 1
}
// Solve the resulting matrix, which is symmetric and positive-definite