aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala9
1 files changed, 8 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
index 2c3e828300..443fc5de5b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
@@ -67,7 +67,14 @@ class MatrixFactorizationModel(
}
}
- def predictJavaRDD(usersProductsJRDD: JavaRDD[Array[Byte]]): JavaRDD[Array[Byte]] = {
+ /**
+ * Predict the rating of many users for many products.
+ * This is a Java stub for python predictAll()
+ *
+ * @param usersProductsJRDD A JavaRDD with serialized tuples (user, product)
+ * @return JavaRDD of serialized Rating objects.
+ */
+ def predict(usersProductsJRDD: JavaRDD[Array[Byte]]): JavaRDD[Array[Byte]] = {
val pythonAPI = new PythonMLLibAPI()
val usersProducts = usersProductsJRDD.rdd.map(xBytes => pythonAPI.unpackTuple(xBytes))
predict(usersProducts).map(rate => pythonAPI.serializeRating(rate))