diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2016-04-14 21:36:03 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2016-04-14 21:36:03 -0700 |
commit | b9613239d303bc0f451233852c1eb1219a69875e (patch) | |
tree | cd78313211d0591320724bc67785632a71b20922 /python/pyspark/ml/regression.py | |
parent | 297ba3f1b49cc37d9891a529142c553e0a5e2d62 (diff) | |
download | spark-b9613239d303bc0f451233852c1eb1219a69875e.tar.gz spark-b9613239d303bc0f451233852c1eb1219a69875e.tar.bz2 spark-b9613239d303bc0f451233852c1eb1219a69875e.zip |
[SPARK-14374][ML][PYSPARK] PySpark ml GBTClassifier, Regressor support export/import
## What changes were proposed in this pull request?
PySpark ml GBTClassifier, Regressor support export/import.
## How was this patch tested?
Doc test.
cc jkbradley
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #12383 from yanboliang/spark-14374.
Diffstat (limited to 'python/pyspark/ml/regression.py')
-rw-r--r-- | python/pyspark/ml/regression.py | 17 |
1 files changed, 15 insertions, 2 deletions
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index c064fe500c..3c7852526a 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -902,7 +902,8 @@ class RandomForestRegressionModel(TreeEnsembleModels, JavaMLWritable, JavaMLRead @inherit_doc class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, - GBTParams, HasCheckpointInterval, HasStepSize, HasSeed): + GBTParams, HasCheckpointInterval, HasStepSize, HasSeed, JavaMLWritable, + JavaMLReadable): """ `http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)` learning algorithm for regression. @@ -925,6 +926,18 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 + >>> gbtr_path = temp_path + "gbtr" + >>> gbt.save(gbtr_path) + >>> gbt2 = GBTRegressor.load(gbtr_path) + >>> gbt2.getMaxDepth() + 2 + >>> model_path = temp_path + "gbtr_model" + >>> model.save(model_path) + >>> model2 = GBTRegressionModel.load(model_path) + >>> model.featureImportances == model2.featureImportances + True + >>> model.treeWeights == model2.treeWeights + True .. versionadded:: 1.4.0 """ @@ -989,7 +1002,7 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, return self.getOrDefault(self.lossType) -class GBTRegressionModel(TreeEnsembleModels): +class GBTRegressionModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable): """ Model fitted by GBTRegressor. |