diff options
Diffstat (limited to 'python/pyspark/mllib')
-rw-r--r-- | python/pyspark/mllib/recommendation.py | 32 |
1 files changed, 28 insertions, 4 deletions
diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 95047b5b7b..b9442b0d16 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -76,16 +76,28 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader): >>> first_user = model.userFeatures().take(1)[0] >>> latents = first_user[1] - >>> len(latents) == 4 - True + >>> len(latents) + 4 >>> model.productFeatures().collect() [(1, array('d', [...])), (2, array('d', [...]))] >>> first_product = model.productFeatures().take(1)[0] >>> latents = first_product[1] - >>> len(latents) == 4 - True + >>> len(latents) + 4 + + >>> products_for_users = model.recommendProductsForUsers(1).collect() + >>> len(products_for_users) + 2 + >>> products_for_users[0] + (1, (Rating(user=1, product=2, rating=...),)) + + >>> users_for_products = model.recommendUsersForProducts(1).collect() + >>> len(users_for_products) + 2 + >>> users_for_products[0] + (1, (Rating(user=2, product=1, rating=...),)) >>> model = ALS.train(ratings, 1, nonnegative=True, seed=10) >>> model.predict(2, 2) @@ -166,6 +178,18 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader): """ return list(self.call("recommendProducts", user, num)) + def recommendProductsForUsers(self, num): + """ + Recommends top "num" products for all users. The number returned may be less than this. + """ + return self.call("wrappedRecommendProductsForUsers", num) + + def recommendUsersForProducts(self, num): + """ + Recommends top "num" users for all products. The number returned may be less than this. + """ + return self.call("wrappedRecommendUsersForProducts", num) + @property @since("1.4.0") def rank(self): |