aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorPrashant Sharma <prashant.s@imaginea.com>2013-10-10 09:42:55 +0530
committerPrashant Sharma <prashant.s@imaginea.com>2013-10-10 09:42:55 +0530
commit026ab7566167e6c8ab1b0cce75b9e09bbd485bee (patch)
treea713bacba391eb9b8e07ca0d2f6521cd2b061b49 /mllib
parent26860639c5fee7fc23db1e686f8eb202921e4314 (diff)
parent320418f7c8b42d4ce781b32c9ee47a9b54550b89 (diff)
downloadspark-026ab7566167e6c8ab1b0cce75b9e09bbd485bee.tar.gz
spark-026ab7566167e6c8ab1b0cce75b9e09bbd485bee.tar.bz2
spark-026ab7566167e6c8ab1b0cce75b9e09bbd485bee.zip
Merge branch 'master' of github.com:apache/incubator-spark into scala-2.10
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala199
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java85
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala75
3 files changed, 300 insertions, 59 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
index be002d02bc..36853acab5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
@@ -21,7 +21,8 @@ import scala.collection.mutable.{ArrayBuffer, BitSet}
import scala.util.Random
import scala.util.Sorting
-import org.apache.spark.{HashPartitioner, Partitioner, SparkContext}
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.{Logging, HashPartitioner, Partitioner, SparkContext}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.KryoRegistrator
@@ -61,6 +62,12 @@ case class Rating(val user: Int, val product: Int, val rating: Double)
/**
* Alternating Least Squares matrix factorization.
*
+ * ALS attempts to estimate the ratings matrix `R` as the product of two lower-rank matrices,
+ * `X` and `Y`, i.e. `Xt * Y = R`. Typically these approximations are called 'factor' matrices.
+ * The general approach is iterative. During each iteration, one of the factor matrices is held
+ * constant, while the other is solved for using least squares. The newly-solved factor matrix is
+ * then held constant while solving for the other factor matrix.
+ *
* This is a blocked implementation of the ALS factorization algorithm that groups the two sets
* of factors (referred to as "users" and "products") into blocks and reduces communication by only
* sending one copy of each user vector to each product block on each iteration, and only for the
@@ -70,11 +77,21 @@ case class Rating(val user: Int, val product: Int, val rating: Double)
* vectors it receives from each user block it will depend on). This allows us to send only an
* array of feature vectors between each user block and product block, and have the product block
* find the users' ratings and update the products based on these messages.
+ *
+ * For implicit preference data, the algorithm used is based on
+ * "Collaborative Filtering for Implicit Feedback Datasets", available at
+ * [[http://research.yahoo.com/pub/2433]], adapted for the blocked approach used here.
+ *
+ * Essentially instead of finding the low-rank approximations to the rating matrix `R`,
+ * this finds the approximations for a preference matrix `P` where the elements of `P` are 1 if r > 0
+ * and 0 if r = 0. The ratings then act as 'confidence' values related to strength of indicated user
+ * preferences rather than explicit ratings given to items.
*/
-class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var lambda: Double)
- extends Serializable
+class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var lambda: Double,
+ var implicitPrefs: Boolean, var alpha: Double)
+ extends Serializable with Logging
{
- def this() = this(-1, 10, 10, 0.01)
+ def this() = this(-1, 10, 10, 0.01, false, 1.0)
/**
* Set the number of blocks to parallelize the computation into; pass -1 for an auto-configured
@@ -103,6 +120,16 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
this
}
+ def setImplicitPrefs(implicitPrefs: Boolean): ALS = {
+ this.implicitPrefs = implicitPrefs
+ this
+ }
+
+ def setAlpha(alpha: Double): ALS = {
+ this.alpha = alpha
+ this
+ }
+
/**
* Run ALS with the configured parameters on an input RDD of (user, product, rating) triples.
* Returns a MatrixFactorizationModel with feature vectors for each user and product.
@@ -147,19 +174,24 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
}
}
- for (iter <- 0 until iterations) {
+ for (iter <- 1 to iterations) {
// perform ALS update
- products = updateFeatures(users, userOutLinks, productInLinks, partitioner, rank, lambda)
- users = updateFeatures(products, productOutLinks, userInLinks, partitioner, rank, lambda)
+ logInfo("Re-computing I given U (Iteration %d/%d)".format(iter, iterations))
+ // YtY / XtX is an Option[DoubleMatrix] and is only required for the implicit feedback model
+ val YtY = computeYtY(users)
+ val YtYb = ratings.context.broadcast(YtY)
+ products = updateFeatures(users, userOutLinks, productInLinks, partitioner, rank, lambda,
+ alpha, YtYb)
+ logInfo("Re-computing U given I (Iteration %d/%d)".format(iter, iterations))
+ val XtX = computeYtY(products)
+ val XtXb = ratings.context.broadcast(XtX)
+ users = updateFeatures(products, productOutLinks, userInLinks, partitioner, rank, lambda,
+ alpha, XtXb)
}
// Flatten and cache the two final RDDs to un-block them
- val usersOut = users.join(userOutLinks).flatMap { case (b, (factors, outLinkBlock)) =>
- for (i <- 0 until factors.length) yield (outLinkBlock.elementIds(i), factors(i))
- }
- val productsOut = products.join(productOutLinks).flatMap { case (b, (factors, outLinkBlock)) =>
- for (i <- 0 until factors.length) yield (outLinkBlock.elementIds(i), factors(i))
- }
+ val usersOut = unblockFactors(users, userOutLinks)
+ val productsOut = unblockFactors(products, productOutLinks)
usersOut.persist()
productsOut.persist()
@@ -168,6 +200,40 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
}
/**
+ * Computes the (`rank x rank`) matrix `YtY`, where `Y` is the (`nui x rank`) matrix of factors
+ * for each user (or product), in a distributed fashion. Here `reduceByKeyLocally` is used as
+ * the driver program requires `YtY` to broadcast it to the slaves
+ * @param factors the (block-distributed) user or product factor vectors
+ * @return Option[YtY] - whose value is only used in the implicit preference model
+ */
+ def computeYtY(factors: RDD[(Int, Array[Array[Double]])]) = {
+ if (implicitPrefs) {
+ Option(
+ factors.flatMapValues{ case factorArray =>
+ factorArray.map{ vector =>
+ val x = new DoubleMatrix(vector)
+ x.mmul(x.transpose())
+ }
+ }.reduceByKeyLocally((a, b) => a.addi(b))
+ .values
+ .reduce((a, b) => a.addi(b))
+ )
+ } else {
+ None
+ }
+ }
+
+ /**
+ * Flatten out blocked user or product factors into an RDD of (id, factor vector) pairs
+ */
+ def unblockFactors(blockedFactors: RDD[(Int, Array[Array[Double]])],
+ outLinks: RDD[(Int, OutLinkBlock)]) = {
+ blockedFactors.join(outLinks).flatMap{ case (b, (factors, outLinkBlock)) =>
+ for (i <- 0 until factors.length) yield (outLinkBlock.elementIds(i), factors(i))
+ }
+ }
+
+ /**
* Make the out-links table for a block of the users (or products) dataset given the list of
* (user, product, rating) values for the users in that block (or the opposite for products).
*/
@@ -251,7 +317,9 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
userInLinks: RDD[(Int, InLinkBlock)],
partitioner: Partitioner,
rank: Int,
- lambda: Double)
+ lambda: Double,
+ alpha: Double,
+ YtY: Broadcast[Option[DoubleMatrix]])
: RDD[(Int, Array[Array[Double]])] =
{
val numBlocks = products.partitions.size
@@ -265,7 +333,9 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
toSend.zipWithIndex.map{ case (buf, idx) => (idx, (bid, buf.toArray)) }
}.groupByKey(partitioner)
.join(userInLinks)
- .mapValues{ case (messages, inLinkBlock) => updateBlock(messages, inLinkBlock, rank, lambda) }
+ .mapValues{ case (messages, inLinkBlock) =>
+ updateBlock(messages, inLinkBlock, rank, lambda, alpha, YtY)
+ }
}
/**
@@ -273,7 +343,7 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
* it received from each product and its InLinkBlock.
*/
def updateBlock(messages: Seq[(Int, Array[Array[Double]])], inLinkBlock: InLinkBlock,
- rank: Int, lambda: Double)
+ rank: Int, lambda: Double, alpha: Double, YtY: Broadcast[Option[DoubleMatrix]])
: Array[Array[Double]] =
{
// Sort the incoming block factor messages by block ID and make them an array
@@ -298,8 +368,14 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
fillXtX(x, tempXtX)
val (us, rs) = inLinkBlock.ratingsForBlock(productBlock)(p)
for (i <- 0 until us.length) {
- userXtX(us(i)).addi(tempXtX)
- SimpleBlas.axpy(rs(i), x, userXy(us(i)))
+ implicitPrefs match {
+ case false =>
+ userXtX(us(i)).addi(tempXtX)
+ SimpleBlas.axpy(rs(i), x, userXy(us(i)))
+ case true =>
+ userXtX(us(i)).addi(tempXtX.mul(alpha * rs(i)))
+ SimpleBlas.axpy(1 + alpha * rs(i), x, userXy(us(i)))
+ }
}
}
}
@@ -311,7 +387,10 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
// Add regularization
(0 until rank).foreach(i => fullXtX.data(i*rank + i) += lambda)
// Solve the resulting matrix, which is symmetric and positive-definite
- Solve.solvePositive(fullXtX, userXy(index)).data
+ implicitPrefs match {
+ case false => Solve.solvePositive(fullXtX, userXy(index)).data
+ case true => Solve.solvePositive(fullXtX.add(YtY.value.get), userXy(index)).data
+ }
}
}
@@ -381,7 +460,7 @@ object ALS {
blocks: Int)
: MatrixFactorizationModel =
{
- new ALS(blocks, rank, iterations, lambda).run(ratings)
+ new ALS(blocks, rank, iterations, lambda, false, 1.0).run(ratings)
}
/**
@@ -419,6 +498,68 @@ object ALS {
train(ratings, rank, iterations, 0.01, -1)
}
+ /**
+ * Train a matrix factorization model given an RDD of 'implicit preferences' given by users
+ * to some products, in the form of (userID, productID, preference) pairs. We approximate the
+ * ratings matrix as the product of two lower-rank matrices of a given rank (number of features).
+ * To solve for these features, we run a given number of iterations of ALS. This is done using
+ * a level of parallelism given by `blocks`.
+ *
+ * @param ratings RDD of (userID, productID, rating) pairs
+ * @param rank number of features to use
+ * @param iterations number of iterations of ALS (recommended: 10-20)
+ * @param lambda regularization factor (recommended: 0.01)
+ * @param blocks level of parallelism to split computation into
+ * @param alpha confidence parameter (only applies when immplicitPrefs = true)
+ */
+ def trainImplicit(
+ ratings: RDD[Rating],
+ rank: Int,
+ iterations: Int,
+ lambda: Double,
+ blocks: Int,
+ alpha: Double)
+ : MatrixFactorizationModel =
+ {
+ new ALS(blocks, rank, iterations, lambda, true, alpha).run(ratings)
+ }
+
+ /**
+ * Train a matrix factorization model given an RDD of 'implicit preferences' given by users to
+ * some products, in the form of (userID, productID, preference) pairs. We approximate the
+ * ratings matrix as the product of two lower-rank matrices of a given rank (number of features).
+ * To solve for these features, we run a given number of iterations of ALS. The level of
+ * parallelism is determined automatically based on the number of partitions in `ratings`.
+ *
+ * @param ratings RDD of (userID, productID, rating) pairs
+ * @param rank number of features to use
+ * @param iterations number of iterations of ALS (recommended: 10-20)
+ * @param lambda regularization factor (recommended: 0.01)
+ */
+ def trainImplicit(ratings: RDD[Rating], rank: Int, iterations: Int, lambda: Double, alpha: Double)
+ : MatrixFactorizationModel =
+ {
+ trainImplicit(ratings, rank, iterations, lambda, -1, alpha)
+ }
+
+ /**
+ * Train a matrix factorization model given an RDD of 'implicit preferences' ratings given by
+ * users to some products, in the form of (userID, productID, rating) pairs. We approximate the
+ * ratings matrix as the product of two lower-rank matrices of a given rank (number of features).
+ * To solve for these features, we run a given number of iterations of ALS. The level of
+ * parallelism is determined automatically based on the number of partitions in `ratings`.
+ * Model parameters `alpha` and `lambda` are set to reasonable default values
+ *
+ * @param ratings RDD of (userID, productID, rating) pairs
+ * @param rank number of features to use
+ * @param iterations number of iterations of ALS (recommended: 10-20)
+ */
+ def trainImplicit(ratings: RDD[Rating], rank: Int, iterations: Int)
+ : MatrixFactorizationModel =
+ {
+ trainImplicit(ratings, rank, iterations, 0.01, -1, 1.0)
+ }
+
private class ALSRegistrator extends KryoRegistrator {
override def registerClasses(kryo: Kryo) {
kryo.register(classOf[Rating])
@@ -426,29 +567,37 @@ object ALS {
}
def main(args: Array[String]) {
- if (args.length != 5 && args.length != 6) {
- println("Usage: ALS <master> <ratings_file> <rank> <iterations> <output_dir> [<blocks>]")
+ if (args.length < 5 || args.length > 9) {
+ println("Usage: ALS <master> <ratings_file> <rank> <iterations> <output_dir> " +
+ "[<lambda>] [<implicitPrefs>] [<alpha>] [<blocks>]")
System.exit(1)
}
val (master, ratingsFile, rank, iters, outputDir) =
(args(0), args(1), args(2).toInt, args(3).toInt, args(4))
- val blocks = if (args.length == 6) args(5).toInt else -1
+ val lambda = if (args.length >= 6) args(5).toDouble else 0.01
+ val implicitPrefs = if (args.length >= 7) args(6).toBoolean else false
+ val alpha = if (args.length >= 8) args(7).toDouble else 1
+ val blocks = if (args.length == 9) args(8).toInt else -1
+
System.setProperty("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
System.setProperty("spark.kryo.registrator", classOf[ALSRegistrator].getName)
System.setProperty("spark.kryo.referenceTracking", "false")
System.setProperty("spark.kryoserializer.buffer.mb", "8")
System.setProperty("spark.locality.wait", "10000")
+
val sc = new SparkContext(master, "ALS")
val ratings = sc.textFile(ratingsFile).map { line =>
val fields = line.split(',')
Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble)
}
- val model = ALS.train(ratings, rank, iters, 0.01, blocks)
+ val model = new ALS(rank = rank, iterations = iters, lambda = lambda,
+ numBlocks = blocks, implicitPrefs = implicitPrefs, alpha = alpha).run(ratings)
+
model.userFeatures.map{ case (id, vec) => id + "," + vec.mkString(" ") }
.saveAsTextFile(outputDir + "/userFeatures")
model.productFeatures.map{ case (id, vec) => id + "," + vec.mkString(" ") }
.saveAsTextFile(outputDir + "/productFeatures")
println("Final user/product features written to " + outputDir)
- System.exit(0)
+ sc.stop()
}
}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
index c474e01188..b40f552e0d 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
@@ -19,6 +19,7 @@ package org.apache.spark.mllib.recommendation;
import java.io.Serializable;
import java.util.List;
+import java.lang.Math;
import org.junit.After;
import org.junit.Assert;
@@ -46,7 +47,7 @@ public class JavaALSSuite implements Serializable {
}
void validatePrediction(MatrixFactorizationModel model, int users, int products, int features,
- DoubleMatrix trueRatings, double matchThreshold) {
+ DoubleMatrix trueRatings, double matchThreshold, boolean implicitPrefs, DoubleMatrix truePrefs) {
DoubleMatrix predictedU = new DoubleMatrix(users, features);
List<scala.Tuple2<Object, double[]>> userFeatures = model.userFeatures().toJavaRDD().collect();
for (int i = 0; i < features; ++i) {
@@ -66,12 +67,32 @@ public class JavaALSSuite implements Serializable {
DoubleMatrix predictedRatings = predictedU.mmul(predictedP.transpose());
- for (int u = 0; u < users; ++u) {
- for (int p = 0; p < products; ++p) {
- double prediction = predictedRatings.get(u, p);
- double correct = trueRatings.get(u, p);
- Assert.assertTrue(Math.abs(prediction - correct) < matchThreshold);
+ if (!implicitPrefs) {
+ for (int u = 0; u < users; ++u) {
+ for (int p = 0; p < products; ++p) {
+ double prediction = predictedRatings.get(u, p);
+ double correct = trueRatings.get(u, p);
+ Assert.assertTrue(String.format("Prediction=%2.4f not below match threshold of %2.2f",
+ prediction, matchThreshold), Math.abs(prediction - correct) < matchThreshold);
+ }
}
+ } else {
+ // For implicit prefs we use the confidence-weighted RMSE to test (ref Mahout's implicit ALS tests)
+ double sqErr = 0.0;
+ double denom = 0.0;
+ for (int u = 0; u < users; ++u) {
+ for (int p = 0; p < products; ++p) {
+ double prediction = predictedRatings.get(u, p);
+ double truePref = truePrefs.get(u, p);
+ double confidence = 1.0 + /* alpha = */ 1.0 * trueRatings.get(u, p);
+ double err = confidence * (truePref - prediction) * (truePref - prediction);
+ sqErr += err;
+ denom += 1.0;
+ }
+ }
+ double rmse = Math.sqrt(sqErr / denom);
+ Assert.assertTrue(String.format("Confidence-weighted RMSE=%2.4f above threshold of %2.2f",
+ rmse, matchThreshold), Math.abs(rmse) < matchThreshold);
}
}
@@ -79,30 +100,62 @@ public class JavaALSSuite implements Serializable {
public void runALSUsingStaticMethods() {
int features = 1;
int iterations = 15;
- int users = 10;
- int products = 10;
- scala.Tuple2<List<Rating>, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
- users, products, features, 0.7);
+ int users = 50;
+ int products = 100;
+ scala.Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
+ users, products, features, 0.7, false);
JavaRDD<Rating> data = sc.parallelize(testData._1());
MatrixFactorizationModel model = ALS.train(data.rdd(), features, iterations);
- validatePrediction(model, users, products, features, testData._2(), 0.3);
+ validatePrediction(model, users, products, features, testData._2(), 0.3, false, testData._3());
}
@Test
public void runALSUsingConstructor() {
int features = 2;
int iterations = 15;
- int users = 20;
- int products = 30;
- scala.Tuple2<List<Rating>, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
- users, products, features, 0.7);
+ int users = 100;
+ int products = 200;
+ scala.Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
+ users, products, features, 0.7, false);
JavaRDD<Rating> data = sc.parallelize(testData._1());
MatrixFactorizationModel model = new ALS().setRank(features)
.setIterations(iterations)
.run(data.rdd());
- validatePrediction(model, users, products, features, testData._2(), 0.3);
+ validatePrediction(model, users, products, features, testData._2(), 0.3, false, testData._3());
+ }
+
+ @Test
+ public void runImplicitALSUsingStaticMethods() {
+ int features = 1;
+ int iterations = 15;
+ int users = 80;
+ int products = 160;
+ scala.Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
+ users, products, features, 0.7, true);
+
+ JavaRDD<Rating> data = sc.parallelize(testData._1());
+ MatrixFactorizationModel model = ALS.trainImplicit(data.rdd(), features, iterations);
+ validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3());
+ }
+
+ @Test
+ public void runImplicitALSUsingConstructor() {
+ int features = 2;
+ int iterations = 15;
+ int users = 100;
+ int products = 200;
+ scala.Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
+ users, products, features, 0.7, true);
+
+ JavaRDD<Rating> data = sc.parallelize(testData._1());
+
+ MatrixFactorizationModel model = new ALS().setRank(features)
+ .setIterations(iterations)
+ .setImplicitPrefs(true)
+ .run(data.rdd());
+ validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3());
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
index 347ef238f4..fafc5ec5f2 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
@@ -34,16 +34,19 @@ object ALSSuite {
users: Int,
products: Int,
features: Int,
- samplingRate: Double): (java.util.List[Rating], DoubleMatrix) = {
- val (sampledRatings, trueRatings) = generateRatings(users, products, features, samplingRate)
- (seqAsJavaList(sampledRatings), trueRatings)
+ samplingRate: Double,
+ implicitPrefs: Boolean): (java.util.List[Rating], DoubleMatrix, DoubleMatrix) = {
+ val (sampledRatings, trueRatings, truePrefs) =
+ generateRatings(users, products, features, samplingRate, implicitPrefs)
+ (seqAsJavaList(sampledRatings), trueRatings, truePrefs)
}
def generateRatings(
users: Int,
products: Int,
features: Int,
- samplingRate: Double): (Seq[Rating], DoubleMatrix) = {
+ samplingRate: Double,
+ implicitPrefs: Boolean = false): (Seq[Rating], DoubleMatrix, DoubleMatrix) = {
val rand = new Random(42)
// Create a random matrix with uniform values from -1 to 1
@@ -52,14 +55,20 @@ object ALSSuite {
val userMatrix = randomMatrix(users, features)
val productMatrix = randomMatrix(features, products)
- val trueRatings = userMatrix.mmul(productMatrix)
+ val (trueRatings, truePrefs) = implicitPrefs match {
+ case true =>
+ val raw = new DoubleMatrix(users, products, Array.fill(users * products)(rand.nextInt(10).toDouble): _*)
+ val prefs = new DoubleMatrix(users, products, raw.data.map(v => if (v > 0) 1.0 else 0.0): _*)
+ (raw, prefs)
+ case false => (userMatrix.mmul(productMatrix), null)
+ }
val sampledRatings = {
for (u <- 0 until users; p <- 0 until products if rand.nextDouble() < samplingRate)
yield Rating(u, p, trueRatings.get(u, p))
}
- (sampledRatings, trueRatings)
+ (sampledRatings, trueRatings, truePrefs)
}
}
@@ -78,11 +87,19 @@ class ALSSuite extends FunSuite with BeforeAndAfterAll {
}
test("rank-1 matrices") {
- testALS(10, 20, 1, 15, 0.7, 0.3)
+ testALS(50, 100, 1, 15, 0.7, 0.3)
}
test("rank-2 matrices") {
- testALS(20, 30, 2, 15, 0.7, 0.3)
+ testALS(100, 200, 2, 15, 0.7, 0.3)
+ }
+
+ test("rank-1 matrices implicit") {
+ testALS(80, 160, 1, 15, 0.7, 0.4, true)
+ }
+
+ test("rank-2 matrices implicit") {
+ testALS(100, 200, 2, 15, 0.7, 0.4, true)
}
/**
@@ -96,11 +113,14 @@ class ALSSuite extends FunSuite with BeforeAndAfterAll {
* @param matchThreshold max difference allowed to consider a predicted rating correct
*/
def testALS(users: Int, products: Int, features: Int, iterations: Int,
- samplingRate: Double, matchThreshold: Double)
+ samplingRate: Double, matchThreshold: Double, implicitPrefs: Boolean = false)
{
- val (sampledRatings, trueRatings) = ALSSuite.generateRatings(users, products,
- features, samplingRate)
- val model = ALS.train(sc.parallelize(sampledRatings), features, iterations)
+ val (sampledRatings, trueRatings, truePrefs) = ALSSuite.generateRatings(users, products,
+ features, samplingRate, implicitPrefs)
+ val model = implicitPrefs match {
+ case false => ALS.train(sc.parallelize(sampledRatings), features, iterations)
+ case true => ALS.trainImplicit(sc.parallelize(sampledRatings), features, iterations)
+ }
val predictedU = new DoubleMatrix(users, features)
for ((u, vec) <- model.userFeatures.collect(); i <- 0 until features) {
@@ -112,12 +132,31 @@ class ALSSuite extends FunSuite with BeforeAndAfterAll {
}
val predictedRatings = predictedU.mmul(predictedP.transpose)
- for (u <- 0 until users; p <- 0 until products) {
- val prediction = predictedRatings.get(u, p)
- val correct = trueRatings.get(u, p)
- if (math.abs(prediction - correct) > matchThreshold) {
- fail("Model failed to predict (%d, %d): %f vs %f\ncorr: %s\npred: %s\nU: %s\n P: %s".format(
- u, p, correct, prediction, trueRatings, predictedRatings, predictedU, predictedP))
+ if (!implicitPrefs) {
+ for (u <- 0 until users; p <- 0 until products) {
+ val prediction = predictedRatings.get(u, p)
+ val correct = trueRatings.get(u, p)
+ if (math.abs(prediction - correct) > matchThreshold) {
+ fail("Model failed to predict (%d, %d): %f vs %f\ncorr: %s\npred: %s\nU: %s\n P: %s".format(
+ u, p, correct, prediction, trueRatings, predictedRatings, predictedU, predictedP))
+ }
+ }
+ } else {
+ // For implicit prefs we use the confidence-weighted RMSE to test (ref Mahout's tests)
+ var sqErr = 0.0
+ var denom = 0.0
+ for (u <- 0 until users; p <- 0 until products) {
+ val prediction = predictedRatings.get(u, p)
+ val truePref = truePrefs.get(u, p)
+ val confidence = 1 + 1.0 * trueRatings.get(u, p)
+ val err = confidence * (truePref - prediction) * (truePref - prediction)
+ sqErr += err
+ denom += 1
+ }
+ val rmse = math.sqrt(sqErr / denom)
+ if (math.abs(rmse) > matchThreshold) {
+ fail("Model failed to predict RMSE: %f\ncorr: %s\npred: %s\nU: %s\n P: %s".format(
+ rmse, truePrefs, predictedRatings, predictedU, predictedP))
}
}
}