diff options
author | Nick Pentreath <nick.pentreath@gmail.com> | 2013-09-06 14:45:05 +0200 |
---|---|---|
committer | Nick Pentreath <nick.pentreath@gmail.com> | 2013-09-06 14:45:05 +0200 |
commit | 737f01a1ef49e4a12f24799c4324b3a60712758e (patch) | |
tree | 03d80efb497b71fff4a25829498eb09f76fad9c6 /mllib/src/test | |
parent | a106ed8b97e707b36818c11d1d7211fa28636178 (diff) | |
download | spark-737f01a1ef49e4a12f24799c4324b3a60712758e.tar.gz spark-737f01a1ef49e4a12f24799c4324b3a60712758e.tar.bz2 spark-737f01a1ef49e4a12f24799c4324b3a60712758e.zip |
Adding algorithm for implicit feedback data to ALS
Diffstat (limited to 'mllib/src/test')
-rw-r--r-- | mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java | 77 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala | 71 |
2 files changed, 120 insertions, 28 deletions
diff --git a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java index 3323f6cee2..ec545efcfa 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java @@ -19,6 +19,7 @@ package org.apache.spark.mllib.recommendation; import java.io.Serializable; import java.util.List; +import java.lang.Math; import scala.Tuple2; @@ -48,7 +49,7 @@ public class JavaALSSuite implements Serializable { } void validatePrediction(MatrixFactorizationModel model, int users, int products, int features, - DoubleMatrix trueRatings, double matchThreshold) { + DoubleMatrix trueRatings, double matchThreshold, boolean implicitPrefs, DoubleMatrix truePrefs) { DoubleMatrix predictedU = new DoubleMatrix(users, features); List<scala.Tuple2<Object, double[]>> userFeatures = model.userFeatures().toJavaRDD().collect(); for (int i = 0; i < features; ++i) { @@ -68,12 +69,32 @@ public class JavaALSSuite implements Serializable { DoubleMatrix predictedRatings = predictedU.mmul(predictedP.transpose()); - for (int u = 0; u < users; ++u) { - for (int p = 0; p < products; ++p) { - double prediction = predictedRatings.get(u, p); - double correct = trueRatings.get(u, p); - Assert.assertTrue(Math.abs(prediction - correct) < matchThreshold); + if (!implicitPrefs) { + for (int u = 0; u < users; ++u) { + for (int p = 0; p < products; ++p) { + double prediction = predictedRatings.get(u, p); + double correct = trueRatings.get(u, p); + Assert.assertTrue(String.format("Prediction=%2.4f not below match threshold of %2.2f", + prediction, matchThreshold), Math.abs(prediction - correct) < matchThreshold); + } } + } else { + // For implicit prefs we use the confidence-weighted RMSE to test (ref Mahout's implicit ALS tests) + double sqErr = 0.0; + double denom = 0.0; + for (int u = 0; u < users; ++u) { + for (int p = 0; p < products; ++p) { + double prediction = predictedRatings.get(u, p); + double truePref = truePrefs.get(u, p); + double confidence = 1.0 + /* alpha = */ 1.0 * trueRatings.get(u, p); + double err = confidence * (truePref - prediction) * (truePref - prediction); + sqErr += err; + denom += 1.0; + } + } + double rmse = Math.sqrt(sqErr / denom); + Assert.assertTrue(String.format("Confidence-weighted RMSE=%2.4f above threshold of %2.2f", + rmse, matchThreshold), Math.abs(rmse) < matchThreshold); } } @@ -83,12 +104,12 @@ public class JavaALSSuite implements Serializable { int iterations = 15; int users = 10; int products = 10; - scala.Tuple2<List<Rating>, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( - users, products, features, 0.7); + scala.Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( + users, products, features, 0.7, false); JavaRDD<Rating> data = sc.parallelize(testData._1()); MatrixFactorizationModel model = ALS.train(data.rdd(), features, iterations); - validatePrediction(model, users, products, features, testData._2(), 0.3); + validatePrediction(model, users, products, features, testData._2(), 0.3, false, testData._3()); } @Test @@ -97,14 +118,46 @@ public class JavaALSSuite implements Serializable { int iterations = 15; int users = 20; int products = 30; - scala.Tuple2<List<Rating>, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( - users, products, features, 0.7); + scala.Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( + users, products, features, 0.7, false); JavaRDD<Rating> data = sc.parallelize(testData._1()); MatrixFactorizationModel model = new ALS().setRank(features) .setIterations(iterations) .run(data.rdd()); - validatePrediction(model, users, products, features, testData._2(), 0.3); + validatePrediction(model, users, products, features, testData._2(), 0.3, false, testData._3()); + } + + @Test + public void runImplicitALSUsingStaticMethods() { + int features = 1; + int iterations = 15; + int users = 40; + int products = 80; + scala.Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( + users, products, features, 0.7, true); + + JavaRDD<Rating> data = sc.parallelize(testData._1()); + MatrixFactorizationModel model = ALS.trainImplicit(data.rdd(), features, iterations); + validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3()); + } + + @Test + public void runImplicitALSUsingConstructor() { + int features = 2; + int iterations = 15; + int users = 50; + int products = 100; + scala.Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( + users, products, features, 0.7, true); + + JavaRDD<Rating> data = sc.parallelize(testData._1()); + + MatrixFactorizationModel model = new ALS().setRank(features) + .setIterations(iterations) + .setImplicitPrefs(true) + .run(data.rdd()); + validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3()); } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala index 347ef238f4..1ab181d35a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala @@ -34,16 +34,19 @@ object ALSSuite { users: Int, products: Int, features: Int, - samplingRate: Double): (java.util.List[Rating], DoubleMatrix) = { - val (sampledRatings, trueRatings) = generateRatings(users, products, features, samplingRate) - (seqAsJavaList(sampledRatings), trueRatings) + samplingRate: Double, + implicitPrefs: Boolean): (java.util.List[Rating], DoubleMatrix, DoubleMatrix) = { + val (sampledRatings, trueRatings, truePrefs) = + generateRatings(users, products, features, samplingRate, implicitPrefs) + (seqAsJavaList(sampledRatings), trueRatings, truePrefs) } def generateRatings( users: Int, products: Int, features: Int, - samplingRate: Double): (Seq[Rating], DoubleMatrix) = { + samplingRate: Double, + implicitPrefs: Boolean = false): (Seq[Rating], DoubleMatrix, DoubleMatrix) = { val rand = new Random(42) // Create a random matrix with uniform values from -1 to 1 @@ -52,14 +55,20 @@ object ALSSuite { val userMatrix = randomMatrix(users, features) val productMatrix = randomMatrix(features, products) - val trueRatings = userMatrix.mmul(productMatrix) + val (trueRatings, truePrefs) = implicitPrefs match { + case true => + val raw = new DoubleMatrix(users, products, Array.fill(users * products)(rand.nextInt(10).toDouble): _*) + val prefs = new DoubleMatrix(users, products, raw.data.map(v => if (v > 0) 1.0 else 0.0): _*) + (raw, prefs) + case false => (userMatrix.mmul(productMatrix), null) + } val sampledRatings = { for (u <- 0 until users; p <- 0 until products if rand.nextDouble() < samplingRate) yield Rating(u, p, trueRatings.get(u, p)) } - (sampledRatings, trueRatings) + (sampledRatings, trueRatings, truePrefs) } } @@ -85,6 +94,14 @@ class ALSSuite extends FunSuite with BeforeAndAfterAll { testALS(20, 30, 2, 15, 0.7, 0.3) } + test("rank-1 matrices implicit") { + testALS(40, 80, 1, 15, 0.7, 0.4, true) + } + + test("rank-2 matrices implicit") { + testALS(50, 100, 2, 15, 0.7, 0.4, true) + } + /** * Test if we can correctly factorize R = U * P where U and P are of known rank. * @@ -96,11 +113,14 @@ class ALSSuite extends FunSuite with BeforeAndAfterAll { * @param matchThreshold max difference allowed to consider a predicted rating correct */ def testALS(users: Int, products: Int, features: Int, iterations: Int, - samplingRate: Double, matchThreshold: Double) + samplingRate: Double, matchThreshold: Double, implicitPrefs: Boolean = false) { - val (sampledRatings, trueRatings) = ALSSuite.generateRatings(users, products, - features, samplingRate) - val model = ALS.train(sc.parallelize(sampledRatings), features, iterations) + val (sampledRatings, trueRatings, truePrefs) = ALSSuite.generateRatings(users, products, + features, samplingRate, implicitPrefs) + val model = implicitPrefs match { + case false => ALS.train(sc.parallelize(sampledRatings), features, iterations) + case true => ALS.trainImplicit(sc.parallelize(sampledRatings), features, iterations) + } val predictedU = new DoubleMatrix(users, features) for ((u, vec) <- model.userFeatures.collect(); i <- 0 until features) { @@ -112,12 +132,31 @@ class ALSSuite extends FunSuite with BeforeAndAfterAll { } val predictedRatings = predictedU.mmul(predictedP.transpose) - for (u <- 0 until users; p <- 0 until products) { - val prediction = predictedRatings.get(u, p) - val correct = trueRatings.get(u, p) - if (math.abs(prediction - correct) > matchThreshold) { - fail("Model failed to predict (%d, %d): %f vs %f\ncorr: %s\npred: %s\nU: %s\n P: %s".format( - u, p, correct, prediction, trueRatings, predictedRatings, predictedU, predictedP)) + if (!implicitPrefs) { + for (u <- 0 until users; p <- 0 until products) { + val prediction = predictedRatings.get(u, p) + val correct = trueRatings.get(u, p) + if (math.abs(prediction - correct) > matchThreshold) { + fail("Model failed to predict (%d, %d): %f vs %f\ncorr: %s\npred: %s\nU: %s\n P: %s".format( + u, p, correct, prediction, trueRatings, predictedRatings, predictedU, predictedP)) + } + } + } else { + // For implicit prefs we use the confidence-weighted RMSE to test (ref Mahout's tests) + var sqErr = 0.0 + var denom = 0.0 + for (u <- 0 until users; p <- 0 until products) { + val prediction = predictedRatings.get(u, p) + val truePref = truePrefs.get(u, p) + val confidence = 1 + 1.0 * trueRatings.get(u, p) + val err = confidence * (truePref - prediction) * (truePref - prediction) + sqErr += err + denom += 1 + } + val rmse = math.sqrt(sqErr / denom) + if (math.abs(rmse) > matchThreshold) { + fail("Model failed to predict RMSE: %f\ncorr: %s\npred: %s\nU: %s\n P: %s".format( + rmse, truePrefs, predictedRatings, predictedU, predictedP)) } } } |