diff options
Diffstat (limited to 'python/pyspark/ml/regression.py')
-rw-r--r-- | python/pyspark/ml/regression.py | 48 |
1 files changed, 47 insertions, 1 deletions
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 1b7af7ef59..7c79ab73c7 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -593,7 +593,7 @@ class RandomForestParams(TreeEnsembleParams): featureSubsetStrategy = \ Param(Params._dummy(), "featureSubsetStrategy", "The number of features to consider for splits at each tree node. Supported " + - "options: " + ", ".join(supportedFeatureSubsetStrategies), + "options: " + ", ".join(supportedFeatureSubsetStrategies) + " (0.0-1.0], [1-n].", typeConverter=TypeConverters.toString) def __init__(self): @@ -744,6 +744,12 @@ class DecisionTreeModel(JavaModel): """Return depth of the decision tree.""" return self._call_java("depth") + @property + @since("2.0.0") + def toDebugString(self): + """Full description of model.""" + return self._call_java("toDebugString") + def __repr__(self): return self._call_java("toString") @@ -759,11 +765,35 @@ class TreeEnsembleModels(JavaModel): """ @property + @since("2.0.0") + def trees(self): + """Trees in this ensemble. Warning: These have null parent Estimators.""" + return [DecisionTreeModel(m) for m in list(self._call_java("trees"))] + + @property + @since("2.0.0") + def getNumTrees(self): + """Number of trees in ensemble.""" + return self._call_java("getNumTrees") + + @property @since("1.5.0") def treeWeights(self): """Return the weights for each tree""" return list(self._call_java("javaTreeWeights")) + @property + @since("2.0.0") + def totalNumNodes(self): + """Total number of nodes, summed over all trees in the ensemble.""" + return self._call_java("totalNumNodes") + + @property + @since("2.0.0") + def toDebugString(self): + """Full description of model.""" + return self._call_java("toDebugString") + def __repr__(self): return self._call_java("toString") @@ -825,6 +855,10 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> model.transform(test0).head().prediction 0.0 + >>> model.trees + [DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...] + >>> model.getNumTrees + 2 >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 0.5 @@ -898,6 +932,12 @@ class RandomForestRegressionModel(TreeEnsembleModels, JavaMLWritable, JavaMLRead @property @since("2.0.0") + def trees(self): + """Trees in this ensemble. Warning: These have null parent Estimators.""" + return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))] + + @property + @since("2.0.0") def featureImportances(self): """ Estimate of the importance of each feature. @@ -1045,6 +1085,12 @@ class GBTRegressionModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable): """ return self._call_java("featureImportances") + @property + @since("2.0.0") + def trees(self): + """Trees in this ensemble. Warning: These have null parent Estimators.""" + return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))] + @inherit_doc class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, |