diff options
author | Hossein Falaki <falaki@gmail.com> | 2014-01-03 15:35:20 -0800 |
---|---|---|
committer | Hossein Falaki <falaki@gmail.com> | 2014-01-03 15:35:20 -0800 |
commit | 2c1cba851c2954bacf10006c0d5dad67aba77ab5 (patch) | |
tree | e52f369c5933b1dc58f3b94566ac84b6fc7d9eee /mllib | |
parent | 67f937ec222c5a7db5286c0af0ec6f9c482d2af6 (diff) | |
download | spark-2c1cba851c2954bacf10006c0d5dad67aba77ab5.tar.gz spark-2c1cba851c2954bacf10006c0d5dad67aba77ab5.tar.bz2 spark-2c1cba851c2954bacf10006c0d5dad67aba77ab5.zip |
Added unit tests for bulk prediction in MatrixFactorizationModel
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala | 33 |
1 files changed, 31 insertions, 2 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala index fafc5ec5f2..e683a90f57 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala @@ -90,18 +90,34 @@ class ALSSuite extends FunSuite with BeforeAndAfterAll { testALS(50, 100, 1, 15, 0.7, 0.3) } + test("rank-1 matrices bulk") { + testALS(50, 100, 1, 15, 0.7, 0.3, false, true) + } + test("rank-2 matrices") { testALS(100, 200, 2, 15, 0.7, 0.3) } + test("rank-2 matrices bulk") { + testALS(100, 200, 2, 15, 0.7, 0.3, false, true) + } + test("rank-1 matrices implicit") { testALS(80, 160, 1, 15, 0.7, 0.4, true) } + test("rank-1 matrices implicit bulk") { + testALS(80, 160, 1, 15, 0.7, 0.4, true, true) + } + test("rank-2 matrices implicit") { testALS(100, 200, 2, 15, 0.7, 0.4, true) } + test("rank-2 matrices implicit bulk") { + testALS(100, 200, 2, 15, 0.7, 0.4, true, true) + } + /** * Test if we can correctly factorize R = U * P where U and P are of known rank. * @@ -111,9 +127,12 @@ class ALSSuite extends FunSuite with BeforeAndAfterAll { * @param iterations number of iterations to run * @param samplingRate what fraction of the user-product pairs are known * @param matchThreshold max difference allowed to consider a predicted rating correct + * @param implicitPrefs flag to test implicit feedback + * @param bulkPredict flag to test bulk prediciton */ def testALS(users: Int, products: Int, features: Int, iterations: Int, - samplingRate: Double, matchThreshold: Double, implicitPrefs: Boolean = false) + samplingRate: Double, matchThreshold: Double, implicitPrefs: Boolean = false, + bulkPredict: Boolean = false) { val (sampledRatings, trueRatings, truePrefs) = ALSSuite.generateRatings(users, products, features, samplingRate, implicitPrefs) @@ -130,7 +149,17 @@ class ALSSuite extends FunSuite with BeforeAndAfterAll { for ((p, vec) <- model.productFeatures.collect(); i <- 0 until features) { predictedP.put(p, i, vec(i)) } - val predictedRatings = predictedU.mmul(predictedP.transpose) + val predictedRatings = bulkPredict match { + case false => predictedU.mmul(predictedP.transpose) + case true => + val allRatings = new DoubleMatrix(users, products) + val usersProducts = for (u <- 0 until users; p <- 0 until products) yield (u, p) + val userProductsRDD = sc.parallelize(usersProducts) + model.predict(userProductsRDD).collect().foreach { elem => + allRatings.put(elem.user, elem.product, elem.rating) + } + allRatings + } if (!implicitPrefs) { for (u <- 0 until users; p <- 0 until products) { |