diff options
author | Hossein Falaki <falaki@gmail.com> | 2014-01-03 15:34:16 -0800 |
---|---|---|
committer | Hossein Falaki <falaki@gmail.com> | 2014-01-03 15:34:16 -0800 |
commit | 67f937ec222c5a7db5286c0af0ec6f9c482d2af6 (patch) | |
tree | a33c47fce8cfd41539848752cdc2b7b2727d5c01 /mllib | |
parent | 0475ca8f81b6b8f21fdb841922cd9ab51cfc8cc3 (diff) | |
download | spark-67f937ec222c5a7db5286c0af0ec6f9c482d2af6.tar.gz spark-67f937ec222c5a7db5286c0af0ec6f9c482d2af6.tar.bz2 spark-67f937ec222c5a7db5286c0af0ec6f9c482d2af6.zip |
Added a method to enable bulk prediction
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala | 24 |
1 files changed, 23 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index af43d89c70..bc13a66dbe 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -20,7 +20,9 @@ package org.apache.spark.mllib.recommendation import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext._ + import org.jblas._ +import java.nio.{ByteOrder, ByteBuffer} /** * Model representing the result of matrix factorization. @@ -44,6 +46,26 @@ class MatrixFactorizationModel( userVector.dot(productVector) } - // TODO: Figure out what good bulk prediction methods would look like. + /** + * Predict the rating of many users for many products. + * The output RDD has an element per each element in the input RDD (including all duplicates) + * unless a user or product is missing in the training set. + * + * @param usersProducts RDD of (user, product) pairs. + * @return RDD of Ratings. + */ + def predict(usersProducts: RDD[(Int, Int)]): RDD[Rating] = { + val users = userFeatures.join(usersProducts).map{ + case (user, (uFeatures, product)) => (product, (user, uFeatures)) + } + users.join(productFeatures).map { + case (product, ((user, uFeatures), pFeatures)) => + val userVector = new DoubleMatrix(uFeatures) + val productVector = new DoubleMatrix(pFeatures) + Rating(user, product, userVector.dot(productVector)) + } + } + + // TODO: Figure out what other good bulk prediction methods would look like. // Probably want a way to get the top users for a product or vice-versa. } |