aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/recommendation.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/mllib/recommendation.py')
-rw-r--r--python/pyspark/mllib/recommendation.py20
1 files changed, 18 insertions, 2 deletions
diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py
index 0d99e6dedf..03d7d01147 100644
--- a/python/pyspark/mllib/recommendation.py
+++ b/python/pyspark/mllib/recommendation.py
@@ -19,7 +19,8 @@ from collections import namedtuple
from pyspark import SparkContext
from pyspark.rdd import RDD
-from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc
+from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc
+from pyspark.mllib.util import Saveable, JavaLoader
__all__ = ['MatrixFactorizationModel', 'ALS', 'Rating']
@@ -39,7 +40,8 @@ class Rating(namedtuple("Rating", ["user", "product", "rating"])):
return Rating, (int(self.user), int(self.product), float(self.rating))
-class MatrixFactorizationModel(JavaModelWrapper):
+@inherit_doc
+class MatrixFactorizationModel(JavaModelWrapper, Saveable, JavaLoader):
"""A matrix factorisation model trained by regularized alternating
least-squares.
@@ -81,6 +83,17 @@ class MatrixFactorizationModel(JavaModelWrapper):
>>> model = ALS.trainImplicit(ratings, 1, nonnegative=True, seed=10)
>>> model.predict(2,2)
0.43...
+
+ >>> import os, tempfile
+ >>> path = tempfile.mkdtemp()
+ >>> model.save(sc, path)
+ >>> sameModel = MatrixFactorizationModel.load(sc, path)
+ >>> sameModel.predict(2,2)
+ 0.43...
+ >>> try:
+ ... os.removedirs(path)
+ ... except:
+ ... pass
"""
def predict(self, user, product):
return self._java_model.predict(int(user), int(product))
@@ -98,6 +111,9 @@ class MatrixFactorizationModel(JavaModelWrapper):
def productFeatures(self):
return self.call("getProductFeatures")
+ def save(self, sc, path):
+ self.call("save", sc._jsc.sc(), path)
+
class ALS(object):