aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/mllib/recommendation.py26
1 files changed, 15 insertions, 11 deletions
diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py
index 41bbd9a779..2bcbf2aaf8 100644
--- a/python/pyspark/mllib/recommendation.py
+++ b/python/pyspark/mllib/recommendation.py
@@ -15,24 +15,28 @@
# limitations under the License.
#
+from collections import namedtuple
+
from pyspark import SparkContext
from pyspark.rdd import RDD
from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, _to_java_object_rdd
-__all__ = ['MatrixFactorizationModel', 'ALS']
+__all__ = ['MatrixFactorizationModel', 'ALS', 'Rating']
-class Rating(object):
- def __init__(self, user, product, rating):
- self.user = int(user)
- self.product = int(product)
- self.rating = float(rating)
+class Rating(namedtuple("Rating", ["user", "product", "rating"])):
+ """
+ Represents a (user, product, rating) tuple.
- def __reduce__(self):
- return Rating, (self.user, self.product, self.rating)
+ >>> r = Rating(1, 2, 5.0)
+ >>> (r.user, r.product, r.rating)
+ (1, 2, 5.0)
+ >>> (r[0], r[1], r[2])
+ (1, 2, 5.0)
+ """
- def __repr__(self):
- return "Rating(%d, %d, %s)" % (self.user, self.product, self.rating)
+ def __reduce__(self):
+ return Rating, (int(self.user), int(self.product), float(self.rating))
class MatrixFactorizationModel(JavaModelWrapper):
@@ -51,7 +55,7 @@ class MatrixFactorizationModel(JavaModelWrapper):
>>> testset = sc.parallelize([(1, 2), (1, 1)])
>>> model = ALS.train(ratings, 1, seed=10)
>>> model.predictAll(testset).collect()
- [Rating(1, 1, 1.0471...), Rating(1, 2, 1.9679...)]
+ [Rating(user=1, product=1, rating=1.0471...), Rating(user=1, product=2, rating=1.9679...)]
>>> model = ALS.train(ratings, 4, seed=10)
>>> model.userFeatures().collect()