diff options
author | sethah <seth.hendrickson16@gmail.com> | 2016-03-09 14:44:51 -0800 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-03-09 14:44:51 -0800 |
commit | e1772d3f19bed7e69a80de7900ed22d3eeb05300 (patch) | |
tree | 9db2d2a2b3ac0786141cc51790dc4de0f8e307c5 /mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala | |
parent | c6aa356cd831ea2d159568b699bd5b791f3d8f25 (diff) | |
download | spark-e1772d3f19bed7e69a80de7900ed22d3eeb05300.tar.gz spark-e1772d3f19bed7e69a80de7900ed22d3eeb05300.tar.bz2 spark-e1772d3f19bed7e69a80de7900ed22d3eeb05300.zip |
[SPARK-11861][ML] Add feature importances for decision trees
This patch adds an API entry point for single decision tree feature importances.
Author: sethah <seth.hendrickson16@gmail.com>
Closes #9912 from sethah/SPARK-11861.
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala | 19 |
1 files changed, 19 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 8c4cec1326..7f0397f6bd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -169,6 +169,25 @@ final class DecisionTreeClassificationModel private[ml] ( s"DecisionTreeClassificationModel (uid=$uid) of depth $depth with $numNodes nodes" } + /** + * Estimate of the importance of each feature. + * + * This generalizes the idea of "Gini" importance to other losses, + * following the explanation of Gini importance from "Random Forests" documentation + * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn. + * + * This feature importance is calculated as follows: + * - importance(feature j) = sum (over nodes which split on feature j) of the gain, + * where gain is scaled by the number of instances passing through node + * - Normalize importances for tree to sum to 1. + * + * Note: Feature importance for single decision trees can have high variance due to + * correlated predictor variables. Consider using a [[RandomForestClassifier]] + * to determine feature importance instead. + */ + @Since("2.0.0") + lazy val featureImportances: Vector = RandomForest.featureImportances(this, numFeatures) + /** (private[ml]) Convert to a model in the old API */ private[ml] def toOld: OldDecisionTreeModel = { new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Classification) |