aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/regression.py
diff options
context:
space:
mode:
authorHolden Karau <holden@us.ibm.com>2016-06-02 15:55:14 -0700
committerNick Pentreath <nickp@za.ibm.com>2016-06-02 15:55:14 -0700
commit72353311d3a37cb523c5bdd8072ffdff99af9749 (patch)
tree45a2def7a6f97388e468203cd87c846ae28ce9b6 /python/pyspark/ml/regression.py
parentd109a1beeef5bca1e683247e0a5db4ec841bf3ba (diff)
downloadspark-72353311d3a37cb523c5bdd8072ffdff99af9749.tar.gz
spark-72353311d3a37cb523c5bdd8072ffdff99af9749.tar.bz2
spark-72353311d3a37cb523c5bdd8072ffdff99af9749.zip
[SPARK-15092][SPARK-15139][PYSPARK][ML] Pyspark TreeEnsemble missing methods
## What changes were proposed in this pull request? Add `toDebugString` and `totalNumNodes` to `TreeEnsembleModels` and add `toDebugString` to `DecisionTreeModel` ## How was this patch tested? Extended doc tests. Author: Holden Karau <holden@us.ibm.com> Closes #12919 from holdenk/SPARK-15139-pyspark-treeEnsemble-missing-methods.
Diffstat (limited to 'python/pyspark/ml/regression.py')
-rw-r--r--python/pyspark/ml/regression.py48
1 files changed, 47 insertions, 1 deletions
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 1b7af7ef59..7c79ab73c7 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -593,7 +593,7 @@ class RandomForestParams(TreeEnsembleParams):
featureSubsetStrategy = \
Param(Params._dummy(), "featureSubsetStrategy",
"The number of features to consider for splits at each tree node. Supported " +
- "options: " + ", ".join(supportedFeatureSubsetStrategies),
+ "options: " + ", ".join(supportedFeatureSubsetStrategies) + " (0.0-1.0], [1-n].",
typeConverter=TypeConverters.toString)
def __init__(self):
@@ -744,6 +744,12 @@ class DecisionTreeModel(JavaModel):
"""Return depth of the decision tree."""
return self._call_java("depth")
+ @property
+ @since("2.0.0")
+ def toDebugString(self):
+ """Full description of model."""
+ return self._call_java("toDebugString")
+
def __repr__(self):
return self._call_java("toString")
@@ -759,11 +765,35 @@ class TreeEnsembleModels(JavaModel):
"""
@property
+ @since("2.0.0")
+ def trees(self):
+ """Trees in this ensemble. Warning: These have null parent Estimators."""
+ return [DecisionTreeModel(m) for m in list(self._call_java("trees"))]
+
+ @property
+ @since("2.0.0")
+ def getNumTrees(self):
+ """Number of trees in ensemble."""
+ return self._call_java("getNumTrees")
+
+ @property
@since("1.5.0")
def treeWeights(self):
"""Return the weights for each tree"""
return list(self._call_java("javaTreeWeights"))
+ @property
+ @since("2.0.0")
+ def totalNumNodes(self):
+ """Total number of nodes, summed over all trees in the ensemble."""
+ return self._call_java("totalNumNodes")
+
+ @property
+ @since("2.0.0")
+ def toDebugString(self):
+ """Full description of model."""
+ return self._call_java("toDebugString")
+
def __repr__(self):
return self._call_java("toString")
@@ -825,6 +855,10 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
0.0
+ >>> model.trees
+ [DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...]
+ >>> model.getNumTrees
+ 2
>>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
>>> model.transform(test1).head().prediction
0.5
@@ -898,6 +932,12 @@ class RandomForestRegressionModel(TreeEnsembleModels, JavaMLWritable, JavaMLRead
@property
@since("2.0.0")
+ def trees(self):
+ """Trees in this ensemble. Warning: These have null parent Estimators."""
+ return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))]
+
+ @property
+ @since("2.0.0")
def featureImportances(self):
"""
Estimate of the importance of each feature.
@@ -1045,6 +1085,12 @@ class GBTRegressionModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable):
"""
return self._call_java("featureImportances")
+ @property
+ @since("2.0.0")
+ def trees(self):
+ """Trees in this ensemble. Warning: These have null parent Estimators."""
+ return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))]
+
@inherit_doc
class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,