aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorHossein Falaki <falaki@gmail.com>2014-01-03 15:34:16 -0800
committerHossein Falaki <falaki@gmail.com>2014-01-03 15:34:16 -0800
commit67f937ec222c5a7db5286c0af0ec6f9c482d2af6 (patch)
treea33c47fce8cfd41539848752cdc2b7b2727d5c01 /mllib
parent0475ca8f81b6b8f21fdb841922cd9ab51cfc8cc3 (diff)
downloadspark-67f937ec222c5a7db5286c0af0ec6f9c482d2af6.tar.gz
spark-67f937ec222c5a7db5286c0af0ec6f9c482d2af6.tar.bz2
spark-67f937ec222c5a7db5286c0af0ec6f9c482d2af6.zip
Added a method to enable bulk prediction
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala24
1 files changed, 23 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
index af43d89c70..bc13a66dbe 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
@@ -20,7 +20,9 @@ package org.apache.spark.mllib.recommendation
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._
+
import org.jblas._
+import java.nio.{ByteOrder, ByteBuffer}
/**
* Model representing the result of matrix factorization.
@@ -44,6 +46,26 @@ class MatrixFactorizationModel(
userVector.dot(productVector)
}
- // TODO: Figure out what good bulk prediction methods would look like.
+ /**
+ * Predict the rating of many users for many products.
+ * The output RDD has an element per each element in the input RDD (including all duplicates)
+ * unless a user or product is missing in the training set.
+ *
+ * @param usersProducts RDD of (user, product) pairs.
+ * @return RDD of Ratings.
+ */
+ def predict(usersProducts: RDD[(Int, Int)]): RDD[Rating] = {
+ val users = userFeatures.join(usersProducts).map{
+ case (user, (uFeatures, product)) => (product, (user, uFeatures))
+ }
+ users.join(productFeatures).map {
+ case (product, ((user, uFeatures), pFeatures)) =>
+ val userVector = new DoubleMatrix(uFeatures)
+ val productVector = new DoubleMatrix(pFeatures)
+ Rating(user, product, userVector.dot(productVector))
+ }
+ }
+
+ // TODO: Figure out what other good bulk prediction methods would look like.
// Probably want a way to get the top users for a product or vice-versa.
}