diff options
author | Xiangrui Meng <meng@databricks.com> | 2015-03-01 16:26:57 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-03-01 16:26:57 -0800 |
commit | aedbbaa3dda9cbc154cd52c07f6d296b972b0eb2 (patch) | |
tree | 4ba785e145c21b93e1e4c49ae33899642b1f3cea /python | |
parent | fd8d283eeb98e310b1e85ef8c3a8af9e547ab5e0 (diff) | |
download | spark-aedbbaa3dda9cbc154cd52c07f6d296b972b0eb2.tar.gz spark-aedbbaa3dda9cbc154cd52c07f6d296b972b0eb2.tar.bz2 spark-aedbbaa3dda9cbc154cd52c07f6d296b972b0eb2.zip |
[SPARK-6053][MLLIB] support save/load in PySpark's ALS
A simple wrapper to save/load `MatrixFactorizationModel` in Python. jkbradley
Author: Xiangrui Meng <meng@databricks.com>
Closes #4811 from mengxr/SPARK-5991 and squashes the following commits:
f135dac [Xiangrui Meng] update save doc
57e5200 [Xiangrui Meng] address comments
06140a4 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5991
282ec8d [Xiangrui Meng] support save/load in PySpark's ALS
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/mllib/recommendation.py | 20 | ||||
-rw-r--r-- | python/pyspark/mllib/util.py | 58 |
2 files changed, 76 insertions, 2 deletions
diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 0d99e6dedf..03d7d01147 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -19,7 +19,8 @@ from collections import namedtuple from pyspark import SparkContext from pyspark.rdd import RDD -from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc +from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc +from pyspark.mllib.util import Saveable, JavaLoader __all__ = ['MatrixFactorizationModel', 'ALS', 'Rating'] @@ -39,7 +40,8 @@ class Rating(namedtuple("Rating", ["user", "product", "rating"])): return Rating, (int(self.user), int(self.product), float(self.rating)) -class MatrixFactorizationModel(JavaModelWrapper): +@inherit_doc +class MatrixFactorizationModel(JavaModelWrapper, Saveable, JavaLoader): """A matrix factorisation model trained by regularized alternating least-squares. @@ -81,6 +83,17 @@ class MatrixFactorizationModel(JavaModelWrapper): >>> model = ALS.trainImplicit(ratings, 1, nonnegative=True, seed=10) >>> model.predict(2,2) 0.43... + + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> model.save(sc, path) + >>> sameModel = MatrixFactorizationModel.load(sc, path) + >>> sameModel.predict(2,2) + 0.43... + >>> try: + ... os.removedirs(path) + ... except: + ... pass """ def predict(self, user, product): return self._java_model.predict(int(user), int(product)) @@ -98,6 +111,9 @@ class MatrixFactorizationModel(JavaModelWrapper): def productFeatures(self): return self.call("getProductFeatures") + def save(self, sc, path): + self.call("save", sc._jsc.sc(), path) + class ALS(object): diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 4ed978b454..17d43eadba 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -168,6 +168,64 @@ class MLUtils(object): return callMLlibFunc("loadLabeledPoints", sc, path, minPartitions) +class Saveable(object): + """ + Mixin for models and transformers which may be saved as files. + """ + + def save(self, sc, path): + """ + Save this model to the given path. + + This saves: + * human-readable (JSON) model metadata to path/metadata/ + * Parquet formatted data to path/data/ + + The model may be loaded using py:meth:`Loader.load`. + + :param sc: Spark context used to save model data. + :param path: Path specifying the directory in which to save + this model. If the directory already exists, + this method throws an exception. + """ + raise NotImplementedError + + +class Loader(object): + """ + Mixin for classes which can load saved models from files. + """ + + @classmethod + def load(cls, sc, path): + """ + Load a model from the given path. The model should have been + saved using py:meth:`Saveable.save`. + + :param sc: Spark context used for loading model files. + :param path: Path specifying the directory to which the model + was saved. + :return: model instance + """ + raise NotImplemented + + +class JavaLoader(Loader): + """ + Mixin for classes which can load saved models using its Scala + implementation. + """ + + @classmethod + def load(cls, sc, path): + java_package = cls.__module__.replace("pyspark", "org.apache.spark") + java_class = ".".join([java_package, cls.__name__]) + java_obj = sc._jvm + for name in java_class.split("."): + java_obj = getattr(java_obj, name) + return cls(java_obj.load(sc._jsc.sc(), path)) + + def _test(): import doctest from pyspark.context import SparkContext |