diff options
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/mllib/recommendation.py | 26 |
1 files changed, 15 insertions, 11 deletions
diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 41bbd9a779..2bcbf2aaf8 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -15,24 +15,28 @@ # limitations under the License. # +from collections import namedtuple + from pyspark import SparkContext from pyspark.rdd import RDD from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, _to_java_object_rdd -__all__ = ['MatrixFactorizationModel', 'ALS'] +__all__ = ['MatrixFactorizationModel', 'ALS', 'Rating'] -class Rating(object): - def __init__(self, user, product, rating): - self.user = int(user) - self.product = int(product) - self.rating = float(rating) +class Rating(namedtuple("Rating", ["user", "product", "rating"])): + """ + Represents a (user, product, rating) tuple. - def __reduce__(self): - return Rating, (self.user, self.product, self.rating) + >>> r = Rating(1, 2, 5.0) + >>> (r.user, r.product, r.rating) + (1, 2, 5.0) + >>> (r[0], r[1], r[2]) + (1, 2, 5.0) + """ - def __repr__(self): - return "Rating(%d, %d, %s)" % (self.user, self.product, self.rating) + def __reduce__(self): + return Rating, (int(self.user), int(self.product), float(self.rating)) class MatrixFactorizationModel(JavaModelWrapper): @@ -51,7 +55,7 @@ class MatrixFactorizationModel(JavaModelWrapper): >>> testset = sc.parallelize([(1, 2), (1, 1)]) >>> model = ALS.train(ratings, 1, seed=10) >>> model.predictAll(testset).collect() - [Rating(1, 1, 1.0471...), Rating(1, 2, 1.9679...)] + [Rating(user=1, product=1, rating=1.0471...), Rating(user=1, product=2, rating=1.9679...)] >>> model = ALS.train(ratings, 4, seed=10) >>> model.userFeatures().collect() |