diff options
author | Vladimir Vladimirov <vladimir.vladimirov@magnetic.com> | 2015-10-09 14:16:13 -0700 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2015-10-09 14:16:13 -0700 |
commit | c1b4ce43264fa8b9945df3c599a51d4d2a675705 (patch) | |
tree | 192333393be8d38f558327659a03904e5d3a5c68 /python | |
parent | 63c340a710b24869410d56602b712fbfe443e6f0 (diff) | |
download | spark-c1b4ce43264fa8b9945df3c599a51d4d2a675705.tar.gz spark-c1b4ce43264fa8b9945df3c599a51d4d2a675705.tar.bz2 spark-c1b4ce43264fa8b9945df3c599a51d4d2a675705.zip |
[SPARK-10535] Sync up API for matrix factorization model between Scala and PySpark
Support for recommendUsersForProducts and recommendProductsForUsers in matrix factorization model for PySpark
Author: Vladimir Vladimirov <vladimir.vladimirov@magnetic.com>
Closes #8700 from smartkiwi/SPARK-10535_.
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/mllib/recommendation.py | 32 |
1 files changed, 28 insertions, 4 deletions
diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 95047b5b7b..b9442b0d16 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -76,16 +76,28 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader): >>> first_user = model.userFeatures().take(1)[0] >>> latents = first_user[1] - >>> len(latents) == 4 - True + >>> len(latents) + 4 >>> model.productFeatures().collect() [(1, array('d', [...])), (2, array('d', [...]))] >>> first_product = model.productFeatures().take(1)[0] >>> latents = first_product[1] - >>> len(latents) == 4 - True + >>> len(latents) + 4 + + >>> products_for_users = model.recommendProductsForUsers(1).collect() + >>> len(products_for_users) + 2 + >>> products_for_users[0] + (1, (Rating(user=1, product=2, rating=...),)) + + >>> users_for_products = model.recommendUsersForProducts(1).collect() + >>> len(users_for_products) + 2 + >>> users_for_products[0] + (1, (Rating(user=2, product=1, rating=...),)) >>> model = ALS.train(ratings, 1, nonnegative=True, seed=10) >>> model.predict(2, 2) @@ -166,6 +178,18 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader): """ return list(self.call("recommendProducts", user, num)) + def recommendProductsForUsers(self, num): + """ + Recommends top "num" products for all users. The number returned may be less than this. + """ + return self.call("wrappedRecommendProductsForUsers", num) + + def recommendUsersForProducts(self, num): + """ + Recommends top "num" users for all products. The number returned may be less than this. + """ + return self.call("wrappedRecommendUsersForProducts", num) + @property @since("1.4.0") def rank(self): |