aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark')
-rw-r--r--python/pyspark/mllib/recommendation.py31
1 files changed, 31 insertions, 0 deletions
diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py
index 17f96b8700..22872dbbe3 100644
--- a/python/pyspark/mllib/recommendation.py
+++ b/python/pyspark/mllib/recommendation.py
@@ -53,6 +53,23 @@ class MatrixFactorizationModel(object):
>>> model = ALS.train(ratings, 1)
>>> model.predictAll(testset).count() == 2
True
+
+ >>> model = ALS.train(ratings, 4)
+ >>> model.userFeatures().count() == 2
+ True
+
+ >>> first_user = model.userFeatures().take(1)[0]
+ >>> latents = first_user[1]
+ >>> len(latents) == 4
+ True
+
+ >>> model.productFeatures().count() == 2
+ True
+
+ >>> first_product = model.productFeatures().take(1)[0]
+ >>> latents = first_product[1]
+ >>> len(latents) == 4
+ True
"""
def __init__(self, sc, java_model):
@@ -83,6 +100,20 @@ class MatrixFactorizationModel(object):
return RDD(sc._jvm.SerDe.javaToPython(jresult), sc,
AutoBatchedSerializer(PickleSerializer()))
+ def userFeatures(self):
+ sc = self._context
+ juf = self._java_model.userFeatures()
+ juf = sc._jvm.SerDe.fromTuple2RDD(juf).toJavaRDD()
+ return RDD(sc._jvm.PythonRDD.javaToPython(juf), sc,
+ AutoBatchedSerializer(PickleSerializer()))
+
+ def productFeatures(self):
+ sc = self._context
+ jpf = self._java_model.productFeatures()
+ jpf = sc._jvm.SerDe.fromTuple2RDD(jpf).toJavaRDD()
+ return RDD(sc._jvm.PythonRDD.javaToPython(jpf), sc,
+ AutoBatchedSerializer(PickleSerializer()))
+
class ALS(object):