aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/classification.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/classification.py')
-rw-r--r--python/pyspark/ml/classification.py6
1 files changed, 4 insertions, 2 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 121b9262dd..a3cd91790c 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -21,8 +21,8 @@ import warnings
from pyspark import since, keyword_only
from pyspark.ml import Estimator, Model
from pyspark.ml.param.shared import *
-from pyspark.ml.regression import (
- RandomForestParams, TreeEnsembleParams, DecisionTreeModel, TreeEnsembleModels)
+from pyspark.ml.regression import DecisionTreeModel, DecisionTreeRegressionModel, \
+ RandomForestParams, TreeEnsembleModels, TreeEnsembleParams
from pyspark.ml.util import *
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams
from pyspark.ml.wrapper import JavaWrapper
@@ -798,6 +798,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
True
>>> model.treeWeights == model2.treeWeights
True
+ >>> model.trees
+ [DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...]
.. versionadded:: 1.4.0
"""