From b9613239d303bc0f451233852c1eb1219a69875e Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 14 Apr 2016 21:36:03 -0700 Subject: [SPARK-14374][ML][PYSPARK] PySpark ml GBTClassifier, Regressor support export/import ## What changes were proposed in this pull request? PySpark ml GBTClassifier, Regressor support export/import. ## How was this patch tested? Doc test. cc jkbradley Author: Yanbo Liang Closes #12383 from yanboliang/spark-14374. --- python/pyspark/ml/classification.py | 17 +++++++++++++++-- python/pyspark/ml/regression.py | 17 +++++++++++++++-- 2 files changed, 30 insertions(+), 4 deletions(-) (limited to 'python/pyspark') diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 922f8069fa..6ef119a426 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -739,7 +739,8 @@ class RandomForestClassificationModel(TreeEnsembleModels, JavaMLWritable, JavaML @inherit_doc class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, - GBTParams, HasCheckpointInterval, HasStepSize, HasSeed): + GBTParams, HasCheckpointInterval, HasStepSize, HasSeed, JavaMLWritable, + JavaMLReadable): """ `http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)` learning algorithm for classification. @@ -767,6 +768,18 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 + >>> gbtc_path = temp_path + "gbtc" + >>> gbt.save(gbtc_path) + >>> gbt2 = GBTClassifier.load(gbtc_path) + >>> gbt2.getMaxDepth() + 2 + >>> model_path = temp_path + "gbtc_model" + >>> model.save(model_path) + >>> model2 = GBTClassificationModel.load(model_path) + >>> model.featureImportances == model2.featureImportances + True + >>> model.treeWeights == model2.treeWeights + True .. versionadded:: 1.4.0 """ @@ -831,7 +844,7 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol return self.getOrDefault(self.lossType) -class GBTClassificationModel(TreeEnsembleModels): +class GBTClassificationModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable): """ Model fitted by GBTClassifier. diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index c064fe500c..3c7852526a 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -902,7 +902,8 @@ class RandomForestRegressionModel(TreeEnsembleModels, JavaMLWritable, JavaMLRead @inherit_doc class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, - GBTParams, HasCheckpointInterval, HasStepSize, HasSeed): + GBTParams, HasCheckpointInterval, HasStepSize, HasSeed, JavaMLWritable, + JavaMLReadable): """ `http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)` learning algorithm for regression. @@ -925,6 +926,18 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 + >>> gbtr_path = temp_path + "gbtr" + >>> gbt.save(gbtr_path) + >>> gbt2 = GBTRegressor.load(gbtr_path) + >>> gbt2.getMaxDepth() + 2 + >>> model_path = temp_path + "gbtr_model" + >>> model.save(model_path) + >>> model2 = GBTRegressionModel.load(model_path) + >>> model.featureImportances == model2.featureImportances + True + >>> model.treeWeights == model2.treeWeights + True .. versionadded:: 1.4.0 """ @@ -989,7 +1002,7 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, return self.getOrDefault(self.lossType) -class GBTRegressionModel(TreeEnsembleModels): +class GBTRegressionModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable): """ Model fitted by GBTRegressor. -- cgit v1.2.3