From 72353311d3a37cb523c5bdd8072ffdff99af9749 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 2 Jun 2016 15:55:14 -0700 Subject: [SPARK-15092][SPARK-15139][PYSPARK][ML] Pyspark TreeEnsemble missing methods ## What changes were proposed in this pull request? Add `toDebugString` and `totalNumNodes` to `TreeEnsembleModels` and add `toDebugString` to `DecisionTreeModel` ## How was this patch tested? Extended doc tests. Author: Holden Karau Closes #12919 from holdenk/SPARK-15139-pyspark-treeEnsemble-missing-methods. --- python/pyspark/ml/regression.py | 48 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) (limited to 'python/pyspark/ml/regression.py') diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 1b7af7ef59..7c79ab73c7 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -593,7 +593,7 @@ class RandomForestParams(TreeEnsembleParams): featureSubsetStrategy = \ Param(Params._dummy(), "featureSubsetStrategy", "The number of features to consider for splits at each tree node. Supported " + - "options: " + ", ".join(supportedFeatureSubsetStrategies), + "options: " + ", ".join(supportedFeatureSubsetStrategies) + " (0.0-1.0], [1-n].", typeConverter=TypeConverters.toString) def __init__(self): @@ -744,6 +744,12 @@ class DecisionTreeModel(JavaModel): """Return depth of the decision tree.""" return self._call_java("depth") + @property + @since("2.0.0") + def toDebugString(self): + """Full description of model.""" + return self._call_java("toDebugString") + def __repr__(self): return self._call_java("toString") @@ -758,12 +764,36 @@ class TreeEnsembleModels(JavaModel): .. versionadded:: 1.5.0 """ + @property + @since("2.0.0") + def trees(self): + """Trees in this ensemble. Warning: These have null parent Estimators.""" + return [DecisionTreeModel(m) for m in list(self._call_java("trees"))] + + @property + @since("2.0.0") + def getNumTrees(self): + """Number of trees in ensemble.""" + return self._call_java("getNumTrees") + @property @since("1.5.0") def treeWeights(self): """Return the weights for each tree""" return list(self._call_java("javaTreeWeights")) + @property + @since("2.0.0") + def totalNumNodes(self): + """Total number of nodes, summed over all trees in the ensemble.""" + return self._call_java("totalNumNodes") + + @property + @since("2.0.0") + def toDebugString(self): + """Full description of model.""" + return self._call_java("toDebugString") + def __repr__(self): return self._call_java("toString") @@ -825,6 +855,10 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> model.transform(test0).head().prediction 0.0 + >>> model.trees + [DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...] + >>> model.getNumTrees + 2 >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 0.5 @@ -896,6 +930,12 @@ class RandomForestRegressionModel(TreeEnsembleModels, JavaMLWritable, JavaMLRead .. versionadded:: 1.4.0 """ + @property + @since("2.0.0") + def trees(self): + """Trees in this ensemble. Warning: These have null parent Estimators.""" + return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))] + @property @since("2.0.0") def featureImportances(self): @@ -1045,6 +1085,12 @@ class GBTRegressionModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable): """ return self._call_java("featureImportances") + @property + @since("2.0.0") + def trees(self): + """Trees in this ensemble. Warning: These have null parent Estimators.""" + return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))] + @inherit_doc class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, -- cgit v1.2.3