diff options
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/ml/classification.py | 17 | ||||
-rw-r--r-- | python/pyspark/ml/regression.py | 17 |
2 files changed, 30 insertions, 4 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 922f8069fa..6ef119a426 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -739,7 +739,8 @@ class RandomForestClassificationModel(TreeEnsembleModels, JavaMLWritable, JavaML @inherit_doc class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, - GBTParams, HasCheckpointInterval, HasStepSize, HasSeed): + GBTParams, HasCheckpointInterval, HasStepSize, HasSeed, JavaMLWritable, + JavaMLReadable): """ `http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)` learning algorithm for classification. @@ -767,6 +768,18 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 + >>> gbtc_path = temp_path + "gbtc" + >>> gbt.save(gbtc_path) + >>> gbt2 = GBTClassifier.load(gbtc_path) + >>> gbt2.getMaxDepth() + 2 + >>> model_path = temp_path + "gbtc_model" + >>> model.save(model_path) + >>> model2 = GBTClassificationModel.load(model_path) + >>> model.featureImportances == model2.featureImportances + True + >>> model.treeWeights == model2.treeWeights + True .. versionadded:: 1.4.0 """ @@ -831,7 +844,7 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol return self.getOrDefault(self.lossType) -class GBTClassificationModel(TreeEnsembleModels): +class GBTClassificationModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable): """ Model fitted by GBTClassifier. diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index c064fe500c..3c7852526a 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -902,7 +902,8 @@ class RandomForestRegressionModel(TreeEnsembleModels, JavaMLWritable, JavaMLRead @inherit_doc class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, - GBTParams, HasCheckpointInterval, HasStepSize, HasSeed): + GBTParams, HasCheckpointInterval, HasStepSize, HasSeed, JavaMLWritable, + JavaMLReadable): """ `http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)` learning algorithm for regression. @@ -925,6 +926,18 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 + >>> gbtr_path = temp_path + "gbtr" + >>> gbt.save(gbtr_path) + >>> gbt2 = GBTRegressor.load(gbtr_path) + >>> gbt2.getMaxDepth() + 2 + >>> model_path = temp_path + "gbtr_model" + >>> model.save(model_path) + >>> model2 = GBTRegressionModel.load(model_path) + >>> model.featureImportances == model2.featureImportances + True + >>> model.treeWeights == model2.treeWeights + True .. versionadded:: 1.4.0 """ @@ -989,7 +1002,7 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, return self.getOrDefault(self.lossType) -class GBTRegressionModel(TreeEnsembleModels): +class GBTRegressionModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable): """ Model fitted by GBTRegressor. |