aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-11-13 11:42:27 -0800
committerXiangrui Meng <meng@databricks.com>2014-11-13 11:42:27 -0800
commitca26a212fda39a15fde09dfdb2fbe69580a717f6 (patch)
treee6473ef3a55f9c353642fafb4c779e8a2546d50c /mllib/src/test
parentce0333f9a008348692bb9a200449d2d992e7825e (diff)
downloadspark-ca26a212fda39a15fde09dfdb2fbe69580a717f6.tar.gz
spark-ca26a212fda39a15fde09dfdb2fbe69580a717f6.tar.bz2
spark-ca26a212fda39a15fde09dfdb2fbe69580a717f6.zip
[SPARK-4378][MLLIB] make ALS more Java-friendly
Add Java-friendly version of `run` and `predict`, and use bulk prediction in Java unit tests. The user guide update will come later (though we may not save many lines of code there). srowen Author: Xiangrui Meng <meng@databricks.com> Closes #3240 from mengxr/SPARK-4378 and squashes the following commits: 6581503 [Xiangrui Meng] check number of predictions 6c8bbd1 [Xiangrui Meng] make ALS more Java-friendly
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java74
1 files changed, 31 insertions, 43 deletions
diff --git a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
index f6ca964322..af688c504c 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
@@ -23,13 +23,14 @@ import java.util.List;
import scala.Tuple2;
import scala.Tuple3;
+import com.google.common.collect.Lists;
import org.jblas.DoubleMatrix;
-
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
@@ -47,61 +48,48 @@ public class JavaALSSuite implements Serializable {
sc = null;
}
- static void validatePrediction(
+ void validatePrediction(
MatrixFactorizationModel model,
int users,
int products,
- int features,
DoubleMatrix trueRatings,
double matchThreshold,
boolean implicitPrefs,
DoubleMatrix truePrefs) {
- DoubleMatrix predictedU = new DoubleMatrix(users, features);
- List<Tuple2<Object, double[]>> userFeatures = model.userFeatures().toJavaRDD().collect();
- for (int i = 0; i < features; ++i) {
- for (Tuple2<Object, double[]> userFeature : userFeatures) {
- predictedU.put((Integer)userFeature._1(), i, userFeature._2()[i]);
- }
- }
- DoubleMatrix predictedP = new DoubleMatrix(products, features);
-
- List<Tuple2<Object, double[]>> productFeatures =
- model.productFeatures().toJavaRDD().collect();
- for (int i = 0; i < features; ++i) {
- for (Tuple2<Object, double[]> productFeature : productFeatures) {
- predictedP.put((Integer)productFeature._1(), i, productFeature._2()[i]);
+ List<Tuple2<Integer, Integer>> localUsersProducts =
+ Lists.newArrayListWithCapacity(users * products);
+ for (int u=0; u < users; ++u) {
+ for (int p=0; p < products; ++p) {
+ localUsersProducts.add(new Tuple2<Integer, Integer>(u, p));
}
}
-
- DoubleMatrix predictedRatings = predictedU.mmul(predictedP.transpose());
-
+ JavaPairRDD<Integer, Integer> usersProducts = sc.parallelizePairs(localUsersProducts);
+ List<Rating> predictedRatings = model.predict(usersProducts).collect();
+ Assert.assertEquals(users * products, predictedRatings.size());
if (!implicitPrefs) {
- for (int u = 0; u < users; ++u) {
- for (int p = 0; p < products; ++p) {
- double prediction = predictedRatings.get(u, p);
- double correct = trueRatings.get(u, p);
- Assert.assertTrue(String.format("Prediction=%2.4f not below match threshold of %2.2f",
- prediction, matchThreshold), Math.abs(prediction - correct) < matchThreshold);
- }
+ for (Rating r: predictedRatings) {
+ double prediction = r.rating();
+ double correct = trueRatings.get(r.user(), r.product());
+ Assert.assertTrue(String.format("Prediction=%2.4f not below match threshold of %2.2f",
+ prediction, matchThreshold), Math.abs(prediction - correct) < matchThreshold);
}
} else {
// For implicit prefs we use the confidence-weighted RMSE to test
// (ref Mahout's implicit ALS tests)
double sqErr = 0.0;
double denom = 0.0;
- for (int u = 0; u < users; ++u) {
- for (int p = 0; p < products; ++p) {
- double prediction = predictedRatings.get(u, p);
- double truePref = truePrefs.get(u, p);
- double confidence = 1.0 + /* alpha = */ 1.0 * Math.abs(trueRatings.get(u, p));
- double err = confidence * (truePref - prediction) * (truePref - prediction);
- sqErr += err;
- denom += confidence;
- }
+ for (Rating r: predictedRatings) {
+ double prediction = r.rating();
+ double truePref = truePrefs.get(r.user(), r.product());
+ double confidence = 1.0 +
+ /* alpha = */ 1.0 * Math.abs(trueRatings.get(r.user(), r.product()));
+ double err = confidence * (truePref - prediction) * (truePref - prediction);
+ sqErr += err;
+ denom += confidence;
}
double rmse = Math.sqrt(sqErr / denom);
Assert.assertTrue(String.format("Confidence-weighted RMSE=%2.4f above threshold of %2.2f",
- rmse, matchThreshold), rmse < matchThreshold);
+ rmse, matchThreshold), rmse < matchThreshold);
}
}
@@ -116,7 +104,7 @@ public class JavaALSSuite implements Serializable {
JavaRDD<Rating> data = sc.parallelize(testData._1());
MatrixFactorizationModel model = ALS.train(data.rdd(), features, iterations);
- validatePrediction(model, users, products, features, testData._2(), 0.3, false, testData._3());
+ validatePrediction(model, users, products, testData._2(), 0.3, false, testData._3());
}
@Test
@@ -132,8 +120,8 @@ public class JavaALSSuite implements Serializable {
MatrixFactorizationModel model = new ALS().setRank(features)
.setIterations(iterations)
- .run(data.rdd());
- validatePrediction(model, users, products, features, testData._2(), 0.3, false, testData._3());
+ .run(data);
+ validatePrediction(model, users, products, testData._2(), 0.3, false, testData._3());
}
@Test
@@ -147,7 +135,7 @@ public class JavaALSSuite implements Serializable {
JavaRDD<Rating> data = sc.parallelize(testData._1());
MatrixFactorizationModel model = ALS.trainImplicit(data.rdd(), features, iterations);
- validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3());
+ validatePrediction(model, users, products, testData._2(), 0.4, true, testData._3());
}
@Test
@@ -165,7 +153,7 @@ public class JavaALSSuite implements Serializable {
.setIterations(iterations)
.setImplicitPrefs(true)
.run(data.rdd());
- validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3());
+ validatePrediction(model, users, products, testData._2(), 0.4, true, testData._3());
}
@Test
@@ -183,7 +171,7 @@ public class JavaALSSuite implements Serializable {
.setImplicitPrefs(true)
.setSeed(8675309L)
.run(data.rdd());
- validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3());
+ validatePrediction(model, users, products, testData._2(), 0.4, true, testData._3());
}
@Test