aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala33
1 files changed, 31 insertions, 2 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
index fafc5ec5f2..e683a90f57 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
@@ -90,18 +90,34 @@ class ALSSuite extends FunSuite with BeforeAndAfterAll {
testALS(50, 100, 1, 15, 0.7, 0.3)
}
+ test("rank-1 matrices bulk") {
+ testALS(50, 100, 1, 15, 0.7, 0.3, false, true)
+ }
+
test("rank-2 matrices") {
testALS(100, 200, 2, 15, 0.7, 0.3)
}
+ test("rank-2 matrices bulk") {
+ testALS(100, 200, 2, 15, 0.7, 0.3, false, true)
+ }
+
test("rank-1 matrices implicit") {
testALS(80, 160, 1, 15, 0.7, 0.4, true)
}
+ test("rank-1 matrices implicit bulk") {
+ testALS(80, 160, 1, 15, 0.7, 0.4, true, true)
+ }
+
test("rank-2 matrices implicit") {
testALS(100, 200, 2, 15, 0.7, 0.4, true)
}
+ test("rank-2 matrices implicit bulk") {
+ testALS(100, 200, 2, 15, 0.7, 0.4, true, true)
+ }
+
/**
* Test if we can correctly factorize R = U * P where U and P are of known rank.
*
@@ -111,9 +127,12 @@ class ALSSuite extends FunSuite with BeforeAndAfterAll {
* @param iterations number of iterations to run
* @param samplingRate what fraction of the user-product pairs are known
* @param matchThreshold max difference allowed to consider a predicted rating correct
+ * @param implicitPrefs flag to test implicit feedback
+ * @param bulkPredict flag to test bulk prediciton
*/
def testALS(users: Int, products: Int, features: Int, iterations: Int,
- samplingRate: Double, matchThreshold: Double, implicitPrefs: Boolean = false)
+ samplingRate: Double, matchThreshold: Double, implicitPrefs: Boolean = false,
+ bulkPredict: Boolean = false)
{
val (sampledRatings, trueRatings, truePrefs) = ALSSuite.generateRatings(users, products,
features, samplingRate, implicitPrefs)
@@ -130,7 +149,17 @@ class ALSSuite extends FunSuite with BeforeAndAfterAll {
for ((p, vec) <- model.productFeatures.collect(); i <- 0 until features) {
predictedP.put(p, i, vec(i))
}
- val predictedRatings = predictedU.mmul(predictedP.transpose)
+ val predictedRatings = bulkPredict match {
+ case false => predictedU.mmul(predictedP.transpose)
+ case true =>
+ val allRatings = new DoubleMatrix(users, products)
+ val usersProducts = for (u <- 0 until users; p <- 0 until products) yield (u, p)
+ val userProductsRDD = sc.parallelize(usersProducts)
+ model.predict(userProductsRDD).collect().foreach { elem =>
+ allRatings.put(elem.user, elem.product, elem.rating)
+ }
+ allRatings
+ }
if (!implicitPrefs) {
for (u <- 0 until users; p <- 0 until products) {