aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
diff options
context:
space:
mode:
authorsethah <seth.hendrickson16@gmail.com>2016-03-09 14:44:51 -0800
committerJoseph K. Bradley <joseph@databricks.com>2016-03-09 14:44:51 -0800
commite1772d3f19bed7e69a80de7900ed22d3eeb05300 (patch)
tree9db2d2a2b3ac0786141cc51790dc4de0f8e307c5 /mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
parentc6aa356cd831ea2d159568b699bd5b791f3d8f25 (diff)
downloadspark-e1772d3f19bed7e69a80de7900ed22d3eeb05300.tar.gz
spark-e1772d3f19bed7e69a80de7900ed22d3eeb05300.tar.bz2
spark-e1772d3f19bed7e69a80de7900ed22d3eeb05300.zip
[SPARK-11861][ML] Add feature importances for decision trees
This patch adds an API entry point for single decision tree feature importances. Author: sethah <seth.hendrickson16@gmail.com> Closes #9912 from sethah/SPARK-11861.
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala19
1 files changed, 19 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 8c4cec1326..7f0397f6bd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -169,6 +169,25 @@ final class DecisionTreeClassificationModel private[ml] (
s"DecisionTreeClassificationModel (uid=$uid) of depth $depth with $numNodes nodes"
}
+ /**
+ * Estimate of the importance of each feature.
+ *
+ * This generalizes the idea of "Gini" importance to other losses,
+ * following the explanation of Gini importance from "Random Forests" documentation
+ * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
+ *
+ * This feature importance is calculated as follows:
+ * - importance(feature j) = sum (over nodes which split on feature j) of the gain,
+ * where gain is scaled by the number of instances passing through node
+ * - Normalize importances for tree to sum to 1.
+ *
+ * Note: Feature importance for single decision trees can have high variance due to
+ * correlated predictor variables. Consider using a [[RandomForestClassifier]]
+ * to determine feature importance instead.
+ */
+ @Since("2.0.0")
+ lazy val featureImportances: Vector = RandomForest.featureImportances(this, numFeatures)
+
/** (private[ml]) Convert to a model in the old API */
private[ml] def toOld: OldDecisionTreeModel = {
new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Classification)