diff options
author | Holden Karau <holden@us.ibm.com> | 2016-06-02 15:55:14 -0700 |
---|---|---|
committer | Nick Pentreath <nickp@za.ibm.com> | 2016-06-02 15:55:14 -0700 |
commit | 72353311d3a37cb523c5bdd8072ffdff99af9749 (patch) | |
tree | 45a2def7a6f97388e468203cd87c846ae28ce9b6 /python/pyspark/ml/classification.py | |
parent | d109a1beeef5bca1e683247e0a5db4ec841bf3ba (diff) | |
download | spark-72353311d3a37cb523c5bdd8072ffdff99af9749.tar.gz spark-72353311d3a37cb523c5bdd8072ffdff99af9749.tar.bz2 spark-72353311d3a37cb523c5bdd8072ffdff99af9749.zip |
[SPARK-15092][SPARK-15139][PYSPARK][ML] Pyspark TreeEnsemble missing methods
## What changes were proposed in this pull request?
Add `toDebugString` and `totalNumNodes` to `TreeEnsembleModels` and add `toDebugString` to `DecisionTreeModel`
## How was this patch tested?
Extended doc tests.
Author: Holden Karau <holden@us.ibm.com>
Closes #12919 from holdenk/SPARK-15139-pyspark-treeEnsemble-missing-methods.
Diffstat (limited to 'python/pyspark/ml/classification.py')
-rw-r--r-- | python/pyspark/ml/classification.py | 20 |
1 files changed, 20 insertions, 0 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index ea660d7808..177cf9d72c 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -512,6 +512,8 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred 1 >>> model.featureImportances SparseVector(1, {0: 1.0}) + >>> print(model.toDebugString) + DecisionTreeClassificationModel (uid=...) of depth 1 with 3 nodes... >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> result = model.transform(test0).head() >>> result.prediction @@ -650,6 +652,8 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 + >>> model.trees + [DecisionTreeClassificationModel (uid=...) of depth..., DecisionTreeClassificationModel...] >>> rfc_path = temp_path + "/rfc" >>> rf.save(rfc_path) >>> rf2 = RandomForestClassifier.load(rfc_path) @@ -730,6 +734,12 @@ class RandomForestClassificationModel(TreeEnsembleModels, JavaMLWritable, JavaML """ return self._call_java("featureImportances") + @property + @since("2.0.0") + def trees(self): + """Trees in this ensemble. Warning: These have null parent Estimators.""" + return [DecisionTreeClassificationModel(m) for m in list(self._call_java("trees"))] + @inherit_doc class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, @@ -772,6 +782,10 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 + >>> model.totalNumNodes + 15 + >>> print(model.toDebugString) + GBTClassificationModel (uid=...)...with 5 trees... >>> gbtc_path = temp_path + "gbtc" >>> gbt.save(gbtc_path) >>> gbt2 = GBTClassifier.load(gbtc_path) @@ -869,6 +883,12 @@ class GBTClassificationModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable) """ return self._call_java("featureImportances") + @property + @since("2.0.0") + def trees(self): + """Trees in this ensemble. Warning: These have null parent Estimators.""" + return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))] + @inherit_doc class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol, |