diff options
author | Xiangrui Meng <meng@databricks.com> | 2014-11-18 10:35:29 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2014-11-18 10:35:29 -0800 |
commit | b54c6ab3c54e65238d6766832ea1f3fcd694f2fd (patch) | |
tree | 1a06f90e04a2d11f7c39fde74bdeaf55a9334101 /python/pyspark | |
parent | 8fbf72b7903b5bbec8d949151aa4693b4af26ff5 (diff) | |
download | spark-b54c6ab3c54e65238d6766832ea1f3fcd694f2fd.tar.gz spark-b54c6ab3c54e65238d6766832ea1f3fcd694f2fd.tar.bz2 spark-b54c6ab3c54e65238d6766832ea1f3fcd694f2fd.zip |
[SPARK-4396] allow lookup by index in Python's Rating
In PySpark, ALS can take an RDD of (user, product, rating) tuples as input. However, model.predict outputs an RDD of Rating. So on the input side, users can use r[0], r[1], r[2], while on the output side, users have to use r.user, r.product, r.rating. We should allow lookup by index in Rating by making Rating a namedtuple.
davies
<!-- Reviewable:start -->
[<img src="https://reviewable.io/review_button.png" height=40 alt="Review on Reviewable"/>](https://reviewable.io/reviews/apache/spark/3261)
<!-- Reviewable:end -->
Author: Xiangrui Meng <meng@databricks.com>
Closes #3261 from mengxr/SPARK-4396 and squashes the following commits:
543aef0 [Xiangrui Meng] use named tuple to implement ALS
0b61bae [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-4396
d3bd7d4 [Xiangrui Meng] allow lookup by index in Python's Rating
Diffstat (limited to 'python/pyspark')
-rw-r--r-- | python/pyspark/mllib/recommendation.py | 26 |
1 files changed, 15 insertions, 11 deletions
diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 41bbd9a779..2bcbf2aaf8 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -15,24 +15,28 @@ # limitations under the License. # +from collections import namedtuple + from pyspark import SparkContext from pyspark.rdd import RDD from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, _to_java_object_rdd -__all__ = ['MatrixFactorizationModel', 'ALS'] +__all__ = ['MatrixFactorizationModel', 'ALS', 'Rating'] -class Rating(object): - def __init__(self, user, product, rating): - self.user = int(user) - self.product = int(product) - self.rating = float(rating) +class Rating(namedtuple("Rating", ["user", "product", "rating"])): + """ + Represents a (user, product, rating) tuple. - def __reduce__(self): - return Rating, (self.user, self.product, self.rating) + >>> r = Rating(1, 2, 5.0) + >>> (r.user, r.product, r.rating) + (1, 2, 5.0) + >>> (r[0], r[1], r[2]) + (1, 2, 5.0) + """ - def __repr__(self): - return "Rating(%d, %d, %s)" % (self.user, self.product, self.rating) + def __reduce__(self): + return Rating, (int(self.user), int(self.product), float(self.rating)) class MatrixFactorizationModel(JavaModelWrapper): @@ -51,7 +55,7 @@ class MatrixFactorizationModel(JavaModelWrapper): >>> testset = sc.parallelize([(1, 2), (1, 1)]) >>> model = ALS.train(ratings, 1, seed=10) >>> model.predictAll(testset).collect() - [Rating(1, 1, 1.0471...), Rating(1, 2, 1.9679...)] + [Rating(user=1, product=1, rating=1.0471...), Rating(user=1, product=2, rating=1.9679...)] >>> model = ALS.train(ratings, 4, seed=10) >>> model.userFeatures().collect() |