diff options
author | Xiangrui Meng <meng@databricks.com> | 2015-03-01 16:26:57 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-03-01 16:26:57 -0800 |
commit | aedbbaa3dda9cbc154cd52c07f6d296b972b0eb2 (patch) | |
tree | 4ba785e145c21b93e1e4c49ae33899642b1f3cea /python/pyspark/mllib/recommendation.py | |
parent | fd8d283eeb98e310b1e85ef8c3a8af9e547ab5e0 (diff) | |
download | spark-aedbbaa3dda9cbc154cd52c07f6d296b972b0eb2.tar.gz spark-aedbbaa3dda9cbc154cd52c07f6d296b972b0eb2.tar.bz2 spark-aedbbaa3dda9cbc154cd52c07f6d296b972b0eb2.zip |
[SPARK-6053][MLLIB] support save/load in PySpark's ALS
A simple wrapper to save/load `MatrixFactorizationModel` in Python. jkbradley
Author: Xiangrui Meng <meng@databricks.com>
Closes #4811 from mengxr/SPARK-5991 and squashes the following commits:
f135dac [Xiangrui Meng] update save doc
57e5200 [Xiangrui Meng] address comments
06140a4 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5991
282ec8d [Xiangrui Meng] support save/load in PySpark's ALS
Diffstat (limited to 'python/pyspark/mllib/recommendation.py')
-rw-r--r-- | python/pyspark/mllib/recommendation.py | 20 |
1 files changed, 18 insertions, 2 deletions
diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 0d99e6dedf..03d7d01147 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -19,7 +19,8 @@ from collections import namedtuple from pyspark import SparkContext from pyspark.rdd import RDD -from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc +from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc +from pyspark.mllib.util import Saveable, JavaLoader __all__ = ['MatrixFactorizationModel', 'ALS', 'Rating'] @@ -39,7 +40,8 @@ class Rating(namedtuple("Rating", ["user", "product", "rating"])): return Rating, (int(self.user), int(self.product), float(self.rating)) -class MatrixFactorizationModel(JavaModelWrapper): +@inherit_doc +class MatrixFactorizationModel(JavaModelWrapper, Saveable, JavaLoader): """A matrix factorisation model trained by regularized alternating least-squares. @@ -81,6 +83,17 @@ class MatrixFactorizationModel(JavaModelWrapper): >>> model = ALS.trainImplicit(ratings, 1, nonnegative=True, seed=10) >>> model.predict(2,2) 0.43... + + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> model.save(sc, path) + >>> sameModel = MatrixFactorizationModel.load(sc, path) + >>> sameModel.predict(2,2) + 0.43... + >>> try: + ... os.removedirs(path) + ... except: + ... pass """ def predict(self, user, product): return self._java_model.predict(int(user), int(product)) @@ -98,6 +111,9 @@ class MatrixFactorizationModel(JavaModelWrapper): def productFeatures(self): return self.call("getProductFeatures") + def save(self, sc, path): + self.call("save", sc._jsc.sc(), path) + class ALS(object): |