diff options
author | Xiangrui Meng <meng@databricks.com> | 2014-11-13 11:42:27 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2014-11-13 11:42:27 -0800 |
commit | ca26a212fda39a15fde09dfdb2fbe69580a717f6 (patch) | |
tree | e6473ef3a55f9c353642fafb4c779e8a2546d50c /mllib/src/test | |
parent | ce0333f9a008348692bb9a200449d2d992e7825e (diff) | |
download | spark-ca26a212fda39a15fde09dfdb2fbe69580a717f6.tar.gz spark-ca26a212fda39a15fde09dfdb2fbe69580a717f6.tar.bz2 spark-ca26a212fda39a15fde09dfdb2fbe69580a717f6.zip |
[SPARK-4378][MLLIB] make ALS more Java-friendly
Add Java-friendly version of `run` and `predict`, and use bulk prediction in Java unit tests. The user guide update will come later (though we may not save many lines of code there). srowen
Author: Xiangrui Meng <meng@databricks.com>
Closes #3240 from mengxr/SPARK-4378 and squashes the following commits:
6581503 [Xiangrui Meng] check number of predictions
6c8bbd1 [Xiangrui Meng] make ALS more Java-friendly
Diffstat (limited to 'mllib/src/test')
-rw-r--r-- | mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java | 74 |
1 files changed, 31 insertions, 43 deletions
diff --git a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java index f6ca964322..af688c504c 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java @@ -23,13 +23,14 @@ import java.util.List; import scala.Tuple2; import scala.Tuple3; +import com.google.common.collect.Lists; import org.jblas.DoubleMatrix; - import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; +import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; @@ -47,61 +48,48 @@ public class JavaALSSuite implements Serializable { sc = null; } - static void validatePrediction( + void validatePrediction( MatrixFactorizationModel model, int users, int products, - int features, DoubleMatrix trueRatings, double matchThreshold, boolean implicitPrefs, DoubleMatrix truePrefs) { - DoubleMatrix predictedU = new DoubleMatrix(users, features); - List<Tuple2<Object, double[]>> userFeatures = model.userFeatures().toJavaRDD().collect(); - for (int i = 0; i < features; ++i) { - for (Tuple2<Object, double[]> userFeature : userFeatures) { - predictedU.put((Integer)userFeature._1(), i, userFeature._2()[i]); - } - } - DoubleMatrix predictedP = new DoubleMatrix(products, features); - - List<Tuple2<Object, double[]>> productFeatures = - model.productFeatures().toJavaRDD().collect(); - for (int i = 0; i < features; ++i) { - for (Tuple2<Object, double[]> productFeature : productFeatures) { - predictedP.put((Integer)productFeature._1(), i, productFeature._2()[i]); + List<Tuple2<Integer, Integer>> localUsersProducts = + Lists.newArrayListWithCapacity(users * products); + for (int u=0; u < users; ++u) { + for (int p=0; p < products; ++p) { + localUsersProducts.add(new Tuple2<Integer, Integer>(u, p)); } } - - DoubleMatrix predictedRatings = predictedU.mmul(predictedP.transpose()); - + JavaPairRDD<Integer, Integer> usersProducts = sc.parallelizePairs(localUsersProducts); + List<Rating> predictedRatings = model.predict(usersProducts).collect(); + Assert.assertEquals(users * products, predictedRatings.size()); if (!implicitPrefs) { - for (int u = 0; u < users; ++u) { - for (int p = 0; p < products; ++p) { - double prediction = predictedRatings.get(u, p); - double correct = trueRatings.get(u, p); - Assert.assertTrue(String.format("Prediction=%2.4f not below match threshold of %2.2f", - prediction, matchThreshold), Math.abs(prediction - correct) < matchThreshold); - } + for (Rating r: predictedRatings) { + double prediction = r.rating(); + double correct = trueRatings.get(r.user(), r.product()); + Assert.assertTrue(String.format("Prediction=%2.4f not below match threshold of %2.2f", + prediction, matchThreshold), Math.abs(prediction - correct) < matchThreshold); } } else { // For implicit prefs we use the confidence-weighted RMSE to test // (ref Mahout's implicit ALS tests) double sqErr = 0.0; double denom = 0.0; - for (int u = 0; u < users; ++u) { - for (int p = 0; p < products; ++p) { - double prediction = predictedRatings.get(u, p); - double truePref = truePrefs.get(u, p); - double confidence = 1.0 + /* alpha = */ 1.0 * Math.abs(trueRatings.get(u, p)); - double err = confidence * (truePref - prediction) * (truePref - prediction); - sqErr += err; - denom += confidence; - } + for (Rating r: predictedRatings) { + double prediction = r.rating(); + double truePref = truePrefs.get(r.user(), r.product()); + double confidence = 1.0 + + /* alpha = */ 1.0 * Math.abs(trueRatings.get(r.user(), r.product())); + double err = confidence * (truePref - prediction) * (truePref - prediction); + sqErr += err; + denom += confidence; } double rmse = Math.sqrt(sqErr / denom); Assert.assertTrue(String.format("Confidence-weighted RMSE=%2.4f above threshold of %2.2f", - rmse, matchThreshold), rmse < matchThreshold); + rmse, matchThreshold), rmse < matchThreshold); } } @@ -116,7 +104,7 @@ public class JavaALSSuite implements Serializable { JavaRDD<Rating> data = sc.parallelize(testData._1()); MatrixFactorizationModel model = ALS.train(data.rdd(), features, iterations); - validatePrediction(model, users, products, features, testData._2(), 0.3, false, testData._3()); + validatePrediction(model, users, products, testData._2(), 0.3, false, testData._3()); } @Test @@ -132,8 +120,8 @@ public class JavaALSSuite implements Serializable { MatrixFactorizationModel model = new ALS().setRank(features) .setIterations(iterations) - .run(data.rdd()); - validatePrediction(model, users, products, features, testData._2(), 0.3, false, testData._3()); + .run(data); + validatePrediction(model, users, products, testData._2(), 0.3, false, testData._3()); } @Test @@ -147,7 +135,7 @@ public class JavaALSSuite implements Serializable { JavaRDD<Rating> data = sc.parallelize(testData._1()); MatrixFactorizationModel model = ALS.trainImplicit(data.rdd(), features, iterations); - validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3()); + validatePrediction(model, users, products, testData._2(), 0.4, true, testData._3()); } @Test @@ -165,7 +153,7 @@ public class JavaALSSuite implements Serializable { .setIterations(iterations) .setImplicitPrefs(true) .run(data.rdd()); - validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3()); + validatePrediction(model, users, products, testData._2(), 0.4, true, testData._3()); } @Test @@ -183,7 +171,7 @@ public class JavaALSSuite implements Serializable { .setImplicitPrefs(true) .setSeed(8675309L) .run(data.rdd()); - validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3()); + validatePrediction(model, users, products, testData._2(), 0.4, true, testData._3()); } @Test |