aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/regression.py
diff options
context:
space:
mode:
authorsethah <seth.hendrickson16@gmail.com>2016-03-31 13:00:10 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-03-31 13:00:10 -0700
commitb11887c086974dbab18b9f53e99a26bbe06e9c86 (patch)
treee7a2b3254ae69d4b1605e2b0d5c7071d75c058cc /python/pyspark/ml/regression.py
parente785402826dcd984d9312470464714ba6c908a49 (diff)
downloadspark-b11887c086974dbab18b9f53e99a26bbe06e9c86.tar.gz
spark-b11887c086974dbab18b9f53e99a26bbe06e9c86.tar.bz2
spark-b11887c086974dbab18b9f53e99a26bbe06e9c86.zip
[SPARK-14264][PYSPARK][ML] Add feature importance for GBTs in pyspark
## What changes were proposed in this pull request? Feature importances are exposed in the python API for GBTs. Other changes: * Update the random forest feature importance documentation to not repeat decision tree docstring and instead place a reference to it. ## How was this patch tested? Python doc tests were updated to validate GBT feature importance. Author: sethah <seth.hendrickson16@gmail.com> Closes #12056 from sethah/Pyspark_GBT_feature_importance.
Diffstat (limited to 'python/pyspark/ml/regression.py')
-rw-r--r--python/pyspark/ml/regression.py33
1 files changed, 23 insertions, 10 deletions
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 37648549de..de8a5e4bed 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -533,7 +533,7 @@ class DecisionTreeRegressionModel(DecisionTreeModel, JavaMLWritable, JavaMLReada
- Normalize importances for tree to sum to 1.
Note: Feature importance for single decision trees can have high variance due to
- correlated predictor variables. Consider using a :class:`RandomForestRegressor`
+ correlated predictor variables. Consider using a :py:class:`RandomForestRegressor`
to determine feature importance instead.
"""
return self._call_java("featureImportances")
@@ -626,16 +626,12 @@ class RandomForestRegressionModel(TreeEnsembleModels):
"""
Estimate of the importance of each feature.
- This generalizes the idea of "Gini" importance to other losses,
- following the explanation of Gini importance from "Random Forests" documentation
- by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
+ Each feature's importance is the average of its importance across all trees in the ensemble
+ The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
+ (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
+ and follows the implementation from scikit-learn.
- This feature importance is calculated as follows:
- - Average over trees:
- - importance(feature j) = sum (over nodes which split on feature j) of the gain,
- where gain is scaled by the number of instances passing through node
- - Normalize importances for tree to sum to 1.
- - Normalize feature importance vector to sum to 1.
+ .. seealso:: :py:attr:`DecisionTreeRegressionModel.featureImportances`
"""
return self._call_java("featureImportances")
@@ -655,6 +651,8 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
>>> gbt = GBTRegressor(maxIter=5, maxDepth=2, seed=42)
>>> model = gbt.fit(df)
+ >>> model.featureImportances
+ SparseVector(1, {0: 1.0})
>>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])
True
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
@@ -734,6 +732,21 @@ class GBTRegressionModel(TreeEnsembleModels):
.. versionadded:: 1.4.0
"""
+ @property
+ @since("2.0.0")
+ def featureImportances(self):
+ """
+ Estimate of the importance of each feature.
+
+ Each feature's importance is the average of its importance across all trees in the ensemble
+ The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
+ (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
+ and follows the implementation from scikit-learn.
+
+ .. seealso:: :py:attr:`DecisionTreeRegressionModel.featureImportances`
+ """
+ return self._call_java("featureImportances")
+
@inherit_doc
class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,