aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/recommendation.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/mllib/recommendation.py')
-rw-r--r--python/pyspark/mllib/recommendation.py10
1 files changed, 9 insertions, 1 deletions
diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py
index 14d06cba21..c81b482a87 100644
--- a/python/pyspark/mllib/recommendation.py
+++ b/python/pyspark/mllib/recommendation.py
@@ -20,7 +20,10 @@ from pyspark.mllib._common import \
_get_unmangled_rdd, _get_unmangled_double_vector_rdd, \
_serialize_double_matrix, _deserialize_double_matrix, \
_serialize_double_vector, _deserialize_double_vector, \
- _get_initial_weights, _serialize_rating, _regression_train_wrapper
+ _get_initial_weights, _serialize_rating, _regression_train_wrapper, \
+ _serialize_tuple, _deserialize_rating
+from pyspark.serializers import BatchedSerializer
+from pyspark.rdd import RDD
class MatrixFactorizationModel(object):
"""A matrix factorisation model trained by regularized alternating
@@ -45,6 +48,11 @@ class MatrixFactorizationModel(object):
def predict(self, user, product):
return self._java_model.predict(user, product)
+ def predictAll(self, usersProducts):
+ usersProductsJRDD = _get_unmangled_rdd(usersProducts, _serialize_tuple)
+ return RDD(self._java_model.predictJavaRDD(usersProductsJRDD._jrdd),
+ self._context, BatchedSerializer(_deserialize_rating, self._context._batchSize))
+
class ALS(object):
@classmethod
def train(cls, sc, ratings, rank, iterations=5, lambda_=0.01, blocks=-1):