aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/java/org
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/java/org')
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java32
1 files changed, 23 insertions, 9 deletions
diff --git a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
index b40f552e0d..b150334deb 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
@@ -19,7 +19,6 @@ package org.apache.spark.mllib.recommendation;
import java.io.Serializable;
import java.util.List;
-import java.lang.Math;
import org.junit.After;
import org.junit.Assert;
@@ -46,7 +45,7 @@ public class JavaALSSuite implements Serializable {
System.clearProperty("spark.driver.port");
}
- void validatePrediction(MatrixFactorizationModel model, int users, int products, int features,
+ static void validatePrediction(MatrixFactorizationModel model, int users, int products, int features,
DoubleMatrix trueRatings, double matchThreshold, boolean implicitPrefs, DoubleMatrix truePrefs) {
DoubleMatrix predictedU = new DoubleMatrix(users, features);
List<scala.Tuple2<Object, double[]>> userFeatures = model.userFeatures().toJavaRDD().collect();
@@ -84,15 +83,15 @@ public class JavaALSSuite implements Serializable {
for (int p = 0; p < products; ++p) {
double prediction = predictedRatings.get(u, p);
double truePref = truePrefs.get(u, p);
- double confidence = 1.0 + /* alpha = */ 1.0 * trueRatings.get(u, p);
+ double confidence = 1.0 + /* alpha = */ 1.0 * Math.abs(trueRatings.get(u, p));
double err = confidence * (truePref - prediction) * (truePref - prediction);
sqErr += err;
- denom += 1.0;
+ denom += confidence;
}
}
double rmse = Math.sqrt(sqErr / denom);
Assert.assertTrue(String.format("Confidence-weighted RMSE=%2.4f above threshold of %2.2f",
- rmse, matchThreshold), Math.abs(rmse) < matchThreshold);
+ rmse, matchThreshold), rmse < matchThreshold);
}
}
@@ -103,7 +102,7 @@ public class JavaALSSuite implements Serializable {
int users = 50;
int products = 100;
scala.Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
- users, products, features, 0.7, false);
+ users, products, features, 0.7, false, false);
JavaRDD<Rating> data = sc.parallelize(testData._1());
MatrixFactorizationModel model = ALS.train(data.rdd(), features, iterations);
@@ -117,7 +116,7 @@ public class JavaALSSuite implements Serializable {
int users = 100;
int products = 200;
scala.Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
- users, products, features, 0.7, false);
+ users, products, features, 0.7, false, false);
JavaRDD<Rating> data = sc.parallelize(testData._1());
@@ -134,7 +133,7 @@ public class JavaALSSuite implements Serializable {
int users = 80;
int products = 160;
scala.Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
- users, products, features, 0.7, true);
+ users, products, features, 0.7, true, false);
JavaRDD<Rating> data = sc.parallelize(testData._1());
MatrixFactorizationModel model = ALS.trainImplicit(data.rdd(), features, iterations);
@@ -148,7 +147,7 @@ public class JavaALSSuite implements Serializable {
int users = 100;
int products = 200;
scala.Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
- users, products, features, 0.7, true);
+ users, products, features, 0.7, true, false);
JavaRDD<Rating> data = sc.parallelize(testData._1());
@@ -158,4 +157,19 @@ public class JavaALSSuite implements Serializable {
.run(data.rdd());
validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3());
}
+
+ @Test
+ public void runImplicitALSWithNegativeWeight() {
+ int features = 2;
+ int iterations = 15;
+ int users = 80;
+ int products = 160;
+ scala.Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
+ users, products, features, 0.7, true, true);
+
+ JavaRDD<Rating> data = sc.parallelize(testData._1());
+ MatrixFactorizationModel model = ALS.trainImplicit(data.rdd(), features, iterations);
+ validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3());
+ }
+
}