diff options
author | Sean Owen <sowen@cloudera.com> | 2014-02-19 23:44:53 -0800 |
---|---|---|
committer | Reynold Xin <rxin@apache.org> | 2014-02-19 23:44:53 -0800 |
commit | 9e63f80e75bb6d9bbe6df268908c3219de6852d9 (patch) | |
tree | 40018b5ef81b996c3480783543178fac88c346a5 /mllib/src/test/java/org | |
parent | f9b7d64a4e7dd03be672728335cb72df4be5dbf6 (diff) | |
download | spark-9e63f80e75bb6d9bbe6df268908c3219de6852d9.tar.gz spark-9e63f80e75bb6d9bbe6df268908c3219de6852d9.tar.bz2 spark-9e63f80e75bb6d9bbe6df268908c3219de6852d9.zip |
MLLIB-22. Support negative implicit input in ALS
I'm back with another less trivial suggestion for ALS:
In ALS for implicit feedback, input values are treated as weights on squared-errors in a loss function (or rather, the weight is a simple function of the input r, like c = 1 + alpha*r). The paper on which it's based assumes that the input is positive. Indeed, if the input is negative, it will create a negative weight on squared-errors, which causes things to go haywire. The optimization will try to make the error in a cell as large possible, and the result is silently bogus.
There is a good use case for negative input values though. Implicit feedback is usually collected from signals of positive interaction like a view or like or buy, but equally, can come from "not interested" signals. The natural representation is negative values.
The algorithm can be extended quite simply to provide a sound interpretation of these values: negative values should encourage the factorization to come up with 0 for cells with large negative input values, just as much as positive values encourage it to come up with 1.
The implications for the algorithm are simple:
* the confidence function value must not be negative, and so can become 1 + alpha*|r|
* the matrix P should have a value 1 where the input R is _positive_, not merely where it is non-zero. Actually, that's what the paper already says, it's just that we can't assume P = 1 when a cell in R is specified anymore, since it may be negative
This in turn entails just a few lines of code change in `ALS.scala`:
* `rs(i)` becomes `abs(rs(i))`
* When constructing `userXy(us(i))`, it's implicitly only adding where P is 1. That had been true for any us(i) that is iterated over, before, since these are exactly the ones for which P is 1. But now P is zero where rs(i) <= 0, and should not be added
I think it's a safe change because:
* It doesn't change any existing behavior (unless you're using negative values, in which case results are already borked)
* It's the simplest direct extension of the paper's algorithm
* (I've used it to good effect in production FWIW)
Tests included.
I tweaked minor things en route:
* `ALS.scala` javadoc writes "R = Xt*Y" when the paper and rest of code defines it as "R = X*Yt"
* RMSE in the ALS tests uses a confidence-weighted mean, but the denominator is not actually sum of weights
Excuse my Scala style; I'm sure it needs tweaks.
Author: Sean Owen <sowen@cloudera.com>
Closes #500 from srowen/ALSNegativeImplicitInput and squashes the following commits:
cf902a9 [Sean Owen] Support negative implicit input in ALS
953be1c [Sean Owen] Make weighted RMSE in ALS test actually weighted; adjust comment about R = X*Yt
Diffstat (limited to 'mllib/src/test/java/org')
-rw-r--r-- | mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java | 32 |
1 files changed, 23 insertions, 9 deletions
diff --git a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java index b40f552e0d..b150334deb 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java @@ -19,7 +19,6 @@ package org.apache.spark.mllib.recommendation; import java.io.Serializable; import java.util.List; -import java.lang.Math; import org.junit.After; import org.junit.Assert; @@ -46,7 +45,7 @@ public class JavaALSSuite implements Serializable { System.clearProperty("spark.driver.port"); } - void validatePrediction(MatrixFactorizationModel model, int users, int products, int features, + static void validatePrediction(MatrixFactorizationModel model, int users, int products, int features, DoubleMatrix trueRatings, double matchThreshold, boolean implicitPrefs, DoubleMatrix truePrefs) { DoubleMatrix predictedU = new DoubleMatrix(users, features); List<scala.Tuple2<Object, double[]>> userFeatures = model.userFeatures().toJavaRDD().collect(); @@ -84,15 +83,15 @@ public class JavaALSSuite implements Serializable { for (int p = 0; p < products; ++p) { double prediction = predictedRatings.get(u, p); double truePref = truePrefs.get(u, p); - double confidence = 1.0 + /* alpha = */ 1.0 * trueRatings.get(u, p); + double confidence = 1.0 + /* alpha = */ 1.0 * Math.abs(trueRatings.get(u, p)); double err = confidence * (truePref - prediction) * (truePref - prediction); sqErr += err; - denom += 1.0; + denom += confidence; } } double rmse = Math.sqrt(sqErr / denom); Assert.assertTrue(String.format("Confidence-weighted RMSE=%2.4f above threshold of %2.2f", - rmse, matchThreshold), Math.abs(rmse) < matchThreshold); + rmse, matchThreshold), rmse < matchThreshold); } } @@ -103,7 +102,7 @@ public class JavaALSSuite implements Serializable { int users = 50; int products = 100; scala.Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( - users, products, features, 0.7, false); + users, products, features, 0.7, false, false); JavaRDD<Rating> data = sc.parallelize(testData._1()); MatrixFactorizationModel model = ALS.train(data.rdd(), features, iterations); @@ -117,7 +116,7 @@ public class JavaALSSuite implements Serializable { int users = 100; int products = 200; scala.Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( - users, products, features, 0.7, false); + users, products, features, 0.7, false, false); JavaRDD<Rating> data = sc.parallelize(testData._1()); @@ -134,7 +133,7 @@ public class JavaALSSuite implements Serializable { int users = 80; int products = 160; scala.Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( - users, products, features, 0.7, true); + users, products, features, 0.7, true, false); JavaRDD<Rating> data = sc.parallelize(testData._1()); MatrixFactorizationModel model = ALS.trainImplicit(data.rdd(), features, iterations); @@ -148,7 +147,7 @@ public class JavaALSSuite implements Serializable { int users = 100; int products = 200; scala.Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( - users, products, features, 0.7, true); + users, products, features, 0.7, true, false); JavaRDD<Rating> data = sc.parallelize(testData._1()); @@ -158,4 +157,19 @@ public class JavaALSSuite implements Serializable { .run(data.rdd()); validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3()); } + + @Test + public void runImplicitALSWithNegativeWeight() { + int features = 2; + int iterations = 15; + int users = 80; + int products = 160; + scala.Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( + users, products, features, 0.7, true, true); + + JavaRDD<Rating> data = sc.parallelize(testData._1()); + MatrixFactorizationModel model = ALS.trainImplicit(data.rdd(), features, iterations); + validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3()); + } + } |