aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorKai Jiang <jiangkai@gmail.com>2016-02-11 15:50:33 -0800
committerXiangrui Meng <meng@databricks.com>2016-02-11 15:50:33 -0800
commitc8f667d7c1a0b02685e17b6f498879b05ced9b9d (patch)
tree5875714d311c589973ea9adcd79ca3bd582d4238 /python
parent574571c87098795a2206a113ee9ed4bafba8f00f (diff)
downloadspark-c8f667d7c1a0b02685e17b6f498879b05ced9b9d.tar.gz
spark-c8f667d7c1a0b02685e17b6f498879b05ced9b9d.tar.bz2
spark-c8f667d7c1a0b02685e17b6f498879b05ced9b9d.zip
[SPARK-13037][ML][PYSPARK] PySpark ml.recommendation support export/import
PySpark ml.recommendation support export/import. Author: Kai Jiang <jiangkai@gmail.com> Closes #11044 from vectorijk/spark-13037.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/ml/recommendation.py31
1 files changed, 27 insertions, 4 deletions
diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py
index 08180a2f25..ef9448855e 100644
--- a/python/pyspark/ml/recommendation.py
+++ b/python/pyspark/ml/recommendation.py
@@ -16,7 +16,7 @@
#
from pyspark import since
-from pyspark.ml.util import keyword_only
+from pyspark.ml.util import *
from pyspark.ml.wrapper import JavaEstimator, JavaModel
from pyspark.ml.param.shared import *
from pyspark.mllib.common import inherit_doc
@@ -26,7 +26,8 @@ __all__ = ['ALS', 'ALSModel']
@inherit_doc
-class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, HasRegParam, HasSeed):
+class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, HasRegParam, HasSeed,
+ MLWritable, MLReadable):
"""
Alternating Least Squares (ALS) matrix factorization.
@@ -81,6 +82,27 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
Row(user=1, item=0, prediction=2.6258413791656494)
>>> predictions[2]
Row(user=2, item=0, prediction=-1.5018409490585327)
+ >>> import os, tempfile
+ >>> path = tempfile.mkdtemp()
+ >>> als_path = path + "/als"
+ >>> als.save(als_path)
+ >>> als2 = ALS.load(als_path)
+ >>> als.getMaxIter()
+ 5
+ >>> model_path = path + "/als_model"
+ >>> model.save(model_path)
+ >>> model2 = ALSModel.load(model_path)
+ >>> model.rank == model2.rank
+ True
+ >>> sorted(model.userFactors.collect()) == sorted(model2.userFactors.collect())
+ True
+ >>> sorted(model.itemFactors.collect()) == sorted(model2.itemFactors.collect())
+ True
+ >>> from shutil import rmtree
+ >>> try:
+ ... rmtree(path)
+ ... except OSError:
+ ... pass
.. versionadded:: 1.4.0
"""
@@ -274,7 +296,7 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
return self.getOrDefault(self.nonnegative)
-class ALSModel(JavaModel):
+class ALSModel(JavaModel, MLWritable, MLReadable):
"""
Model fitted by ALS.
@@ -308,9 +330,10 @@ class ALSModel(JavaModel):
if __name__ == "__main__":
import doctest
+ import pyspark.ml.recommendation
from pyspark.context import SparkContext
from pyspark.sql import SQLContext
- globs = globals().copy()
+ globs = pyspark.ml.recommendation.__dict__.copy()
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
sc = SparkContext("local[2]", "ml.recommendation tests")