diff options
author | Tommy YU <tummyyu@163.com> | 2016-02-25 21:09:02 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2016-02-25 21:09:02 -0800 |
commit | f3be369ef7b78944ce9cafc94b19df52772cea58 (patch) | |
tree | 33071054830637f9d3fe39e1289aa6d8c76e098a /python/pyspark/ml/regression.py | |
parent | 90d07154c2cef3d1095cb3caeafa7003218a3e49 (diff) | |
download | spark-f3be369ef7b78944ce9cafc94b19df52772cea58.tar.gz spark-f3be369ef7b78944ce9cafc94b19df52772cea58.tar.bz2 spark-f3be369ef7b78944ce9cafc94b19df52772cea58.zip |
[SPARK-13033] [ML] [PYSPARK] Add import/export for ml.regression
Add export/import for all estimators and transformers(which have Scala implementation) under pyspark/ml/regression.py.
yanboliang Please help to review.
For doctest, I though it's enough to add one since it's common usage. But I can add to all if we want it.
Author: Tommy YU <tummyyu@163.com>
Closes #11000 from Wenpei/spark-13033-ml.regression-exprot-import and squashes the following commits:
3646b36 [Tommy YU] address review comments
9cddc98 [Tommy YU] change base on review and pr 11197
cc61d9d [Tommy YU] remove default parameter set
19535d4 [Tommy YU] add export/import to regression
44a9dc2 [Tommy YU] add import/export for ml.regression
Diffstat (limited to 'python/pyspark/ml/regression.py')
-rw-r--r-- | python/pyspark/ml/regression.py | 34 |
1 files changed, 30 insertions, 4 deletions
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index de4a751a54..6b994fe9f9 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -154,7 +154,7 @@ class LinearRegressionModel(JavaModel, MLWritable, MLReadable): @inherit_doc class IsotonicRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, - HasWeightCol): + HasWeightCol, MLWritable, MLReadable): """ .. note:: Experimental @@ -172,6 +172,18 @@ class IsotonicRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti 0.0 >>> model.boundaries DenseVector([0.0, 1.0]) + >>> ir_path = temp_path + "/ir" + >>> ir.save(ir_path) + >>> ir2 = IsotonicRegression.load(ir_path) + >>> ir2.getIsotonic() + True + >>> model_path = temp_path + "/ir_model" + >>> model.save(model_path) + >>> model2 = IsotonicRegressionModel.load(model_path) + >>> model.boundaries == model2.boundaries + True + >>> model.predictions == model2.predictions + True """ isotonic = \ @@ -237,7 +249,7 @@ class IsotonicRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti return self.getOrDefault(self.featureIndex) -class IsotonicRegressionModel(JavaModel): +class IsotonicRegressionModel(JavaModel, MLWritable, MLReadable): """ .. note:: Experimental @@ -663,7 +675,7 @@ class GBTRegressionModel(TreeEnsembleModels): @inherit_doc class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, - HasFitIntercept, HasMaxIter, HasTol): + HasFitIntercept, HasMaxIter, HasTol, MLWritable, MLReadable): """ Accelerated Failure Time (AFT) Model Survival Regression @@ -690,6 +702,20 @@ class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi | 0.0|(1,[],[])| 0.0| 1.0| +-----+---------+------+----------+ ... + >>> aftsr_path = temp_path + "/aftsr" + >>> aftsr.save(aftsr_path) + >>> aftsr2 = AFTSurvivalRegression.load(aftsr_path) + >>> aftsr2.getMaxIter() + 100 + >>> model_path = temp_path + "/aftsr_model" + >>> model.save(model_path) + >>> model2 = AFTSurvivalRegressionModel.load(model_path) + >>> model.coefficients == model2.coefficients + True + >>> model.intercept == model2.intercept + True + >>> model.scale == model2.scale + True .. versionadded:: 1.6.0 """ @@ -787,7 +813,7 @@ class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi return self.getOrDefault(self.quantilesCol) -class AFTSurvivalRegressionModel(JavaModel): +class AFTSurvivalRegressionModel(JavaModel, MLWritable, MLReadable): """ Model fitted by AFTSurvivalRegression. |