diff options
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala | 33 |
1 files changed, 31 insertions, 2 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala index fafc5ec5f2..e683a90f57 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala @@ -90,18 +90,34 @@ class ALSSuite extends FunSuite with BeforeAndAfterAll { testALS(50, 100, 1, 15, 0.7, 0.3) } + test("rank-1 matrices bulk") { + testALS(50, 100, 1, 15, 0.7, 0.3, false, true) + } + test("rank-2 matrices") { testALS(100, 200, 2, 15, 0.7, 0.3) } + test("rank-2 matrices bulk") { + testALS(100, 200, 2, 15, 0.7, 0.3, false, true) + } + test("rank-1 matrices implicit") { testALS(80, 160, 1, 15, 0.7, 0.4, true) } + test("rank-1 matrices implicit bulk") { + testALS(80, 160, 1, 15, 0.7, 0.4, true, true) + } + test("rank-2 matrices implicit") { testALS(100, 200, 2, 15, 0.7, 0.4, true) } + test("rank-2 matrices implicit bulk") { + testALS(100, 200, 2, 15, 0.7, 0.4, true, true) + } + /** * Test if we can correctly factorize R = U * P where U and P are of known rank. * @@ -111,9 +127,12 @@ class ALSSuite extends FunSuite with BeforeAndAfterAll { * @param iterations number of iterations to run * @param samplingRate what fraction of the user-product pairs are known * @param matchThreshold max difference allowed to consider a predicted rating correct + * @param implicitPrefs flag to test implicit feedback + * @param bulkPredict flag to test bulk prediciton */ def testALS(users: Int, products: Int, features: Int, iterations: Int, - samplingRate: Double, matchThreshold: Double, implicitPrefs: Boolean = false) + samplingRate: Double, matchThreshold: Double, implicitPrefs: Boolean = false, + bulkPredict: Boolean = false) { val (sampledRatings, trueRatings, truePrefs) = ALSSuite.generateRatings(users, products, features, samplingRate, implicitPrefs) @@ -130,7 +149,17 @@ class ALSSuite extends FunSuite with BeforeAndAfterAll { for ((p, vec) <- model.productFeatures.collect(); i <- 0 until features) { predictedP.put(p, i, vec(i)) } - val predictedRatings = predictedU.mmul(predictedP.transpose) + val predictedRatings = bulkPredict match { + case false => predictedU.mmul(predictedP.transpose) + case true => + val allRatings = new DoubleMatrix(users, products) + val usersProducts = for (u <- 0 until users; p <- 0 until products) yield (u, p) + val userProductsRDD = sc.parallelize(usersProducts) + model.predict(userProductsRDD).collect().foreach { elem => + allRatings.put(elem.user, elem.product, elem.rating) + } + allRatings + } if (!implicitPrefs) { for (u <- 0 until users; p <- 0 until products) { |