aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorHossein Falaki <falaki@gmail.com>2014-01-03 15:35:20 -0800
committerHossein Falaki <falaki@gmail.com>2014-01-03 15:35:20 -0800
commit2c1cba851c2954bacf10006c0d5dad67aba77ab5 (patch)
treee52f369c5933b1dc58f3b94566ac84b6fc7d9eee /mllib
parent67f937ec222c5a7db5286c0af0ec6f9c482d2af6 (diff)
downloadspark-2c1cba851c2954bacf10006c0d5dad67aba77ab5.tar.gz
spark-2c1cba851c2954bacf10006c0d5dad67aba77ab5.tar.bz2
spark-2c1cba851c2954bacf10006c0d5dad67aba77ab5.zip
Added unit tests for bulk prediction in MatrixFactorizationModel
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala33
1 files changed, 31 insertions, 2 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
index fafc5ec5f2..e683a90f57 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
@@ -90,18 +90,34 @@ class ALSSuite extends FunSuite with BeforeAndAfterAll {
testALS(50, 100, 1, 15, 0.7, 0.3)
}
+ test("rank-1 matrices bulk") {
+ testALS(50, 100, 1, 15, 0.7, 0.3, false, true)
+ }
+
test("rank-2 matrices") {
testALS(100, 200, 2, 15, 0.7, 0.3)
}
+ test("rank-2 matrices bulk") {
+ testALS(100, 200, 2, 15, 0.7, 0.3, false, true)
+ }
+
test("rank-1 matrices implicit") {
testALS(80, 160, 1, 15, 0.7, 0.4, true)
}
+ test("rank-1 matrices implicit bulk") {
+ testALS(80, 160, 1, 15, 0.7, 0.4, true, true)
+ }
+
test("rank-2 matrices implicit") {
testALS(100, 200, 2, 15, 0.7, 0.4, true)
}
+ test("rank-2 matrices implicit bulk") {
+ testALS(100, 200, 2, 15, 0.7, 0.4, true, true)
+ }
+
/**
* Test if we can correctly factorize R = U * P where U and P are of known rank.
*
@@ -111,9 +127,12 @@ class ALSSuite extends FunSuite with BeforeAndAfterAll {
* @param iterations number of iterations to run
* @param samplingRate what fraction of the user-product pairs are known
* @param matchThreshold max difference allowed to consider a predicted rating correct
+ * @param implicitPrefs flag to test implicit feedback
+ * @param bulkPredict flag to test bulk prediciton
*/
def testALS(users: Int, products: Int, features: Int, iterations: Int,
- samplingRate: Double, matchThreshold: Double, implicitPrefs: Boolean = false)
+ samplingRate: Double, matchThreshold: Double, implicitPrefs: Boolean = false,
+ bulkPredict: Boolean = false)
{
val (sampledRatings, trueRatings, truePrefs) = ALSSuite.generateRatings(users, products,
features, samplingRate, implicitPrefs)
@@ -130,7 +149,17 @@ class ALSSuite extends FunSuite with BeforeAndAfterAll {
for ((p, vec) <- model.productFeatures.collect(); i <- 0 until features) {
predictedP.put(p, i, vec(i))
}
- val predictedRatings = predictedU.mmul(predictedP.transpose)
+ val predictedRatings = bulkPredict match {
+ case false => predictedU.mmul(predictedP.transpose)
+ case true =>
+ val allRatings = new DoubleMatrix(users, products)
+ val usersProducts = for (u <- 0 until users; p <- 0 until products) yield (u, p)
+ val userProductsRDD = sc.parallelize(usersProducts)
+ model.predict(userProductsRDD).collect().foreach { elem =>
+ allRatings.put(elem.user, elem.product, elem.rating)
+ }
+ allRatings
+ }
if (!implicitPrefs) {
for (u <- 0 until users; p <- 0 until products) {