diff options
author | Kai Jiang <jiangkai@gmail.com> | 2016-02-11 15:50:33 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2016-02-11 15:50:33 -0800 |
commit | c8f667d7c1a0b02685e17b6f498879b05ced9b9d (patch) | |
tree | 5875714d311c589973ea9adcd79ca3bd582d4238 /python/pyspark | |
parent | 574571c87098795a2206a113ee9ed4bafba8f00f (diff) | |
download | spark-c8f667d7c1a0b02685e17b6f498879b05ced9b9d.tar.gz spark-c8f667d7c1a0b02685e17b6f498879b05ced9b9d.tar.bz2 spark-c8f667d7c1a0b02685e17b6f498879b05ced9b9d.zip |
[SPARK-13037][ML][PYSPARK] PySpark ml.recommendation support export/import
PySpark ml.recommendation support export/import.
Author: Kai Jiang <jiangkai@gmail.com>
Closes #11044 from vectorijk/spark-13037.
Diffstat (limited to 'python/pyspark')
-rw-r--r-- | python/pyspark/ml/recommendation.py | 31 |
1 files changed, 27 insertions, 4 deletions
diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index 08180a2f25..ef9448855e 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -16,7 +16,7 @@ # from pyspark import since -from pyspark.ml.util import keyword_only +from pyspark.ml.util import * from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import * from pyspark.mllib.common import inherit_doc @@ -26,7 +26,8 @@ __all__ = ['ALS', 'ALSModel'] @inherit_doc -class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, HasRegParam, HasSeed): +class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, HasRegParam, HasSeed, + MLWritable, MLReadable): """ Alternating Least Squares (ALS) matrix factorization. @@ -81,6 +82,27 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha Row(user=1, item=0, prediction=2.6258413791656494) >>> predictions[2] Row(user=2, item=0, prediction=-1.5018409490585327) + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> als_path = path + "/als" + >>> als.save(als_path) + >>> als2 = ALS.load(als_path) + >>> als.getMaxIter() + 5 + >>> model_path = path + "/als_model" + >>> model.save(model_path) + >>> model2 = ALSModel.load(model_path) + >>> model.rank == model2.rank + True + >>> sorted(model.userFactors.collect()) == sorted(model2.userFactors.collect()) + True + >>> sorted(model.itemFactors.collect()) == sorted(model2.itemFactors.collect()) + True + >>> from shutil import rmtree + >>> try: + ... rmtree(path) + ... except OSError: + ... pass .. versionadded:: 1.4.0 """ @@ -274,7 +296,7 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha return self.getOrDefault(self.nonnegative) -class ALSModel(JavaModel): +class ALSModel(JavaModel, MLWritable, MLReadable): """ Model fitted by ALS. @@ -308,9 +330,10 @@ class ALSModel(JavaModel): if __name__ == "__main__": import doctest + import pyspark.ml.recommendation from pyspark.context import SparkContext from pyspark.sql import SQLContext - globs = globals().copy() + globs = pyspark.ml.recommendation.__dict__.copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: sc = SparkContext("local[2]", "ml.recommendation tests") |