aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatei Zaharia <matei@eecs.berkeley.edu>2013-10-08 23:44:55 -0700
committerMatei Zaharia <matei@eecs.berkeley.edu>2013-10-08 23:44:55 -0700
commit3218fa795ff3ddee855772184aebe99098701d4f (patch)
tree7fd15f20e789ca5e63c79335e6c4bc36c3e90ff6
parente67d5b962a2adddc073cfc9c99be9012fbb69838 (diff)
parenta5e58b8f980fe49d02dce83e366f0ea0cf070d76 (diff)
downloadspark-3218fa795ff3ddee855772184aebe99098701d4f.tar.gz
spark-3218fa795ff3ddee855772184aebe99098701d4f.tar.bz2
spark-3218fa795ff3ddee855772184aebe99098701d4f.zip
Merge pull request #4 from MLnick/implicit-als
Adding algorithm for implicit feedback data to ALS This PR adds the commonly used "implicit feedack" variant to ALS. The implementation is based in part on Mahout's implementation, which is in turn based on [Collaborative Filtering for Implicit Feedback Datasets](http://research.yahoo.com/pub/2433). It has been adapted for the blocked approach used in MLlib. I have tested this implementation against the MovieLens 100k, 1m and 10m datasets, and confirmed that it produces the same RMSE score as Mahout, as well as my own port of Mahout's implicit ALS implementation to Spark (not that RMSE is necessarily the best metric to judge by for implicit feedback, but it provides a consistent metric for comparison). It turned out to be more straightforward than I had thought to add this. The main additions are: 1. Adding `implicitPrefs` boolean flag and `alpha` parameter 2. Added the `computeYtY` method. In each least-squares step, the algorithm requires the computation of `YtY`, where `Y` is the {user, item} factor matrix. Since the factors are already block-distributed in an `RDD`, this is quite straightforward to compute but does add an extra operation over the explicit version (but only twice per iteration) 3. Finally the actual solve step in `updateBlock` boils down to: * a multiplication of the `XtX` matrix by `alpha * rating` * a multiplication of the `Xty` vector by `1 + alpha * rating` * when solving for the factor vector, the implicit variant adds the `YtY` matrix to the LHS 4. Added `trainImplicit` methods in the `ALS` object 5. Added test cases for both Scala and Java - based on achieving a confidence-weighted RMSE score < 0.4 (this is taken from Mahout's test cases) It would be great to get some feedback on this and have people test things out against some datasets (MovieLens and others and perhaps proprietary datasets) both locally and on a cluster if possible. I have not yet tested on a cluster but will try to do that soon. I have tried to make things as efficient as possible but if there are potential improvements let me know. The results of a run against ml-1m are below (note the vanilla RMSE scores will be very different from the explicit variant): **INPUTS** ``` iterations=10 factors=10 lambda=0.01 alpha=1 implicitPrefs=true ``` **RESULTS** ``` Spark MLlib 0.8.0-SNAPSHOT RMSE = 3.1544 Time: 24.834 sec ``` ``` My own port of Mahout's ALS to Spark (updated to 0.8.0-SNAPSHOT) RMSE = 3.1543 Time: 58.708 sec ``` ``` Mahout 0.8 time ./factorize-movielens-1M.sh /path/to/ratings/ml-1m/ratings.dat real 3m48.648s user 6m39.254s sys 0m14.505s RMSE = 3.1539 ``` Results of a run against ml-10m ``` Spark MLlib RMSE = 3.1200 Time: 162.348 sec ``` ``` Mahout 0.8 real 23m2.220s user 43m39.185s sys 0m25.316s RMSE = 3.1187 ```
-rw-r--r--docs/mllib-guide.md24
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala199
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java85
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala75
4 files changed, 320 insertions, 63 deletions
diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md
index f991d86c8d..c1ff9c417c 100644
--- a/docs/mllib-guide.md
+++ b/docs/mllib-guide.md
@@ -144,10 +144,9 @@ Available algorithms for clustering:
# Collaborative Filtering
-[Collaborative
-filtering](http://en.wikipedia.org/wiki/Recommender_system#Collaborative_filtering)
+[Collaborative filtering](http://en.wikipedia.org/wiki/Recommender_system#Collaborative_filtering)
is commonly used for recommender systems. These techniques aim to fill in the
-missing entries of a user-product association matrix. MLlib currently supports
+missing entries of a user-item association matrix. MLlib currently supports
model-based collaborative filtering, in which users and products are described
by a small set of latent factors that can be used to predict missing entries.
In particular, we implement the [alternating least squares
@@ -158,7 +157,24 @@ following parameters:
* *numBlocks* is the number of blacks used to parallelize computation (set to -1 to auto-configure).
* *rank* is the number of latent factors in our model.
* *iterations* is the number of iterations to run.
-* *lambda* specifies the regularization parameter in ALS.
+* *lambda* specifies the regularization parameter in ALS.
+* *implicitPrefs* specifies whether to use the *explicit feedback* ALS variant or one adapted for *implicit feedback* data
+* *alpha* is a parameter applicable to the implicit feedback variant of ALS that governs the *baseline* confidence in preference observations
+
+## Explicit vs Implicit Feedback
+
+The standard approach to matrix factorization based collaborative filtering treats
+the entries in the user-item matrix as *explicit* preferences given by the user to the item.
+
+It is common in many real-world use cases to only have access to *implicit feedback*
+(e.g. views, clicks, purchases, likes, shares etc.). The approach used in MLlib to deal with
+such data is taken from
+[Collaborative Filtering for Implicit Feedback Datasets](http://research.yahoo.com/pub/2433).
+Essentially instead of trying to model the matrix of ratings directly, this approach treats the data as
+a combination of binary preferences and *confidence values*. The ratings are then related
+to the level of confidence in observed user preferences, rather than explicit ratings given to items.
+The model then tries to find latent factors that can be used to predict the expected preference of a user
+for an item.
Available algorithms for collaborative filtering:
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
index be002d02bc..36853acab5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
@@ -21,7 +21,8 @@ import scala.collection.mutable.{ArrayBuffer, BitSet}
import scala.util.Random
import scala.util.Sorting
-import org.apache.spark.{HashPartitioner, Partitioner, SparkContext}
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.{Logging, HashPartitioner, Partitioner, SparkContext}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.KryoRegistrator
@@ -61,6 +62,12 @@ case class Rating(val user: Int, val product: Int, val rating: Double)
/**
* Alternating Least Squares matrix factorization.
*
+ * ALS attempts to estimate the ratings matrix `R` as the product of two lower-rank matrices,
+ * `X` and `Y`, i.e. `Xt * Y = R`. Typically these approximations are called 'factor' matrices.
+ * The general approach is iterative. During each iteration, one of the factor matrices is held
+ * constant, while the other is solved for using least squares. The newly-solved factor matrix is
+ * then held constant while solving for the other factor matrix.
+ *
* This is a blocked implementation of the ALS factorization algorithm that groups the two sets
* of factors (referred to as "users" and "products") into blocks and reduces communication by only
* sending one copy of each user vector to each product block on each iteration, and only for the
@@ -70,11 +77,21 @@ case class Rating(val user: Int, val product: Int, val rating: Double)
* vectors it receives from each user block it will depend on). This allows us to send only an
* array of feature vectors between each user block and product block, and have the product block
* find the users' ratings and update the products based on these messages.
+ *
+ * For implicit preference data, the algorithm used is based on
+ * "Collaborative Filtering for Implicit Feedback Datasets", available at
+ * [[http://research.yahoo.com/pub/2433]], adapted for the blocked approach used here.
+ *
+ * Essentially instead of finding the low-rank approximations to the rating matrix `R`,
+ * this finds the approximations for a preference matrix `P` where the elements of `P` are 1 if r > 0
+ * and 0 if r = 0. The ratings then act as 'confidence' values related to strength of indicated user
+ * preferences rather than explicit ratings given to items.
*/
-class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var lambda: Double)
- extends Serializable
+class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var lambda: Double,
+ var implicitPrefs: Boolean, var alpha: Double)
+ extends Serializable with Logging
{
- def this() = this(-1, 10, 10, 0.01)
+ def this() = this(-1, 10, 10, 0.01, false, 1.0)
/**
* Set the number of blocks to parallelize the computation into; pass -1 for an auto-configured
@@ -103,6 +120,16 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
this
}
+ def setImplicitPrefs(implicitPrefs: Boolean): ALS = {
+ this.implicitPrefs = implicitPrefs
+ this
+ }
+
+ def setAlpha(alpha: Double): ALS = {
+ this.alpha = alpha
+ this
+ }
+
/**
* Run ALS with the configured parameters on an input RDD of (user, product, rating) triples.
* Returns a MatrixFactorizationModel with feature vectors for each user and product.
@@ -147,19 +174,24 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
}
}
- for (iter <- 0 until iterations) {
+ for (iter <- 1 to iterations) {
// perform ALS update
- products = updateFeatures(users, userOutLinks, productInLinks, partitioner, rank, lambda)
- users = updateFeatures(products, productOutLinks, userInLinks, partitioner, rank, lambda)
+ logInfo("Re-computing I given U (Iteration %d/%d)".format(iter, iterations))
+ // YtY / XtX is an Option[DoubleMatrix] and is only required for the implicit feedback model
+ val YtY = computeYtY(users)
+ val YtYb = ratings.context.broadcast(YtY)
+ products = updateFeatures(users, userOutLinks, productInLinks, partitioner, rank, lambda,
+ alpha, YtYb)
+ logInfo("Re-computing U given I (Iteration %d/%d)".format(iter, iterations))
+ val XtX = computeYtY(products)
+ val XtXb = ratings.context.broadcast(XtX)
+ users = updateFeatures(products, productOutLinks, userInLinks, partitioner, rank, lambda,
+ alpha, XtXb)
}
// Flatten and cache the two final RDDs to un-block them
- val usersOut = users.join(userOutLinks).flatMap { case (b, (factors, outLinkBlock)) =>
- for (i <- 0 until factors.length) yield (outLinkBlock.elementIds(i), factors(i))
- }
- val productsOut = products.join(productOutLinks).flatMap { case (b, (factors, outLinkBlock)) =>
- for (i <- 0 until factors.length) yield (outLinkBlock.elementIds(i), factors(i))
- }
+ val usersOut = unblockFactors(users, userOutLinks)
+ val productsOut = unblockFactors(products, productOutLinks)
usersOut.persist()
productsOut.persist()
@@ -168,6 +200,40 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
}
/**
+ * Computes the (`rank x rank`) matrix `YtY`, where `Y` is the (`nui x rank`) matrix of factors
+ * for each user (or product), in a distributed fashion. Here `reduceByKeyLocally` is used as
+ * the driver program requires `YtY` to broadcast it to the slaves
+ * @param factors the (block-distributed) user or product factor vectors
+ * @return Option[YtY] - whose value is only used in the implicit preference model
+ */
+ def computeYtY(factors: RDD[(Int, Array[Array[Double]])]) = {
+ if (implicitPrefs) {
+ Option(
+ factors.flatMapValues{ case factorArray =>
+ factorArray.map{ vector =>
+ val x = new DoubleMatrix(vector)
+ x.mmul(x.transpose())
+ }
+ }.reduceByKeyLocally((a, b) => a.addi(b))
+ .values
+ .reduce((a, b) => a.addi(b))
+ )
+ } else {
+ None
+ }
+ }
+
+ /**
+ * Flatten out blocked user or product factors into an RDD of (id, factor vector) pairs
+ */
+ def unblockFactors(blockedFactors: RDD[(Int, Array[Array[Double]])],
+ outLinks: RDD[(Int, OutLinkBlock)]) = {
+ blockedFactors.join(outLinks).flatMap{ case (b, (factors, outLinkBlock)) =>
+ for (i <- 0 until factors.length) yield (outLinkBlock.elementIds(i), factors(i))
+ }
+ }
+
+ /**
* Make the out-links table for a block of the users (or products) dataset given the list of
* (user, product, rating) values for the users in that block (or the opposite for products).
*/
@@ -251,7 +317,9 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
userInLinks: RDD[(Int, InLinkBlock)],
partitioner: Partitioner,
rank: Int,
- lambda: Double)
+ lambda: Double,
+ alpha: Double,
+ YtY: Broadcast[Option[DoubleMatrix]])
: RDD[(Int, Array[Array[Double]])] =
{
val numBlocks = products.partitions.size
@@ -265,7 +333,9 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
toSend.zipWithIndex.map{ case (buf, idx) => (idx, (bid, buf.toArray)) }
}.groupByKey(partitioner)
.join(userInLinks)
- .mapValues{ case (messages, inLinkBlock) => updateBlock(messages, inLinkBlock, rank, lambda) }
+ .mapValues{ case (messages, inLinkBlock) =>
+ updateBlock(messages, inLinkBlock, rank, lambda, alpha, YtY)
+ }
}
/**
@@ -273,7 +343,7 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
* it received from each product and its InLinkBlock.
*/
def updateBlock(messages: Seq[(Int, Array[Array[Double]])], inLinkBlock: InLinkBlock,
- rank: Int, lambda: Double)
+ rank: Int, lambda: Double, alpha: Double, YtY: Broadcast[Option[DoubleMatrix]])
: Array[Array[Double]] =
{
// Sort the incoming block factor messages by block ID and make them an array
@@ -298,8 +368,14 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
fillXtX(x, tempXtX)
val (us, rs) = inLinkBlock.ratingsForBlock(productBlock)(p)
for (i <- 0 until us.length) {
- userXtX(us(i)).addi(tempXtX)
- SimpleBlas.axpy(rs(i), x, userXy(us(i)))
+ implicitPrefs match {
+ case false =>
+ userXtX(us(i)).addi(tempXtX)
+ SimpleBlas.axpy(rs(i), x, userXy(us(i)))
+ case true =>
+ userXtX(us(i)).addi(tempXtX.mul(alpha * rs(i)))
+ SimpleBlas.axpy(1 + alpha * rs(i), x, userXy(us(i)))
+ }
}
}
}
@@ -311,7 +387,10 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
// Add regularization
(0 until rank).foreach(i => fullXtX.data(i*rank + i) += lambda)
// Solve the resulting matrix, which is symmetric and positive-definite
- Solve.solvePositive(fullXtX, userXy(index)).data
+ implicitPrefs match {
+ case false => Solve.solvePositive(fullXtX, userXy(index)).data
+ case true => Solve.solvePositive(fullXtX.add(YtY.value.get), userXy(index)).data
+ }
}
}
@@ -381,7 +460,7 @@ object ALS {
blocks: Int)
: MatrixFactorizationModel =
{
- new ALS(blocks, rank, iterations, lambda).run(ratings)
+ new ALS(blocks, rank, iterations, lambda, false, 1.0).run(ratings)
}
/**
@@ -419,6 +498,68 @@ object ALS {
train(ratings, rank, iterations, 0.01, -1)
}
+ /**
+ * Train a matrix factorization model given an RDD of 'implicit preferences' given by users
+ * to some products, in the form of (userID, productID, preference) pairs. We approximate the
+ * ratings matrix as the product of two lower-rank matrices of a given rank (number of features).
+ * To solve for these features, we run a given number of iterations of ALS. This is done using
+ * a level of parallelism given by `blocks`.
+ *
+ * @param ratings RDD of (userID, productID, rating) pairs
+ * @param rank number of features to use
+ * @param iterations number of iterations of ALS (recommended: 10-20)
+ * @param lambda regularization factor (recommended: 0.01)
+ * @param blocks level of parallelism to split computation into
+ * @param alpha confidence parameter (only applies when immplicitPrefs = true)
+ */
+ def trainImplicit(
+ ratings: RDD[Rating],
+ rank: Int,
+ iterations: Int,
+ lambda: Double,
+ blocks: Int,
+ alpha: Double)
+ : MatrixFactorizationModel =
+ {
+ new ALS(blocks, rank, iterations, lambda, true, alpha).run(ratings)
+ }
+
+ /**
+ * Train a matrix factorization model given an RDD of 'implicit preferences' given by users to
+ * some products, in the form of (userID, productID, preference) pairs. We approximate the
+ * ratings matrix as the product of two lower-rank matrices of a given rank (number of features).
+ * To solve for these features, we run a given number of iterations of ALS. The level of
+ * parallelism is determined automatically based on the number of partitions in `ratings`.
+ *
+ * @param ratings RDD of (userID, productID, rating) pairs
+ * @param rank number of features to use
+ * @param iterations number of iterations of ALS (recommended: 10-20)
+ * @param lambda regularization factor (recommended: 0.01)
+ */
+ def trainImplicit(ratings: RDD[Rating], rank: Int, iterations: Int, lambda: Double, alpha: Double)
+ : MatrixFactorizationModel =
+ {
+ trainImplicit(ratings, rank, iterations, lambda, -1, alpha)
+ }
+
+ /**
+ * Train a matrix factorization model given an RDD of 'implicit preferences' ratings given by
+ * users to some products, in the form of (userID, productID, rating) pairs. We approximate the
+ * ratings matrix as the product of two lower-rank matrices of a given rank (number of features).
+ * To solve for these features, we run a given number of iterations of ALS. The level of
+ * parallelism is determined automatically based on the number of partitions in `ratings`.
+ * Model parameters `alpha` and `lambda` are set to reasonable default values
+ *
+ * @param ratings RDD of (userID, productID, rating) pairs
+ * @param rank number of features to use
+ * @param iterations number of iterations of ALS (recommended: 10-20)
+ */
+ def trainImplicit(ratings: RDD[Rating], rank: Int, iterations: Int)
+ : MatrixFactorizationModel =
+ {
+ trainImplicit(ratings, rank, iterations, 0.01, -1, 1.0)
+ }
+
private class ALSRegistrator extends KryoRegistrator {
override def registerClasses(kryo: Kryo) {
kryo.register(classOf[Rating])
@@ -426,29 +567,37 @@ object ALS {
}
def main(args: Array[String]) {
- if (args.length != 5 && args.length != 6) {
- println("Usage: ALS <master> <ratings_file> <rank> <iterations> <output_dir> [<blocks>]")
+ if (args.length < 5 || args.length > 9) {
+ println("Usage: ALS <master> <ratings_file> <rank> <iterations> <output_dir> " +
+ "[<lambda>] [<implicitPrefs>] [<alpha>] [<blocks>]")
System.exit(1)
}
val (master, ratingsFile, rank, iters, outputDir) =
(args(0), args(1), args(2).toInt, args(3).toInt, args(4))
- val blocks = if (args.length == 6) args(5).toInt else -1
+ val lambda = if (args.length >= 6) args(5).toDouble else 0.01
+ val implicitPrefs = if (args.length >= 7) args(6).toBoolean else false
+ val alpha = if (args.length >= 8) args(7).toDouble else 1
+ val blocks = if (args.length == 9) args(8).toInt else -1
+
System.setProperty("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
System.setProperty("spark.kryo.registrator", classOf[ALSRegistrator].getName)
System.setProperty("spark.kryo.referenceTracking", "false")
System.setProperty("spark.kryoserializer.buffer.mb", "8")
System.setProperty("spark.locality.wait", "10000")
+
val sc = new SparkContext(master, "ALS")
val ratings = sc.textFile(ratingsFile).map { line =>
val fields = line.split(',')
Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble)
}
- val model = ALS.train(ratings, rank, iters, 0.01, blocks)
+ val model = new ALS(rank = rank, iterations = iters, lambda = lambda,
+ numBlocks = blocks, implicitPrefs = implicitPrefs, alpha = alpha).run(ratings)
+
model.userFeatures.map{ case (id, vec) => id + "," + vec.mkString(" ") }
.saveAsTextFile(outputDir + "/userFeatures")
model.productFeatures.map{ case (id, vec) => id + "," + vec.mkString(" ") }
.saveAsTextFile(outputDir + "/productFeatures")
println("Final user/product features written to " + outputDir)
- System.exit(0)
+ sc.stop()
}
}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
index 3323f6cee2..eafee060cd 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
@@ -19,6 +19,7 @@ package org.apache.spark.mllib.recommendation;
import java.io.Serializable;
import java.util.List;
+import java.lang.Math;
import scala.Tuple2;
@@ -48,7 +49,7 @@ public class JavaALSSuite implements Serializable {
}
void validatePrediction(MatrixFactorizationModel model, int users, int products, int features,
- DoubleMatrix trueRatings, double matchThreshold) {
+ DoubleMatrix trueRatings, double matchThreshold, boolean implicitPrefs, DoubleMatrix truePrefs) {
DoubleMatrix predictedU = new DoubleMatrix(users, features);
List<scala.Tuple2<Object, double[]>> userFeatures = model.userFeatures().toJavaRDD().collect();
for (int i = 0; i < features; ++i) {
@@ -68,12 +69,32 @@ public class JavaALSSuite implements Serializable {
DoubleMatrix predictedRatings = predictedU.mmul(predictedP.transpose());
- for (int u = 0; u < users; ++u) {
- for (int p = 0; p < products; ++p) {
- double prediction = predictedRatings.get(u, p);
- double correct = trueRatings.get(u, p);
- Assert.assertTrue(Math.abs(prediction - correct) < matchThreshold);
+ if (!implicitPrefs) {
+ for (int u = 0; u < users; ++u) {
+ for (int p = 0; p < products; ++p) {
+ double prediction = predictedRatings.get(u, p);
+ double correct = trueRatings.get(u, p);
+ Assert.assertTrue(String.format("Prediction=%2.4f not below match threshold of %2.2f",
+ prediction, matchThreshold), Math.abs(prediction - correct) < matchThreshold);
+ }
}
+ } else {
+ // For implicit prefs we use the confidence-weighted RMSE to test (ref Mahout's implicit ALS tests)
+ double sqErr = 0.0;
+ double denom = 0.0;
+ for (int u = 0; u < users; ++u) {
+ for (int p = 0; p < products; ++p) {
+ double prediction = predictedRatings.get(u, p);
+ double truePref = truePrefs.get(u, p);
+ double confidence = 1.0 + /* alpha = */ 1.0 * trueRatings.get(u, p);
+ double err = confidence * (truePref - prediction) * (truePref - prediction);
+ sqErr += err;
+ denom += 1.0;
+ }
+ }
+ double rmse = Math.sqrt(sqErr / denom);
+ Assert.assertTrue(String.format("Confidence-weighted RMSE=%2.4f above threshold of %2.2f",
+ rmse, matchThreshold), Math.abs(rmse) < matchThreshold);
}
}
@@ -81,30 +102,62 @@ public class JavaALSSuite implements Serializable {
public void runALSUsingStaticMethods() {
int features = 1;
int iterations = 15;
- int users = 10;
- int products = 10;
- scala.Tuple2<List<Rating>, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
- users, products, features, 0.7);
+ int users = 50;
+ int products = 100;
+ scala.Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
+ users, products, features, 0.7, false);
JavaRDD<Rating> data = sc.parallelize(testData._1());
MatrixFactorizationModel model = ALS.train(data.rdd(), features, iterations);
- validatePrediction(model, users, products, features, testData._2(), 0.3);
+ validatePrediction(model, users, products, features, testData._2(), 0.3, false, testData._3());
}
@Test
public void runALSUsingConstructor() {
int features = 2;
int iterations = 15;
- int users = 20;
- int products = 30;
- scala.Tuple2<List<Rating>, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
- users, products, features, 0.7);
+ int users = 100;
+ int products = 200;
+ scala.Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
+ users, products, features, 0.7, false);
JavaRDD<Rating> data = sc.parallelize(testData._1());
MatrixFactorizationModel model = new ALS().setRank(features)
.setIterations(iterations)
.run(data.rdd());
- validatePrediction(model, users, products, features, testData._2(), 0.3);
+ validatePrediction(model, users, products, features, testData._2(), 0.3, false, testData._3());
+ }
+
+ @Test
+ public void runImplicitALSUsingStaticMethods() {
+ int features = 1;
+ int iterations = 15;
+ int users = 80;
+ int products = 160;
+ scala.Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
+ users, products, features, 0.7, true);
+
+ JavaRDD<Rating> data = sc.parallelize(testData._1());
+ MatrixFactorizationModel model = ALS.trainImplicit(data.rdd(), features, iterations);
+ validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3());
+ }
+
+ @Test
+ public void runImplicitALSUsingConstructor() {
+ int features = 2;
+ int iterations = 15;
+ int users = 100;
+ int products = 200;
+ scala.Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
+ users, products, features, 0.7, true);
+
+ JavaRDD<Rating> data = sc.parallelize(testData._1());
+
+ MatrixFactorizationModel model = new ALS().setRank(features)
+ .setIterations(iterations)
+ .setImplicitPrefs(true)
+ .run(data.rdd());
+ validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3());
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
index 347ef238f4..fafc5ec5f2 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
@@ -34,16 +34,19 @@ object ALSSuite {
users: Int,
products: Int,
features: Int,
- samplingRate: Double): (java.util.List[Rating], DoubleMatrix) = {
- val (sampledRatings, trueRatings) = generateRatings(users, products, features, samplingRate)
- (seqAsJavaList(sampledRatings), trueRatings)
+ samplingRate: Double,
+ implicitPrefs: Boolean): (java.util.List[Rating], DoubleMatrix, DoubleMatrix) = {
+ val (sampledRatings, trueRatings, truePrefs) =
+ generateRatings(users, products, features, samplingRate, implicitPrefs)
+ (seqAsJavaList(sampledRatings), trueRatings, truePrefs)
}
def generateRatings(
users: Int,
products: Int,
features: Int,
- samplingRate: Double): (Seq[Rating], DoubleMatrix) = {
+ samplingRate: Double,
+ implicitPrefs: Boolean = false): (Seq[Rating], DoubleMatrix, DoubleMatrix) = {
val rand = new Random(42)
// Create a random matrix with uniform values from -1 to 1
@@ -52,14 +55,20 @@ object ALSSuite {
val userMatrix = randomMatrix(users, features)
val productMatrix = randomMatrix(features, products)
- val trueRatings = userMatrix.mmul(productMatrix)
+ val (trueRatings, truePrefs) = implicitPrefs match {
+ case true =>
+ val raw = new DoubleMatrix(users, products, Array.fill(users * products)(rand.nextInt(10).toDouble): _*)
+ val prefs = new DoubleMatrix(users, products, raw.data.map(v => if (v > 0) 1.0 else 0.0): _*)
+ (raw, prefs)
+ case false => (userMatrix.mmul(productMatrix), null)
+ }
val sampledRatings = {
for (u <- 0 until users; p <- 0 until products if rand.nextDouble() < samplingRate)
yield Rating(u, p, trueRatings.get(u, p))
}
- (sampledRatings, trueRatings)
+ (sampledRatings, trueRatings, truePrefs)
}
}
@@ -78,11 +87,19 @@ class ALSSuite extends FunSuite with BeforeAndAfterAll {
}
test("rank-1 matrices") {
- testALS(10, 20, 1, 15, 0.7, 0.3)
+ testALS(50, 100, 1, 15, 0.7, 0.3)
}
test("rank-2 matrices") {
- testALS(20, 30, 2, 15, 0.7, 0.3)
+ testALS(100, 200, 2, 15, 0.7, 0.3)
+ }
+
+ test("rank-1 matrices implicit") {
+ testALS(80, 160, 1, 15, 0.7, 0.4, true)
+ }
+
+ test("rank-2 matrices implicit") {
+ testALS(100, 200, 2, 15, 0.7, 0.4, true)
}
/**
@@ -96,11 +113,14 @@ class ALSSuite extends FunSuite with BeforeAndAfterAll {
* @param matchThreshold max difference allowed to consider a predicted rating correct
*/
def testALS(users: Int, products: Int, features: Int, iterations: Int,
- samplingRate: Double, matchThreshold: Double)
+ samplingRate: Double, matchThreshold: Double, implicitPrefs: Boolean = false)
{
- val (sampledRatings, trueRatings) = ALSSuite.generateRatings(users, products,
- features, samplingRate)
- val model = ALS.train(sc.parallelize(sampledRatings), features, iterations)
+ val (sampledRatings, trueRatings, truePrefs) = ALSSuite.generateRatings(users, products,
+ features, samplingRate, implicitPrefs)
+ val model = implicitPrefs match {
+ case false => ALS.train(sc.parallelize(sampledRatings), features, iterations)
+ case true => ALS.trainImplicit(sc.parallelize(sampledRatings), features, iterations)
+ }
val predictedU = new DoubleMatrix(users, features)
for ((u, vec) <- model.userFeatures.collect(); i <- 0 until features) {
@@ -112,12 +132,31 @@ class ALSSuite extends FunSuite with BeforeAndAfterAll {
}
val predictedRatings = predictedU.mmul(predictedP.transpose)
- for (u <- 0 until users; p <- 0 until products) {
- val prediction = predictedRatings.get(u, p)
- val correct = trueRatings.get(u, p)
- if (math.abs(prediction - correct) > matchThreshold) {
- fail("Model failed to predict (%d, %d): %f vs %f\ncorr: %s\npred: %s\nU: %s\n P: %s".format(
- u, p, correct, prediction, trueRatings, predictedRatings, predictedU, predictedP))
+ if (!implicitPrefs) {
+ for (u <- 0 until users; p <- 0 until products) {
+ val prediction = predictedRatings.get(u, p)
+ val correct = trueRatings.get(u, p)
+ if (math.abs(prediction - correct) > matchThreshold) {
+ fail("Model failed to predict (%d, %d): %f vs %f\ncorr: %s\npred: %s\nU: %s\n P: %s".format(
+ u, p, correct, prediction, trueRatings, predictedRatings, predictedU, predictedP))
+ }
+ }
+ } else {
+ // For implicit prefs we use the confidence-weighted RMSE to test (ref Mahout's tests)
+ var sqErr = 0.0
+ var denom = 0.0
+ for (u <- 0 until users; p <- 0 until products) {
+ val prediction = predictedRatings.get(u, p)
+ val truePref = truePrefs.get(u, p)
+ val confidence = 1 + 1.0 * trueRatings.get(u, p)
+ val err = confidence * (truePref - prediction) * (truePref - prediction)
+ sqErr += err
+ denom += 1
+ }
+ val rmse = math.sqrt(sqErr / denom)
+ if (math.abs(rmse) > matchThreshold) {
+ fail("Model failed to predict RMSE: %f\ncorr: %s\npred: %s\nU: %s\n P: %s".format(
+ rmse, truePrefs, predictedRatings, predictedU, predictedP))
}
}
}