aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-04-14 21:36:03 -0700
committerXiangrui Meng <meng@databricks.com>2016-04-14 21:36:03 -0700
commitb9613239d303bc0f451233852c1eb1219a69875e (patch)
treecd78313211d0591320724bc67785632a71b20922 /python/pyspark
parent297ba3f1b49cc37d9891a529142c553e0a5e2d62 (diff)
downloadspark-b9613239d303bc0f451233852c1eb1219a69875e.tar.gz
spark-b9613239d303bc0f451233852c1eb1219a69875e.tar.bz2
spark-b9613239d303bc0f451233852c1eb1219a69875e.zip
[SPARK-14374][ML][PYSPARK] PySpark ml GBTClassifier, Regressor support export/import
## What changes were proposed in this pull request? PySpark ml GBTClassifier, Regressor support export/import. ## How was this patch tested? Doc test. cc jkbradley Author: Yanbo Liang <ybliang8@gmail.com> Closes #12383 from yanboliang/spark-14374.
Diffstat (limited to 'python/pyspark')
-rw-r--r--python/pyspark/ml/classification.py17
-rw-r--r--python/pyspark/ml/regression.py17
2 files changed, 30 insertions, 4 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 922f8069fa..6ef119a426 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -739,7 +739,8 @@ class RandomForestClassificationModel(TreeEnsembleModels, JavaMLWritable, JavaML
@inherit_doc
class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
- GBTParams, HasCheckpointInterval, HasStepSize, HasSeed):
+ GBTParams, HasCheckpointInterval, HasStepSize, HasSeed, JavaMLWritable,
+ JavaMLReadable):
"""
`http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)`
learning algorithm for classification.
@@ -767,6 +768,18 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
>>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
>>> model.transform(test1).head().prediction
1.0
+ >>> gbtc_path = temp_path + "gbtc"
+ >>> gbt.save(gbtc_path)
+ >>> gbt2 = GBTClassifier.load(gbtc_path)
+ >>> gbt2.getMaxDepth()
+ 2
+ >>> model_path = temp_path + "gbtc_model"
+ >>> model.save(model_path)
+ >>> model2 = GBTClassificationModel.load(model_path)
+ >>> model.featureImportances == model2.featureImportances
+ True
+ >>> model.treeWeights == model2.treeWeights
+ True
.. versionadded:: 1.4.0
"""
@@ -831,7 +844,7 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
return self.getOrDefault(self.lossType)
-class GBTClassificationModel(TreeEnsembleModels):
+class GBTClassificationModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable):
"""
Model fitted by GBTClassifier.
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index c064fe500c..3c7852526a 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -902,7 +902,8 @@ class RandomForestRegressionModel(TreeEnsembleModels, JavaMLWritable, JavaMLRead
@inherit_doc
class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
- GBTParams, HasCheckpointInterval, HasStepSize, HasSeed):
+ GBTParams, HasCheckpointInterval, HasStepSize, HasSeed, JavaMLWritable,
+ JavaMLReadable):
"""
`http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)`
learning algorithm for regression.
@@ -925,6 +926,18 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
>>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
>>> model.transform(test1).head().prediction
1.0
+ >>> gbtr_path = temp_path + "gbtr"
+ >>> gbt.save(gbtr_path)
+ >>> gbt2 = GBTRegressor.load(gbtr_path)
+ >>> gbt2.getMaxDepth()
+ 2
+ >>> model_path = temp_path + "gbtr_model"
+ >>> model.save(model_path)
+ >>> model2 = GBTRegressionModel.load(model_path)
+ >>> model.featureImportances == model2.featureImportances
+ True
+ >>> model.treeWeights == model2.treeWeights
+ True
.. versionadded:: 1.4.0
"""
@@ -989,7 +1002,7 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
return self.getOrDefault(self.lossType)
-class GBTRegressionModel(TreeEnsembleModels):
+class GBTRegressionModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable):
"""
Model fitted by GBTRegressor.