aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/java/org
diff options
context:
space:
mode:
authorSean Owen <srowen@gmail.com>2014-08-01 07:32:53 -0700
committerXiangrui Meng <meng@databricks.com>2014-08-01 07:32:53 -0700
commit82d209d43fb543c174e640667de15b00c7fb5d35 (patch)
tree98c3675e5be55718c34aca6bedaff0dc1819e66c /mllib/src/test/java/org
parenta32f0fb73a739c56208cafcd9f08618fb6dd8859 (diff)
downloadspark-82d209d43fb543c174e640667de15b00c7fb5d35.tar.gz
spark-82d209d43fb543c174e640667de15b00c7fb5d35.tar.bz2
spark-82d209d43fb543c174e640667de15b00c7fb5d35.zip
SPARK-2768 [MLLIB] Add product, user recommend method to MatrixFactorizationModel
Right now, `MatrixFactorizationModel` can only predict a score for one or more `(user,product)` tuples. As a comment in the file notes, it would be more useful to expose a recommend method, that computes top N scoring products for a user (or vice versa – users for a product). (This also corrects some long lines in the Java ALS test suite.) As you can see, it's a little messy to access the class from Java. Should there be a Java-friendly wrapper for it? with a pointer about where that should go, I could add that. Author: Sean Owen <srowen@gmail.com> Closes #1687 from srowen/SPARK-2768 and squashes the following commits: b349675 [Sean Owen] Additional review changes c9edb04 [Sean Owen] Updates from code review 7bc35f9 [Sean Owen] Add recommend methods to MatrixFactorizationModel
Diffstat (limited to 'mllib/src/test/java/org')
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java75
1 files changed, 58 insertions, 17 deletions
diff --git a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
index bf2365f820..f6ca964322 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
@@ -20,6 +20,11 @@ package org.apache.spark.mllib.recommendation;
import java.io.Serializable;
import java.util.List;
+import scala.Tuple2;
+import scala.Tuple3;
+
+import org.jblas.DoubleMatrix;
+
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
@@ -28,8 +33,6 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
-import org.jblas.DoubleMatrix;
-
public class JavaALSSuite implements Serializable {
private transient JavaSparkContext sc;
@@ -44,21 +47,28 @@ public class JavaALSSuite implements Serializable {
sc = null;
}
- static void validatePrediction(MatrixFactorizationModel model, int users, int products, int features,
- DoubleMatrix trueRatings, double matchThreshold, boolean implicitPrefs, DoubleMatrix truePrefs) {
+ static void validatePrediction(
+ MatrixFactorizationModel model,
+ int users,
+ int products,
+ int features,
+ DoubleMatrix trueRatings,
+ double matchThreshold,
+ boolean implicitPrefs,
+ DoubleMatrix truePrefs) {
DoubleMatrix predictedU = new DoubleMatrix(users, features);
- List<scala.Tuple2<Object, double[]>> userFeatures = model.userFeatures().toJavaRDD().collect();
+ List<Tuple2<Object, double[]>> userFeatures = model.userFeatures().toJavaRDD().collect();
for (int i = 0; i < features; ++i) {
- for (scala.Tuple2<Object, double[]> userFeature : userFeatures) {
+ for (Tuple2<Object, double[]> userFeature : userFeatures) {
predictedU.put((Integer)userFeature._1(), i, userFeature._2()[i]);
}
}
DoubleMatrix predictedP = new DoubleMatrix(products, features);
- List<scala.Tuple2<Object, double[]>> productFeatures =
+ List<Tuple2<Object, double[]>> productFeatures =
model.productFeatures().toJavaRDD().collect();
for (int i = 0; i < features; ++i) {
- for (scala.Tuple2<Object, double[]> productFeature : productFeatures) {
+ for (Tuple2<Object, double[]> productFeature : productFeatures) {
predictedP.put((Integer)productFeature._1(), i, productFeature._2()[i]);
}
}
@@ -75,7 +85,8 @@ public class JavaALSSuite implements Serializable {
}
}
} else {
- // For implicit prefs we use the confidence-weighted RMSE to test (ref Mahout's implicit ALS tests)
+ // For implicit prefs we use the confidence-weighted RMSE to test
+ // (ref Mahout's implicit ALS tests)
double sqErr = 0.0;
double denom = 0.0;
for (int u = 0; u < users; ++u) {
@@ -100,7 +111,7 @@ public class JavaALSSuite implements Serializable {
int iterations = 15;
int users = 50;
int products = 100;
- scala.Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
+ Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
users, products, features, 0.7, false, false);
JavaRDD<Rating> data = sc.parallelize(testData._1());
@@ -114,14 +125,14 @@ public class JavaALSSuite implements Serializable {
int iterations = 15;
int users = 100;
int products = 200;
- scala.Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
+ Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
users, products, features, 0.7, false, false);
JavaRDD<Rating> data = sc.parallelize(testData._1());
MatrixFactorizationModel model = new ALS().setRank(features)
- .setIterations(iterations)
- .run(data.rdd());
+ .setIterations(iterations)
+ .run(data.rdd());
validatePrediction(model, users, products, features, testData._2(), 0.3, false, testData._3());
}
@@ -131,7 +142,7 @@ public class JavaALSSuite implements Serializable {
int iterations = 15;
int users = 80;
int products = 160;
- scala.Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
+ Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
users, products, features, 0.7, true, false);
JavaRDD<Rating> data = sc.parallelize(testData._1());
@@ -145,7 +156,7 @@ public class JavaALSSuite implements Serializable {
int iterations = 15;
int users = 100;
int products = 200;
- scala.Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
+ Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
users, products, features, 0.7, true, false);
JavaRDD<Rating> data = sc.parallelize(testData._1());
@@ -163,12 +174,42 @@ public class JavaALSSuite implements Serializable {
int iterations = 15;
int users = 80;
int products = 160;
- scala.Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
+ Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
users, products, features, 0.7, true, true);
JavaRDD<Rating> data = sc.parallelize(testData._1());
- MatrixFactorizationModel model = ALS.trainImplicit(data.rdd(), features, iterations);
+ MatrixFactorizationModel model = new ALS().setRank(features)
+ .setIterations(iterations)
+ .setImplicitPrefs(true)
+ .setSeed(8675309L)
+ .run(data.rdd());
validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3());
}
+ @Test
+ public void runRecommend() {
+ int features = 5;
+ int iterations = 10;
+ int users = 200;
+ int products = 50;
+ Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
+ users, products, features, 0.7, true, false);
+ JavaRDD<Rating> data = sc.parallelize(testData._1());
+ MatrixFactorizationModel model = new ALS().setRank(features)
+ .setIterations(iterations)
+ .setImplicitPrefs(true)
+ .setSeed(8675309L)
+ .run(data.rdd());
+ validateRecommendations(model.recommendProducts(1, 10), 10);
+ validateRecommendations(model.recommendUsers(1, 20), 20);
+ }
+
+ private static void validateRecommendations(Rating[] recommendations, int howMany) {
+ Assert.assertEquals(howMany, recommendations.length);
+ for (int i = 1; i < recommendations.length; i++) {
+ Assert.assertTrue(recommendations[i-1].rating() >= recommendations[i].rating());
+ }
+ Assert.assertTrue(recommendations[0].rating() > 0.7);
+ }
+
}