aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorNick Pentreath <nick.pentreath@gmail.com>2013-09-06 14:45:05 +0200
committerNick Pentreath <nick.pentreath@gmail.com>2013-09-06 14:45:05 +0200
commit737f01a1ef49e4a12f24799c4324b3a60712758e (patch)
tree03d80efb497b71fff4a25829498eb09f76fad9c6 /mllib/src/test
parenta106ed8b97e707b36818c11d1d7211fa28636178 (diff)
downloadspark-737f01a1ef49e4a12f24799c4324b3a60712758e.tar.gz
spark-737f01a1ef49e4a12f24799c4324b3a60712758e.tar.bz2
spark-737f01a1ef49e4a12f24799c4324b3a60712758e.zip
Adding algorithm for implicit feedback data to ALS
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java77
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala71
2 files changed, 120 insertions, 28 deletions
diff --git a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
index 3323f6cee2..ec545efcfa 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
@@ -19,6 +19,7 @@ package org.apache.spark.mllib.recommendation;
import java.io.Serializable;
import java.util.List;
+import java.lang.Math;
import scala.Tuple2;
@@ -48,7 +49,7 @@ public class JavaALSSuite implements Serializable {
}
void validatePrediction(MatrixFactorizationModel model, int users, int products, int features,
- DoubleMatrix trueRatings, double matchThreshold) {
+ DoubleMatrix trueRatings, double matchThreshold, boolean implicitPrefs, DoubleMatrix truePrefs) {
DoubleMatrix predictedU = new DoubleMatrix(users, features);
List<scala.Tuple2<Object, double[]>> userFeatures = model.userFeatures().toJavaRDD().collect();
for (int i = 0; i < features; ++i) {
@@ -68,12 +69,32 @@ public class JavaALSSuite implements Serializable {
DoubleMatrix predictedRatings = predictedU.mmul(predictedP.transpose());
- for (int u = 0; u < users; ++u) {
- for (int p = 0; p < products; ++p) {
- double prediction = predictedRatings.get(u, p);
- double correct = trueRatings.get(u, p);
- Assert.assertTrue(Math.abs(prediction - correct) < matchThreshold);
+ if (!implicitPrefs) {
+ for (int u = 0; u < users; ++u) {
+ for (int p = 0; p < products; ++p) {
+ double prediction = predictedRatings.get(u, p);
+ double correct = trueRatings.get(u, p);
+ Assert.assertTrue(String.format("Prediction=%2.4f not below match threshold of %2.2f",
+ prediction, matchThreshold), Math.abs(prediction - correct) < matchThreshold);
+ }
}
+ } else {
+ // For implicit prefs we use the confidence-weighted RMSE to test (ref Mahout's implicit ALS tests)
+ double sqErr = 0.0;
+ double denom = 0.0;
+ for (int u = 0; u < users; ++u) {
+ for (int p = 0; p < products; ++p) {
+ double prediction = predictedRatings.get(u, p);
+ double truePref = truePrefs.get(u, p);
+ double confidence = 1.0 + /* alpha = */ 1.0 * trueRatings.get(u, p);
+ double err = confidence * (truePref - prediction) * (truePref - prediction);
+ sqErr += err;
+ denom += 1.0;
+ }
+ }
+ double rmse = Math.sqrt(sqErr / denom);
+ Assert.assertTrue(String.format("Confidence-weighted RMSE=%2.4f above threshold of %2.2f",
+ rmse, matchThreshold), Math.abs(rmse) < matchThreshold);
}
}
@@ -83,12 +104,12 @@ public class JavaALSSuite implements Serializable {
int iterations = 15;
int users = 10;
int products = 10;
- scala.Tuple2<List<Rating>, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
- users, products, features, 0.7);
+ scala.Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
+ users, products, features, 0.7, false);
JavaRDD<Rating> data = sc.parallelize(testData._1());
MatrixFactorizationModel model = ALS.train(data.rdd(), features, iterations);
- validatePrediction(model, users, products, features, testData._2(), 0.3);
+ validatePrediction(model, users, products, features, testData._2(), 0.3, false, testData._3());
}
@Test
@@ -97,14 +118,46 @@ public class JavaALSSuite implements Serializable {
int iterations = 15;
int users = 20;
int products = 30;
- scala.Tuple2<List<Rating>, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
- users, products, features, 0.7);
+ scala.Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
+ users, products, features, 0.7, false);
JavaRDD<Rating> data = sc.parallelize(testData._1());
MatrixFactorizationModel model = new ALS().setRank(features)
.setIterations(iterations)
.run(data.rdd());
- validatePrediction(model, users, products, features, testData._2(), 0.3);
+ validatePrediction(model, users, products, features, testData._2(), 0.3, false, testData._3());
+ }
+
+ @Test
+ public void runImplicitALSUsingStaticMethods() {
+ int features = 1;
+ int iterations = 15;
+ int users = 40;
+ int products = 80;
+ scala.Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
+ users, products, features, 0.7, true);
+
+ JavaRDD<Rating> data = sc.parallelize(testData._1());
+ MatrixFactorizationModel model = ALS.trainImplicit(data.rdd(), features, iterations);
+ validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3());
+ }
+
+ @Test
+ public void runImplicitALSUsingConstructor() {
+ int features = 2;
+ int iterations = 15;
+ int users = 50;
+ int products = 100;
+ scala.Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
+ users, products, features, 0.7, true);
+
+ JavaRDD<Rating> data = sc.parallelize(testData._1());
+
+ MatrixFactorizationModel model = new ALS().setRank(features)
+ .setIterations(iterations)
+ .setImplicitPrefs(true)
+ .run(data.rdd());
+ validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3());
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
index 347ef238f4..1ab181d35a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
@@ -34,16 +34,19 @@ object ALSSuite {
users: Int,
products: Int,
features: Int,
- samplingRate: Double): (java.util.List[Rating], DoubleMatrix) = {
- val (sampledRatings, trueRatings) = generateRatings(users, products, features, samplingRate)
- (seqAsJavaList(sampledRatings), trueRatings)
+ samplingRate: Double,
+ implicitPrefs: Boolean): (java.util.List[Rating], DoubleMatrix, DoubleMatrix) = {
+ val (sampledRatings, trueRatings, truePrefs) =
+ generateRatings(users, products, features, samplingRate, implicitPrefs)
+ (seqAsJavaList(sampledRatings), trueRatings, truePrefs)
}
def generateRatings(
users: Int,
products: Int,
features: Int,
- samplingRate: Double): (Seq[Rating], DoubleMatrix) = {
+ samplingRate: Double,
+ implicitPrefs: Boolean = false): (Seq[Rating], DoubleMatrix, DoubleMatrix) = {
val rand = new Random(42)
// Create a random matrix with uniform values from -1 to 1
@@ -52,14 +55,20 @@ object ALSSuite {
val userMatrix = randomMatrix(users, features)
val productMatrix = randomMatrix(features, products)
- val trueRatings = userMatrix.mmul(productMatrix)
+ val (trueRatings, truePrefs) = implicitPrefs match {
+ case true =>
+ val raw = new DoubleMatrix(users, products, Array.fill(users * products)(rand.nextInt(10).toDouble): _*)
+ val prefs = new DoubleMatrix(users, products, raw.data.map(v => if (v > 0) 1.0 else 0.0): _*)
+ (raw, prefs)
+ case false => (userMatrix.mmul(productMatrix), null)
+ }
val sampledRatings = {
for (u <- 0 until users; p <- 0 until products if rand.nextDouble() < samplingRate)
yield Rating(u, p, trueRatings.get(u, p))
}
- (sampledRatings, trueRatings)
+ (sampledRatings, trueRatings, truePrefs)
}
}
@@ -85,6 +94,14 @@ class ALSSuite extends FunSuite with BeforeAndAfterAll {
testALS(20, 30, 2, 15, 0.7, 0.3)
}
+ test("rank-1 matrices implicit") {
+ testALS(40, 80, 1, 15, 0.7, 0.4, true)
+ }
+
+ test("rank-2 matrices implicit") {
+ testALS(50, 100, 2, 15, 0.7, 0.4, true)
+ }
+
/**
* Test if we can correctly factorize R = U * P where U and P are of known rank.
*
@@ -96,11 +113,14 @@ class ALSSuite extends FunSuite with BeforeAndAfterAll {
* @param matchThreshold max difference allowed to consider a predicted rating correct
*/
def testALS(users: Int, products: Int, features: Int, iterations: Int,
- samplingRate: Double, matchThreshold: Double)
+ samplingRate: Double, matchThreshold: Double, implicitPrefs: Boolean = false)
{
- val (sampledRatings, trueRatings) = ALSSuite.generateRatings(users, products,
- features, samplingRate)
- val model = ALS.train(sc.parallelize(sampledRatings), features, iterations)
+ val (sampledRatings, trueRatings, truePrefs) = ALSSuite.generateRatings(users, products,
+ features, samplingRate, implicitPrefs)
+ val model = implicitPrefs match {
+ case false => ALS.train(sc.parallelize(sampledRatings), features, iterations)
+ case true => ALS.trainImplicit(sc.parallelize(sampledRatings), features, iterations)
+ }
val predictedU = new DoubleMatrix(users, features)
for ((u, vec) <- model.userFeatures.collect(); i <- 0 until features) {
@@ -112,12 +132,31 @@ class ALSSuite extends FunSuite with BeforeAndAfterAll {
}
val predictedRatings = predictedU.mmul(predictedP.transpose)
- for (u <- 0 until users; p <- 0 until products) {
- val prediction = predictedRatings.get(u, p)
- val correct = trueRatings.get(u, p)
- if (math.abs(prediction - correct) > matchThreshold) {
- fail("Model failed to predict (%d, %d): %f vs %f\ncorr: %s\npred: %s\nU: %s\n P: %s".format(
- u, p, correct, prediction, trueRatings, predictedRatings, predictedU, predictedP))
+ if (!implicitPrefs) {
+ for (u <- 0 until users; p <- 0 until products) {
+ val prediction = predictedRatings.get(u, p)
+ val correct = trueRatings.get(u, p)
+ if (math.abs(prediction - correct) > matchThreshold) {
+ fail("Model failed to predict (%d, %d): %f vs %f\ncorr: %s\npred: %s\nU: %s\n P: %s".format(
+ u, p, correct, prediction, trueRatings, predictedRatings, predictedU, predictedP))
+ }
+ }
+ } else {
+ // For implicit prefs we use the confidence-weighted RMSE to test (ref Mahout's tests)
+ var sqErr = 0.0
+ var denom = 0.0
+ for (u <- 0 until users; p <- 0 until products) {
+ val prediction = predictedRatings.get(u, p)
+ val truePref = truePrefs.get(u, p)
+ val confidence = 1 + 1.0 * trueRatings.get(u, p)
+ val err = confidence * (truePref - prediction) * (truePref - prediction)
+ sqErr += err
+ denom += 1
+ }
+ val rmse = math.sqrt(sqErr / denom)
+ if (math.abs(rmse) > matchThreshold) {
+ fail("Model failed to predict RMSE: %f\ncorr: %s\npred: %s\nU: %s\n P: %s".format(
+ rmse, truePrefs, predictedRatings, predictedU, predictedP))
}
}
}