aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/regression.py
diff options
context:
space:
mode:
authorTommy YU <tummyyu@163.com>2016-02-25 21:09:02 -0800
committerXiangrui Meng <meng@databricks.com>2016-02-25 21:09:02 -0800
commitf3be369ef7b78944ce9cafc94b19df52772cea58 (patch)
tree33071054830637f9d3fe39e1289aa6d8c76e098a /python/pyspark/ml/regression.py
parent90d07154c2cef3d1095cb3caeafa7003218a3e49 (diff)
downloadspark-f3be369ef7b78944ce9cafc94b19df52772cea58.tar.gz
spark-f3be369ef7b78944ce9cafc94b19df52772cea58.tar.bz2
spark-f3be369ef7b78944ce9cafc94b19df52772cea58.zip
[SPARK-13033] [ML] [PYSPARK] Add import/export for ml.regression
Add export/import for all estimators and transformers(which have Scala implementation) under pyspark/ml/regression.py. yanboliang Please help to review. For doctest, I though it's enough to add one since it's common usage. But I can add to all if we want it. Author: Tommy YU <tummyyu@163.com> Closes #11000 from Wenpei/spark-13033-ml.regression-exprot-import and squashes the following commits: 3646b36 [Tommy YU] address review comments 9cddc98 [Tommy YU] change base on review and pr 11197 cc61d9d [Tommy YU] remove default parameter set 19535d4 [Tommy YU] add export/import to regression 44a9dc2 [Tommy YU] add import/export for ml.regression
Diffstat (limited to 'python/pyspark/ml/regression.py')
-rw-r--r--python/pyspark/ml/regression.py34
1 files changed, 30 insertions, 4 deletions
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index de4a751a54..6b994fe9f9 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -154,7 +154,7 @@ class LinearRegressionModel(JavaModel, MLWritable, MLReadable):
@inherit_doc
class IsotonicRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
- HasWeightCol):
+ HasWeightCol, MLWritable, MLReadable):
"""
.. note:: Experimental
@@ -172,6 +172,18 @@ class IsotonicRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
0.0
>>> model.boundaries
DenseVector([0.0, 1.0])
+ >>> ir_path = temp_path + "/ir"
+ >>> ir.save(ir_path)
+ >>> ir2 = IsotonicRegression.load(ir_path)
+ >>> ir2.getIsotonic()
+ True
+ >>> model_path = temp_path + "/ir_model"
+ >>> model.save(model_path)
+ >>> model2 = IsotonicRegressionModel.load(model_path)
+ >>> model.boundaries == model2.boundaries
+ True
+ >>> model.predictions == model2.predictions
+ True
"""
isotonic = \
@@ -237,7 +249,7 @@ class IsotonicRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
return self.getOrDefault(self.featureIndex)
-class IsotonicRegressionModel(JavaModel):
+class IsotonicRegressionModel(JavaModel, MLWritable, MLReadable):
"""
.. note:: Experimental
@@ -663,7 +675,7 @@ class GBTRegressionModel(TreeEnsembleModels):
@inherit_doc
class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
- HasFitIntercept, HasMaxIter, HasTol):
+ HasFitIntercept, HasMaxIter, HasTol, MLWritable, MLReadable):
"""
Accelerated Failure Time (AFT) Model Survival Regression
@@ -690,6 +702,20 @@ class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
| 0.0|(1,[],[])| 0.0| 1.0|
+-----+---------+------+----------+
...
+ >>> aftsr_path = temp_path + "/aftsr"
+ >>> aftsr.save(aftsr_path)
+ >>> aftsr2 = AFTSurvivalRegression.load(aftsr_path)
+ >>> aftsr2.getMaxIter()
+ 100
+ >>> model_path = temp_path + "/aftsr_model"
+ >>> model.save(model_path)
+ >>> model2 = AFTSurvivalRegressionModel.load(model_path)
+ >>> model.coefficients == model2.coefficients
+ True
+ >>> model.intercept == model2.intercept
+ True
+ >>> model.scale == model2.scale
+ True
.. versionadded:: 1.6.0
"""
@@ -787,7 +813,7 @@ class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
return self.getOrDefault(self.quantilesCol)
-class AFTSurvivalRegressionModel(JavaModel):
+class AFTSurvivalRegressionModel(JavaModel, MLWritable, MLReadable):
"""
Model fitted by AFTSurvivalRegression.