From 72353311d3a37cb523c5bdd8072ffdff99af9749 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 2 Jun 2016 15:55:14 -0700 Subject: [SPARK-15092][SPARK-15139][PYSPARK][ML] Pyspark TreeEnsemble missing methods ## What changes were proposed in this pull request? Add `toDebugString` and `totalNumNodes` to `TreeEnsembleModels` and add `toDebugString` to `DecisionTreeModel` ## How was this patch tested? Extended doc tests. Author: Holden Karau Closes #12919 from holdenk/SPARK-15139-pyspark-treeEnsemble-missing-methods. --- python/pyspark/ml/classification.py | 20 ++++++++++++++++ python/pyspark/ml/regression.py | 48 ++++++++++++++++++++++++++++++++++++- 2 files changed, 67 insertions(+), 1 deletion(-) (limited to 'python') diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index ea660d7808..177cf9d72c 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -512,6 +512,8 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred 1 >>> model.featureImportances SparseVector(1, {0: 1.0}) + >>> print(model.toDebugString) + DecisionTreeClassificationModel (uid=...) of depth 1 with 3 nodes... >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> result = model.transform(test0).head() >>> result.prediction @@ -650,6 +652,8 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 + >>> model.trees + [DecisionTreeClassificationModel (uid=...) of depth..., DecisionTreeClassificationModel...] >>> rfc_path = temp_path + "/rfc" >>> rf.save(rfc_path) >>> rf2 = RandomForestClassifier.load(rfc_path) @@ -730,6 +734,12 @@ class RandomForestClassificationModel(TreeEnsembleModels, JavaMLWritable, JavaML """ return self._call_java("featureImportances") + @property + @since("2.0.0") + def trees(self): + """Trees in this ensemble. Warning: These have null parent Estimators.""" + return [DecisionTreeClassificationModel(m) for m in list(self._call_java("trees"))] + @inherit_doc class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, @@ -772,6 +782,10 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 + >>> model.totalNumNodes + 15 + >>> print(model.toDebugString) + GBTClassificationModel (uid=...)...with 5 trees... >>> gbtc_path = temp_path + "gbtc" >>> gbt.save(gbtc_path) >>> gbt2 = GBTClassifier.load(gbtc_path) @@ -869,6 +883,12 @@ class GBTClassificationModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable) """ return self._call_java("featureImportances") + @property + @since("2.0.0") + def trees(self): + """Trees in this ensemble. Warning: These have null parent Estimators.""" + return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))] + @inherit_doc class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol, diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 1b7af7ef59..7c79ab73c7 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -593,7 +593,7 @@ class RandomForestParams(TreeEnsembleParams): featureSubsetStrategy = \ Param(Params._dummy(), "featureSubsetStrategy", "The number of features to consider for splits at each tree node. Supported " + - "options: " + ", ".join(supportedFeatureSubsetStrategies), + "options: " + ", ".join(supportedFeatureSubsetStrategies) + " (0.0-1.0], [1-n].", typeConverter=TypeConverters.toString) def __init__(self): @@ -744,6 +744,12 @@ class DecisionTreeModel(JavaModel): """Return depth of the decision tree.""" return self._call_java("depth") + @property + @since("2.0.0") + def toDebugString(self): + """Full description of model.""" + return self._call_java("toDebugString") + def __repr__(self): return self._call_java("toString") @@ -758,12 +764,36 @@ class TreeEnsembleModels(JavaModel): .. versionadded:: 1.5.0 """ + @property + @since("2.0.0") + def trees(self): + """Trees in this ensemble. Warning: These have null parent Estimators.""" + return [DecisionTreeModel(m) for m in list(self._call_java("trees"))] + + @property + @since("2.0.0") + def getNumTrees(self): + """Number of trees in ensemble.""" + return self._call_java("getNumTrees") + @property @since("1.5.0") def treeWeights(self): """Return the weights for each tree""" return list(self._call_java("javaTreeWeights")) + @property + @since("2.0.0") + def totalNumNodes(self): + """Total number of nodes, summed over all trees in the ensemble.""" + return self._call_java("totalNumNodes") + + @property + @since("2.0.0") + def toDebugString(self): + """Full description of model.""" + return self._call_java("toDebugString") + def __repr__(self): return self._call_java("toString") @@ -825,6 +855,10 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> model.transform(test0).head().prediction 0.0 + >>> model.trees + [DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...] + >>> model.getNumTrees + 2 >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 0.5 @@ -896,6 +930,12 @@ class RandomForestRegressionModel(TreeEnsembleModels, JavaMLWritable, JavaMLRead .. versionadded:: 1.4.0 """ + @property + @since("2.0.0") + def trees(self): + """Trees in this ensemble. Warning: These have null parent Estimators.""" + return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))] + @property @since("2.0.0") def featureImportances(self): @@ -1045,6 +1085,12 @@ class GBTRegressionModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable): """ return self._call_java("featureImportances") + @property + @since("2.0.0") + def trees(self): + """Trees in this ensemble. Warning: These have null parent Estimators.""" + return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))] + @inherit_doc class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, -- cgit v1.2.3