diff options
Diffstat (limited to 'python/pyspark/mllib/recommendation.py')
-rw-r--r-- | python/pyspark/mllib/recommendation.py | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index c5c4c13dae..80e0a356bb 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -15,6 +15,7 @@ # limitations under the License. # +import array from collections import namedtuple from pyspark import SparkContext @@ -104,14 +105,14 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader): assert isinstance(user_product, RDD), "user_product should be RDD of (user, product)" first = user_product.first() assert len(first) == 2, "user_product should be RDD of (user, product)" - user_product = user_product.map(lambda (u, p): (int(u), int(p))) + user_product = user_product.map(lambda u_p: (int(u_p[0]), int(u_p[1]))) return self.call("predict", user_product) def userFeatures(self): - return self.call("getUserFeatures") + return self.call("getUserFeatures").mapValues(lambda v: array.array('d', v)) def productFeatures(self): - return self.call("getProductFeatures") + return self.call("getProductFeatures").mapValues(lambda v: array.array('d', v)) @classmethod def load(cls, sc, path): |