aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/regression.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/regression.py')
-rw-r--r--python/pyspark/ml/regression.py16
1 files changed, 14 insertions, 2 deletions
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 59d4fe3cf4..37648549de 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -389,7 +389,7 @@ class GBTParams(TreeEnsembleParams):
@inherit_doc
class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
DecisionTreeParams, TreeRegressorParams, HasCheckpointInterval,
- HasSeed):
+ HasSeed, JavaMLWritable, JavaMLReadable):
"""
`http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree`
learning algorithm for regression.
@@ -413,6 +413,18 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
>>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
>>> model.transform(test1).head().prediction
1.0
+ >>> dtr_path = temp_path + "/dtr"
+ >>> dt.save(dtr_path)
+ >>> dt2 = DecisionTreeRegressor.load(dtr_path)
+ >>> dt2.getMaxDepth()
+ 2
+ >>> model_path = temp_path + "/dtr_model"
+ >>> model.save(model_path)
+ >>> model2 = DecisionTreeRegressionModel.load(model_path)
+ >>> model.numNodes == model2.numNodes
+ True
+ >>> model.depth == model2.depth
+ True
.. versionadded:: 1.4.0
"""
@@ -498,7 +510,7 @@ class TreeEnsembleModels(JavaModel):
@inherit_doc
-class DecisionTreeRegressionModel(DecisionTreeModel):
+class DecisionTreeRegressionModel(DecisionTreeModel, JavaMLWritable, JavaMLReadable):
"""
Model fitted by DecisionTreeRegressor.