aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-03-01 16:26:57 -0800
committerXiangrui Meng <meng@databricks.com>2015-03-01 16:26:57 -0800
commitaedbbaa3dda9cbc154cd52c07f6d296b972b0eb2 (patch)
tree4ba785e145c21b93e1e4c49ae33899642b1f3cea /python
parentfd8d283eeb98e310b1e85ef8c3a8af9e547ab5e0 (diff)
downloadspark-aedbbaa3dda9cbc154cd52c07f6d296b972b0eb2.tar.gz
spark-aedbbaa3dda9cbc154cd52c07f6d296b972b0eb2.tar.bz2
spark-aedbbaa3dda9cbc154cd52c07f6d296b972b0eb2.zip
[SPARK-6053][MLLIB] support save/load in PySpark's ALS
A simple wrapper to save/load `MatrixFactorizationModel` in Python. jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #4811 from mengxr/SPARK-5991 and squashes the following commits: f135dac [Xiangrui Meng] update save doc 57e5200 [Xiangrui Meng] address comments 06140a4 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5991 282ec8d [Xiangrui Meng] support save/load in PySpark's ALS
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/mllib/recommendation.py20
-rw-r--r--python/pyspark/mllib/util.py58
2 files changed, 76 insertions, 2 deletions
diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py
index 0d99e6dedf..03d7d01147 100644
--- a/python/pyspark/mllib/recommendation.py
+++ b/python/pyspark/mllib/recommendation.py
@@ -19,7 +19,8 @@ from collections import namedtuple
from pyspark import SparkContext
from pyspark.rdd import RDD
-from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc
+from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc
+from pyspark.mllib.util import Saveable, JavaLoader
__all__ = ['MatrixFactorizationModel', 'ALS', 'Rating']
@@ -39,7 +40,8 @@ class Rating(namedtuple("Rating", ["user", "product", "rating"])):
return Rating, (int(self.user), int(self.product), float(self.rating))
-class MatrixFactorizationModel(JavaModelWrapper):
+@inherit_doc
+class MatrixFactorizationModel(JavaModelWrapper, Saveable, JavaLoader):
"""A matrix factorisation model trained by regularized alternating
least-squares.
@@ -81,6 +83,17 @@ class MatrixFactorizationModel(JavaModelWrapper):
>>> model = ALS.trainImplicit(ratings, 1, nonnegative=True, seed=10)
>>> model.predict(2,2)
0.43...
+
+ >>> import os, tempfile
+ >>> path = tempfile.mkdtemp()
+ >>> model.save(sc, path)
+ >>> sameModel = MatrixFactorizationModel.load(sc, path)
+ >>> sameModel.predict(2,2)
+ 0.43...
+ >>> try:
+ ... os.removedirs(path)
+ ... except:
+ ... pass
"""
def predict(self, user, product):
return self._java_model.predict(int(user), int(product))
@@ -98,6 +111,9 @@ class MatrixFactorizationModel(JavaModelWrapper):
def productFeatures(self):
return self.call("getProductFeatures")
+ def save(self, sc, path):
+ self.call("save", sc._jsc.sc(), path)
+
class ALS(object):
diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py
index 4ed978b454..17d43eadba 100644
--- a/python/pyspark/mllib/util.py
+++ b/python/pyspark/mllib/util.py
@@ -168,6 +168,64 @@ class MLUtils(object):
return callMLlibFunc("loadLabeledPoints", sc, path, minPartitions)
+class Saveable(object):
+ """
+ Mixin for models and transformers which may be saved as files.
+ """
+
+ def save(self, sc, path):
+ """
+ Save this model to the given path.
+
+ This saves:
+ * human-readable (JSON) model metadata to path/metadata/
+ * Parquet formatted data to path/data/
+
+ The model may be loaded using py:meth:`Loader.load`.
+
+ :param sc: Spark context used to save model data.
+ :param path: Path specifying the directory in which to save
+ this model. If the directory already exists,
+ this method throws an exception.
+ """
+ raise NotImplementedError
+
+
+class Loader(object):
+ """
+ Mixin for classes which can load saved models from files.
+ """
+
+ @classmethod
+ def load(cls, sc, path):
+ """
+ Load a model from the given path. The model should have been
+ saved using py:meth:`Saveable.save`.
+
+ :param sc: Spark context used for loading model files.
+ :param path: Path specifying the directory to which the model
+ was saved.
+ :return: model instance
+ """
+ raise NotImplemented
+
+
+class JavaLoader(Loader):
+ """
+ Mixin for classes which can load saved models using its Scala
+ implementation.
+ """
+
+ @classmethod
+ def load(cls, sc, path):
+ java_package = cls.__module__.replace("pyspark", "org.apache.spark")
+ java_class = ".".join([java_package, cls.__name__])
+ java_obj = sc._jvm
+ for name in java_class.split("."):
+ java_obj = getattr(java_obj, name)
+ return cls(java_obj.load(sc._jsc.sc(), path))
+
+
def _test():
import doctest
from pyspark.context import SparkContext