aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-11-18 10:35:29 -0800
committerXiangrui Meng <meng@databricks.com>2014-11-18 10:35:29 -0800
commitb54c6ab3c54e65238d6766832ea1f3fcd694f2fd (patch)
tree1a06f90e04a2d11f7c39fde74bdeaf55a9334101 /python/pyspark/mllib
parent8fbf72b7903b5bbec8d949151aa4693b4af26ff5 (diff)
downloadspark-b54c6ab3c54e65238d6766832ea1f3fcd694f2fd.tar.gz
spark-b54c6ab3c54e65238d6766832ea1f3fcd694f2fd.tar.bz2
spark-b54c6ab3c54e65238d6766832ea1f3fcd694f2fd.zip
[SPARK-4396] allow lookup by index in Python's Rating
In PySpark, ALS can take an RDD of (user, product, rating) tuples as input. However, model.predict outputs an RDD of Rating. So on the input side, users can use r[0], r[1], r[2], while on the output side, users have to use r.user, r.product, r.rating. We should allow lookup by index in Rating by making Rating a namedtuple. davies <!-- Reviewable:start --> [<img src="https://reviewable.io/review_button.png" height=40 alt="Review on Reviewable"/>](https://reviewable.io/reviews/apache/spark/3261) <!-- Reviewable:end --> Author: Xiangrui Meng <meng@databricks.com> Closes #3261 from mengxr/SPARK-4396 and squashes the following commits: 543aef0 [Xiangrui Meng] use named tuple to implement ALS 0b61bae [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-4396 d3bd7d4 [Xiangrui Meng] allow lookup by index in Python's Rating
Diffstat (limited to 'python/pyspark/mllib')
-rw-r--r--python/pyspark/mllib/recommendation.py26
1 files changed, 15 insertions, 11 deletions
diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py
index 41bbd9a779..2bcbf2aaf8 100644
--- a/python/pyspark/mllib/recommendation.py
+++ b/python/pyspark/mllib/recommendation.py
@@ -15,24 +15,28 @@
# limitations under the License.
#
+from collections import namedtuple
+
from pyspark import SparkContext
from pyspark.rdd import RDD
from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, _to_java_object_rdd
-__all__ = ['MatrixFactorizationModel', 'ALS']
+__all__ = ['MatrixFactorizationModel', 'ALS', 'Rating']
-class Rating(object):
- def __init__(self, user, product, rating):
- self.user = int(user)
- self.product = int(product)
- self.rating = float(rating)
+class Rating(namedtuple("Rating", ["user", "product", "rating"])):
+ """
+ Represents a (user, product, rating) tuple.
- def __reduce__(self):
- return Rating, (self.user, self.product, self.rating)
+ >>> r = Rating(1, 2, 5.0)
+ >>> (r.user, r.product, r.rating)
+ (1, 2, 5.0)
+ >>> (r[0], r[1], r[2])
+ (1, 2, 5.0)
+ """
- def __repr__(self):
- return "Rating(%d, %d, %s)" % (self.user, self.product, self.rating)
+ def __reduce__(self):
+ return Rating, (int(self.user), int(self.product), float(self.rating))
class MatrixFactorizationModel(JavaModelWrapper):
@@ -51,7 +55,7 @@ class MatrixFactorizationModel(JavaModelWrapper):
>>> testset = sc.parallelize([(1, 2), (1, 1)])
>>> model = ALS.train(ratings, 1, seed=10)
>>> model.predictAll(testset).collect()
- [Rating(1, 1, 1.0471...), Rating(1, 2, 1.9679...)]
+ [Rating(user=1, product=1, rating=1.0471...), Rating(user=1, product=2, rating=1.9679...)]
>>> model = ALS.train(ratings, 4, seed=10)
>>> model.userFeatures().collect()