diff options
Diffstat (limited to 'python/pyspark/ml/regression.py')
-rw-r--r-- | python/pyspark/ml/regression.py | 30 |
1 files changed, 25 insertions, 5 deletions
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 74a2248ed0..20dc6c2db9 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -18,9 +18,9 @@ import warnings from pyspark import since -from pyspark.ml.util import keyword_only -from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import * +from pyspark.ml.util import * +from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.mllib.common import inherit_doc @@ -35,7 +35,7 @@ __all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel', @inherit_doc class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, HasRegParam, HasTol, HasElasticNetParam, HasFitIntercept, - HasStandardization, HasSolver, HasWeightCol): + HasStandardization, HasSolver, HasWeightCol, MLWritable, MLReadable): """ Linear regression. @@ -68,6 +68,25 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction Traceback (most recent call last): ... TypeError: Method setParams forces keyword arguments. + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> lr_path = path + "/lr" + >>> lr.save(lr_path) + >>> lr2 = LinearRegression.load(lr_path) + >>> lr2.getMaxIter() + 5 + >>> model_path = path + "/lr_model" + >>> model.save(model_path) + >>> model2 = LinearRegressionModel.load(model_path) + >>> model.coefficients[0] == model2.coefficients[0] + True + >>> model.intercept == model2.intercept + True + >>> from shutil import rmtree + >>> try: + ... rmtree(path) + ... except OSError: + ... pass .. versionadded:: 1.4.0 """ @@ -106,7 +125,7 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction return LinearRegressionModel(java_model) -class LinearRegressionModel(JavaModel): +class LinearRegressionModel(JavaModel, MLWritable, MLReadable): """ Model fitted by LinearRegression. @@ -821,9 +840,10 @@ class AFTSurvivalRegressionModel(JavaModel): if __name__ == "__main__": import doctest + import pyspark.ml.regression from pyspark.context import SparkContext from pyspark.sql import SQLContext - globs = globals().copy() + globs = pyspark.ml.regression.__dict__.copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: sc = SparkContext("local[2]", "ml.regression tests") |