aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/recommendation.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/mllib/recommendation.py')
-rw-r--r--python/pyspark/mllib/recommendation.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py
index e26b152e0c..41bbd9a779 100644
--- a/python/pyspark/mllib/recommendation.py
+++ b/python/pyspark/mllib/recommendation.py
@@ -32,7 +32,7 @@ class Rating(object):
return Rating, (self.user, self.product, self.rating)
def __repr__(self):
- return "Rating(%d, %d, %d)" % (self.user, self.product, self.rating)
+ return "Rating(%d, %d, %s)" % (self.user, self.product, self.rating)
class MatrixFactorizationModel(JavaModelWrapper):
@@ -51,7 +51,7 @@ class MatrixFactorizationModel(JavaModelWrapper):
>>> testset = sc.parallelize([(1, 2), (1, 1)])
>>> model = ALS.train(ratings, 1, seed=10)
>>> model.predictAll(testset).collect()
- [Rating(1, 1, 1), Rating(1, 2, 1)]
+ [Rating(1, 1, 1.0471...), Rating(1, 2, 1.9679...)]
>>> model = ALS.train(ratings, 4, seed=10)
>>> model.userFeatures().collect()
@@ -79,7 +79,7 @@ class MatrixFactorizationModel(JavaModelWrapper):
0.4473...
"""
def predict(self, user, product):
- return self._java_model.predict(user, product)
+ return self._java_model.predict(int(user), int(product))
def predictAll(self, user_product):
assert isinstance(user_product, RDD), "user_product should be RDD of (user, product)"