aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-11-13 11:42:27 -0800
committerXiangrui Meng <meng@databricks.com>2014-11-13 11:42:27 -0800
commitca26a212fda39a15fde09dfdb2fbe69580a717f6 (patch)
treee6473ef3a55f9c353642fafb4c779e8a2546d50c /mllib
parentce0333f9a008348692bb9a200449d2d992e7825e (diff)
downloadspark-ca26a212fda39a15fde09dfdb2fbe69580a717f6.tar.gz
spark-ca26a212fda39a15fde09dfdb2fbe69580a717f6.tar.bz2
spark-ca26a212fda39a15fde09dfdb2fbe69580a717f6.zip
[SPARK-4378][MLLIB] make ALS more Java-friendly
Add Java-friendly version of `run` and `predict`, and use bulk prediction in Java unit tests. The user guide update will come later (though we may not save many lines of code there). srowen Author: Xiangrui Meng <meng@databricks.com> Closes #3240 from mengxr/SPARK-4378 and squashes the following commits: 6581503 [Xiangrui Meng] check number of predictions 6c8bbd1 [Xiangrui Meng] make ALS more Java-friendly
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala17
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala15
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java74
3 files changed, 53 insertions, 53 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
index 84d192db53..038edc3521 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
@@ -20,20 +20,20 @@ package org.apache.spark.mllib.recommendation
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.math.{abs, sqrt}
-import scala.util.Random
-import scala.util.Sorting
+import scala.util.{Random, Sorting}
import scala.util.hashing.byteswap32
import org.jblas.{DoubleMatrix, SimpleBlas, Solve}
+import org.apache.spark.{HashPartitioner, Logging, Partitioner}
+import org.apache.spark.SparkContext._
import org.apache.spark.annotation.{DeveloperApi, Experimental}
+import org.apache.spark.api.java.JavaRDD
import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.{Logging, HashPartitioner, Partitioner}
-import org.apache.spark.storage.StorageLevel
+import org.apache.spark.mllib.optimization.NNLS
import org.apache.spark.rdd.RDD
-import org.apache.spark.SparkContext._
+import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
-import org.apache.spark.mllib.optimization.NNLS
/**
* Out-link information for a user or product block. This includes the original user/product IDs
@@ -326,6 +326,11 @@ class ALS private (
}
/**
+ * Java-friendly version of [[ALS.run]].
+ */
+ def run(ratings: JavaRDD[Rating]): MatrixFactorizationModel = run(ratings.rdd)
+
+ /**
* Computes the (`rank x rank`) matrix `YtY`, where `Y` is the (`nui x rank`) matrix of factors
* for each user (or product), in a distributed fashion.
*
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
index 66b58ba770..969e23be21 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
@@ -17,13 +17,13 @@
package org.apache.spark.mllib.recommendation
+import java.lang.{Integer => JavaInteger}
+
import org.jblas.DoubleMatrix
-import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.api.java.JavaRDD
-import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._
-import org.apache.spark.mllib.api.python.SerDe
+import org.apache.spark.api.java.{JavaPairRDD, JavaRDD}
+import org.apache.spark.rdd.RDD
/**
* Model representing the result of matrix factorization.
@@ -66,6 +66,13 @@ class MatrixFactorizationModel private[mllib] (
}
/**
+ * Java-friendly version of [[MatrixFactorizationModel.predict]].
+ */
+ def predict(usersProducts: JavaPairRDD[JavaInteger, JavaInteger]): JavaRDD[Rating] = {
+ predict(usersProducts.rdd.asInstanceOf[RDD[(Int, Int)]]).toJavaRDD()
+ }
+
+ /**
* Recommends products to a user.
*
* @param user the user to recommend products to
diff --git a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
index f6ca964322..af688c504c 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
@@ -23,13 +23,14 @@ import java.util.List;
import scala.Tuple2;
import scala.Tuple3;
+import com.google.common.collect.Lists;
import org.jblas.DoubleMatrix;
-
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
@@ -47,61 +48,48 @@ public class JavaALSSuite implements Serializable {
sc = null;
}
- static void validatePrediction(
+ void validatePrediction(
MatrixFactorizationModel model,
int users,
int products,
- int features,
DoubleMatrix trueRatings,
double matchThreshold,
boolean implicitPrefs,
DoubleMatrix truePrefs) {
- DoubleMatrix predictedU = new DoubleMatrix(users, features);
- List<Tuple2<Object, double[]>> userFeatures = model.userFeatures().toJavaRDD().collect();
- for (int i = 0; i < features; ++i) {
- for (Tuple2<Object, double[]> userFeature : userFeatures) {
- predictedU.put((Integer)userFeature._1(), i, userFeature._2()[i]);
- }
- }
- DoubleMatrix predictedP = new DoubleMatrix(products, features);
-
- List<Tuple2<Object, double[]>> productFeatures =
- model.productFeatures().toJavaRDD().collect();
- for (int i = 0; i < features; ++i) {
- for (Tuple2<Object, double[]> productFeature : productFeatures) {
- predictedP.put((Integer)productFeature._1(), i, productFeature._2()[i]);
+ List<Tuple2<Integer, Integer>> localUsersProducts =
+ Lists.newArrayListWithCapacity(users * products);
+ for (int u=0; u < users; ++u) {
+ for (int p=0; p < products; ++p) {
+ localUsersProducts.add(new Tuple2<Integer, Integer>(u, p));
}
}
-
- DoubleMatrix predictedRatings = predictedU.mmul(predictedP.transpose());
-
+ JavaPairRDD<Integer, Integer> usersProducts = sc.parallelizePairs(localUsersProducts);
+ List<Rating> predictedRatings = model.predict(usersProducts).collect();
+ Assert.assertEquals(users * products, predictedRatings.size());
if (!implicitPrefs) {
- for (int u = 0; u < users; ++u) {
- for (int p = 0; p < products; ++p) {
- double prediction = predictedRatings.get(u, p);
- double correct = trueRatings.get(u, p);
- Assert.assertTrue(String.format("Prediction=%2.4f not below match threshold of %2.2f",
- prediction, matchThreshold), Math.abs(prediction - correct) < matchThreshold);
- }
+ for (Rating r: predictedRatings) {
+ double prediction = r.rating();
+ double correct = trueRatings.get(r.user(), r.product());
+ Assert.assertTrue(String.format("Prediction=%2.4f not below match threshold of %2.2f",
+ prediction, matchThreshold), Math.abs(prediction - correct) < matchThreshold);
}
} else {
// For implicit prefs we use the confidence-weighted RMSE to test
// (ref Mahout's implicit ALS tests)
double sqErr = 0.0;
double denom = 0.0;
- for (int u = 0; u < users; ++u) {
- for (int p = 0; p < products; ++p) {
- double prediction = predictedRatings.get(u, p);
- double truePref = truePrefs.get(u, p);
- double confidence = 1.0 + /* alpha = */ 1.0 * Math.abs(trueRatings.get(u, p));
- double err = confidence * (truePref - prediction) * (truePref - prediction);
- sqErr += err;
- denom += confidence;
- }
+ for (Rating r: predictedRatings) {
+ double prediction = r.rating();
+ double truePref = truePrefs.get(r.user(), r.product());
+ double confidence = 1.0 +
+ /* alpha = */ 1.0 * Math.abs(trueRatings.get(r.user(), r.product()));
+ double err = confidence * (truePref - prediction) * (truePref - prediction);
+ sqErr += err;
+ denom += confidence;
}
double rmse = Math.sqrt(sqErr / denom);
Assert.assertTrue(String.format("Confidence-weighted RMSE=%2.4f above threshold of %2.2f",
- rmse, matchThreshold), rmse < matchThreshold);
+ rmse, matchThreshold), rmse < matchThreshold);
}
}
@@ -116,7 +104,7 @@ public class JavaALSSuite implements Serializable {
JavaRDD<Rating> data = sc.parallelize(testData._1());
MatrixFactorizationModel model = ALS.train(data.rdd(), features, iterations);
- validatePrediction(model, users, products, features, testData._2(), 0.3, false, testData._3());
+ validatePrediction(model, users, products, testData._2(), 0.3, false, testData._3());
}
@Test
@@ -132,8 +120,8 @@ public class JavaALSSuite implements Serializable {
MatrixFactorizationModel model = new ALS().setRank(features)
.setIterations(iterations)
- .run(data.rdd());
- validatePrediction(model, users, products, features, testData._2(), 0.3, false, testData._3());
+ .run(data);
+ validatePrediction(model, users, products, testData._2(), 0.3, false, testData._3());
}
@Test
@@ -147,7 +135,7 @@ public class JavaALSSuite implements Serializable {
JavaRDD<Rating> data = sc.parallelize(testData._1());
MatrixFactorizationModel model = ALS.trainImplicit(data.rdd(), features, iterations);
- validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3());
+ validatePrediction(model, users, products, testData._2(), 0.4, true, testData._3());
}
@Test
@@ -165,7 +153,7 @@ public class JavaALSSuite implements Serializable {
.setIterations(iterations)
.setImplicitPrefs(true)
.run(data.rdd());
- validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3());
+ validatePrediction(model, users, products, testData._2(), 0.4, true, testData._3());
}
@Test
@@ -183,7 +171,7 @@ public class JavaALSSuite implements Serializable {
.setImplicitPrefs(true)
.setSeed(8675309L)
.run(data.rdd());
- validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3());
+ validatePrediction(model, users, products, testData._2(), 0.4, true, testData._3());
}
@Test