aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala8
-rw-r--r--python/pyspark/mllib/recommendation.py32
2 files changed, 36 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala
index 534edac56b..eeb7cba882 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala
@@ -42,4 +42,12 @@ private[python] class MatrixFactorizationModelWrapper(model: MatrixFactorization
case (product, feature) => (product, Vectors.dense(feature))
}.asInstanceOf[RDD[(Any, Any)]])
}
+
+ def wrappedRecommendProductsForUsers(num: Int): RDD[Array[Any]] = {
+ SerDe.fromTuple2RDD(recommendProductsForUsers(num).asInstanceOf[RDD[(Any, Any)]])
+ }
+
+ def wrappedRecommendUsersForProducts(num: Int): RDD[Array[Any]] = {
+ SerDe.fromTuple2RDD(recommendUsersForProducts(num).asInstanceOf[RDD[(Any, Any)]])
+ }
}
diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py
index 95047b5b7b..b9442b0d16 100644
--- a/python/pyspark/mllib/recommendation.py
+++ b/python/pyspark/mllib/recommendation.py
@@ -76,16 +76,28 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader):
>>> first_user = model.userFeatures().take(1)[0]
>>> latents = first_user[1]
- >>> len(latents) == 4
- True
+ >>> len(latents)
+ 4
>>> model.productFeatures().collect()
[(1, array('d', [...])), (2, array('d', [...]))]
>>> first_product = model.productFeatures().take(1)[0]
>>> latents = first_product[1]
- >>> len(latents) == 4
- True
+ >>> len(latents)
+ 4
+
+ >>> products_for_users = model.recommendProductsForUsers(1).collect()
+ >>> len(products_for_users)
+ 2
+ >>> products_for_users[0]
+ (1, (Rating(user=1, product=2, rating=...),))
+
+ >>> users_for_products = model.recommendUsersForProducts(1).collect()
+ >>> len(users_for_products)
+ 2
+ >>> users_for_products[0]
+ (1, (Rating(user=2, product=1, rating=...),))
>>> model = ALS.train(ratings, 1, nonnegative=True, seed=10)
>>> model.predict(2, 2)
@@ -166,6 +178,18 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader):
"""
return list(self.call("recommendProducts", user, num))
+ def recommendProductsForUsers(self, num):
+ """
+ Recommends top "num" products for all users. The number returned may be less than this.
+ """
+ return self.call("wrappedRecommendProductsForUsers", num)
+
+ def recommendUsersForProducts(self, num):
+ """
+ Recommends top "num" users for all products. The number returned may be less than this.
+ """
+ return self.call("wrappedRecommendUsersForProducts", num)
+
@property
@since("1.4.0")
def rank(self):