aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/classification.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/classification.py')
-rw-r--r--python/pyspark/ml/classification.py20
1 files changed, 20 insertions, 0 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index ea660d7808..177cf9d72c 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -512,6 +512,8 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
1
>>> model.featureImportances
SparseVector(1, {0: 1.0})
+ >>> print(model.toDebugString)
+ DecisionTreeClassificationModel (uid=...) of depth 1 with 3 nodes...
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> result = model.transform(test0).head()
>>> result.prediction
@@ -650,6 +652,8 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
>>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
>>> model.transform(test1).head().prediction
1.0
+ >>> model.trees
+ [DecisionTreeClassificationModel (uid=...) of depth..., DecisionTreeClassificationModel...]
>>> rfc_path = temp_path + "/rfc"
>>> rf.save(rfc_path)
>>> rf2 = RandomForestClassifier.load(rfc_path)
@@ -730,6 +734,12 @@ class RandomForestClassificationModel(TreeEnsembleModels, JavaMLWritable, JavaML
"""
return self._call_java("featureImportances")
+ @property
+ @since("2.0.0")
+ def trees(self):
+ """Trees in this ensemble. Warning: These have null parent Estimators."""
+ return [DecisionTreeClassificationModel(m) for m in list(self._call_java("trees"))]
+
@inherit_doc
class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
@@ -772,6 +782,10 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
>>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
>>> model.transform(test1).head().prediction
1.0
+ >>> model.totalNumNodes
+ 15
+ >>> print(model.toDebugString)
+ GBTClassificationModel (uid=...)...with 5 trees...
>>> gbtc_path = temp_path + "gbtc"
>>> gbt.save(gbtc_path)
>>> gbt2 = GBTClassifier.load(gbtc_path)
@@ -869,6 +883,12 @@ class GBTClassificationModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable)
"""
return self._call_java("featureImportances")
+ @property
+ @since("2.0.0")
+ def trees(self):
+ """Trees in this ensemble. Warning: These have null parent Estimators."""
+ return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))]
+
@inherit_doc
class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol,