aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorBryan Cutler <cutlerb@gmail.com>2016-06-20 16:28:11 -0700
committerXiangrui Meng <meng@databricks.com>2016-06-20 16:28:11 -0700
commita42bf555326b75c8251be77db68105c29e8c95c4 (patch)
treed4eff56ebb90edab7d6c01f51926db31a788b772 /python
parent6daa8cf1a642a669cd3a0305036c4390e4336a73 (diff)
downloadspark-a42bf555326b75c8251be77db68105c29e8c95c4.tar.gz
spark-a42bf555326b75c8251be77db68105c29e8c95c4.tar.bz2
spark-a42bf555326b75c8251be77db68105c29e8c95c4.zip
[SPARK-16079][PYSPARK][ML] Added missing import for DecisionTreeRegressionModel used in GBTClassificationModel
## What changes were proposed in this pull request? Fixed missing import for DecisionTreeRegressionModel used in GBTClassificationModel trees method. ## How was this patch tested? Local tests Author: Bryan Cutler <cutlerb@gmail.com> Closes #13787 from BryanCutler/pyspark-GBTClassificationModel-import-SPARK-16079.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/ml/classification.py6
-rw-r--r--python/pyspark/ml/regression.py2
2 files changed, 6 insertions, 2 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 121b9262dd..a3cd91790c 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -21,8 +21,8 @@ import warnings
from pyspark import since, keyword_only
from pyspark.ml import Estimator, Model
from pyspark.ml.param.shared import *
-from pyspark.ml.regression import (
- RandomForestParams, TreeEnsembleParams, DecisionTreeModel, TreeEnsembleModels)
+from pyspark.ml.regression import DecisionTreeModel, DecisionTreeRegressionModel, \
+ RandomForestParams, TreeEnsembleModels, TreeEnsembleParams
from pyspark.ml.util import *
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams
from pyspark.ml.wrapper import JavaWrapper
@@ -798,6 +798,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
True
>>> model.treeWeights == model2.treeWeights
True
+ >>> model.trees
+ [DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...]
.. versionadded:: 1.4.0
"""
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index db31993f0f..8d2378d51f 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -994,6 +994,8 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
True
>>> model.treeWeights == model2.treeWeights
True
+ >>> model.trees
+ [DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...]
.. versionadded:: 1.4.0
"""