aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/regression.py
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-01-29 09:22:24 -0800
committerJoseph K. Bradley <joseph@databricks.com>2016-01-29 09:22:24 -0800
commite51b6eaa9e9c007e194d858195291b2b9fb27322 (patch)
treeb6af90c439154fe7514fd32e47a56a693ffd745a /python/pyspark/ml/regression.py
parent55561e7693dd2a5bf3c7f8026c725421801fd0ec (diff)
downloadspark-e51b6eaa9e9c007e194d858195291b2b9fb27322.tar.gz
spark-e51b6eaa9e9c007e194d858195291b2b9fb27322.tar.bz2
spark-e51b6eaa9e9c007e194d858195291b2b9fb27322.zip
[SPARK-13032][ML][PYSPARK] PySpark support model export/import and take LinearRegression as example
* Implement ```MLWriter/MLWritable/MLReader/MLReadable``` for PySpark. * Making ```LinearRegression``` to support ```save/load``` as example. After this merged, the work for other transformers/estimators will be easy, then we can list and distribute the tasks to the community. cc mengxr jkbradley Author: Yanbo Liang <ybliang8@gmail.com> Author: Joseph K. Bradley <joseph@databricks.com> Closes #10469 from yanboliang/spark-11939.
Diffstat (limited to 'python/pyspark/ml/regression.py')
-rw-r--r--python/pyspark/ml/regression.py30
1 files changed, 25 insertions, 5 deletions
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 74a2248ed0..20dc6c2db9 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -18,9 +18,9 @@
import warnings
from pyspark import since
-from pyspark.ml.util import keyword_only
-from pyspark.ml.wrapper import JavaEstimator, JavaModel
from pyspark.ml.param.shared import *
+from pyspark.ml.util import *
+from pyspark.ml.wrapper import JavaEstimator, JavaModel
from pyspark.mllib.common import inherit_doc
@@ -35,7 +35,7 @@ __all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel',
@inherit_doc
class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
HasRegParam, HasTol, HasElasticNetParam, HasFitIntercept,
- HasStandardization, HasSolver, HasWeightCol):
+ HasStandardization, HasSolver, HasWeightCol, MLWritable, MLReadable):
"""
Linear regression.
@@ -68,6 +68,25 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
Traceback (most recent call last):
...
TypeError: Method setParams forces keyword arguments.
+ >>> import os, tempfile
+ >>> path = tempfile.mkdtemp()
+ >>> lr_path = path + "/lr"
+ >>> lr.save(lr_path)
+ >>> lr2 = LinearRegression.load(lr_path)
+ >>> lr2.getMaxIter()
+ 5
+ >>> model_path = path + "/lr_model"
+ >>> model.save(model_path)
+ >>> model2 = LinearRegressionModel.load(model_path)
+ >>> model.coefficients[0] == model2.coefficients[0]
+ True
+ >>> model.intercept == model2.intercept
+ True
+ >>> from shutil import rmtree
+ >>> try:
+ ... rmtree(path)
+ ... except OSError:
+ ... pass
.. versionadded:: 1.4.0
"""
@@ -106,7 +125,7 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
return LinearRegressionModel(java_model)
-class LinearRegressionModel(JavaModel):
+class LinearRegressionModel(JavaModel, MLWritable, MLReadable):
"""
Model fitted by LinearRegression.
@@ -821,9 +840,10 @@ class AFTSurvivalRegressionModel(JavaModel):
if __name__ == "__main__":
import doctest
+ import pyspark.ml.regression
from pyspark.context import SparkContext
from pyspark.sql import SQLContext
- globs = globals().copy()
+ globs = pyspark.ml.regression.__dict__.copy()
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
sc = SparkContext("local[2]", "ml.regression tests")