aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorsethah <seth.hendrickson16@gmail.com>2016-03-28 22:27:53 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-03-28 22:27:53 -0700
commitf6066b0c3c35ceea1706378145e15776c9b4415a (patch)
tree3878e899b9bbf1b568c82a2ff0783204a7e81b3e /mllib/src/main
parentd3638d7bffd4ee43db594c0669d86fb64d448fc8 (diff)
downloadspark-f6066b0c3c35ceea1706378145e15776c9b4415a.tar.gz
spark-f6066b0c3c35ceea1706378145e15776c9b4415a.tar.bz2
spark-f6066b0c3c35ceea1706378145e15776c9b4415a.zip
[SPARK-11730][ML] Add feature importances for GBTs.
## What changes were proposed in this pull request? Now that GBTs have been moved to ML, they can use the implementation of feature importance for random forests. This patch simply adds a `featureImportances` attribute to `GBTClassifier` and `GBTRegressor` and adds tests for each. GBT feature importances here simply average the feature importances for each tree in its ensemble. This follows the implementation from scikit-learn. This method is also suggested by J Friedman in [this paper](https://statweb.stanford.edu/~jhf/ftp/trebst.pdf). ## How was this patch tested? Unit tests were added to `GBTClassifierSuite` and `GBTRegressorSuite` to validate feature importances. Author: sethah <seth.hendrickson16@gmail.com> Closes #11961 from sethah/SPARK-11730.
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala13
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala16
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala13
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala16
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala110
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala120
9 files changed, 162 insertions, 132 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 3e4b21bff6..23c4af17f9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -203,7 +203,7 @@ final class DecisionTreeClassificationModel private[ml] (
* to determine feature importance instead.
*/
@Since("2.0.0")
- lazy val featureImportances: Vector = RandomForest.featureImportances(this, numFeatures)
+ lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(this, numFeatures)
/** Convert to spark.mllib DecisionTreeModel (losing some infomation) */
override private[spark] def toOld: OldDecisionTreeModel = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index c31df3aa18..48ce051d0a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -238,6 +238,19 @@ final class GBTClassificationModel private[ml](
s"GBTClassificationModel (uid=$uid) with $numTrees trees"
}
+ /**
+ * Estimate of the importance of each feature.
+ *
+ * Each feature's importance is the average of its importance across all trees in the ensemble
+ * The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
+ * (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
+ * and follows the implementation from scikit-learn.
+ *
+ * @see [[DecisionTreeClassificationModel.featureImportances]]
+ */
+ @Since("2.0.0")
+ lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures)
+
/** (private[ml]) Convert to a model in the old API */
private[ml] def toOld: OldGBTModel = {
new OldGBTModel(OldAlgo.Classification, _trees.map(_.toOld), _treeWeights)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 5da04d341d..82fa05a604 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -222,19 +222,15 @@ final class RandomForestClassificationModel private[ml] (
/**
* Estimate of the importance of each feature.
*
- * This generalizes the idea of "Gini" importance to other losses,
- * following the explanation of Gini importance from "Random Forests" documentation
- * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
+ * Each feature's importance is the average of its importance across all trees in the ensemble
+ * The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
+ * (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
+ * and follows the implementation from scikit-learn.
*
- * This feature importance is calculated as follows:
- * - Average over trees:
- * - importance(feature j) = sum (over nodes which split on feature j) of the gain,
- * where gain is scaled by the number of instances passing through node
- * - Normalize importances for tree to sum to 1.
- * - Normalize feature importance vector to sum to 1.
+ * @see [[DecisionTreeClassificationModel.featureImportances]]
*/
@Since("1.5.0")
- lazy val featureImportances: Vector = RandomForest.featureImportances(trees, numFeatures)
+ lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures)
/** (private[ml]) Convert to a model in the old API */
private[ml] def toOld: OldRandomForestModel = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index 50ac96eb5e..0a3d00e470 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -203,7 +203,7 @@ final class DecisionTreeRegressionModel private[ml] (
* to determine feature importance instead.
*/
@Since("2.0.0")
- lazy val featureImportances: Vector = RandomForest.featureImportances(this, numFeatures)
+ lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(this, numFeatures)
/** Convert to spark.mllib DecisionTreeModel (losing some infomation) */
override private[spark] def toOld: OldDecisionTreeModel = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index da5b77e8fa..8fca35da51 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -224,6 +224,19 @@ final class GBTRegressionModel private[ml](
s"GBTRegressionModel (uid=$uid) with $numTrees trees"
}
+ /**
+ * Estimate of the importance of each feature.
+ *
+ * Each feature's importance is the average of its importance across all trees in the ensemble
+ * The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
+ * (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
+ * and follows the implementation from scikit-learn.
+ *
+ * @see [[DecisionTreeRegressionModel.featureImportances]]
+ */
+ @Since("2.0.0")
+ lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures)
+
/** (private[ml]) Convert to a model in the old API */
private[ml] def toOld: OldGBTModel = {
new OldGBTModel(OldAlgo.Regression, _trees.map(_.toOld), _treeWeights)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index 798947b94a..5b3f3a1f5d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -181,19 +181,15 @@ final class RandomForestRegressionModel private[ml] (
/**
* Estimate of the importance of each feature.
*
- * This generalizes the idea of "Gini" importance to other losses,
- * following the explanation of Gini importance from "Random Forests" documentation
- * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
+ * Each feature's importance is the average of its importance across all trees in the ensemble
+ * The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
+ * (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
+ * and follows the implementation from scikit-learn.
*
- * This feature importance is calculated as follows:
- * - Average over trees:
- * - importance(feature j) = sum (over nodes which split on feature j) of the gain,
- * where gain is scaled by the number of instances passing through node
- * - Normalize importances for tree to sum to 1.
- * - Normalize feature importance vector to sum to 1.
+ * @see [[DecisionTreeRegressionModel.featureImportances]]
*/
@Since("1.5.0")
- lazy val featureImportances: Vector = RandomForest.featureImportances(trees, numFeatures)
+ lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures)
/** (private[ml]) Convert to a model in the old API */
private[ml] def toOld: OldRandomForestModel = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
index 1c8a9b4dfe..b37f4e891e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
@@ -19,7 +19,9 @@ package org.apache.spark.ml.tree.impl
import org.apache.spark.internal.Logging
import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor}
+import org.apache.spark.ml.tree.DecisionTreeModel
import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer
+import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy => OldBoostingStrategy}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index 7774ae64e5..cccf052b3e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -26,7 +26,6 @@ import org.apache.spark.internal.Logging
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree._
-import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, DTStatsAggregator,
@@ -35,7 +34,6 @@ import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
import org.apache.spark.mllib.tree.model.ImpurityStats
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
-import org.apache.spark.util.collection.OpenHashMap
import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom}
@@ -1105,112 +1103,4 @@ private[spark] object RandomForest extends Logging {
}
}
- /**
- * Given a Random Forest model, compute the importance of each feature.
- * This generalizes the idea of "Gini" importance to other losses,
- * following the explanation of Gini importance from "Random Forests" documentation
- * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
- *
- * This feature importance is calculated as follows:
- * - Average over trees:
- * - importance(feature j) = sum (over nodes which split on feature j) of the gain,
- * where gain is scaled by the number of instances passing through node
- * - Normalize importances for tree to sum to 1.
- * - Normalize feature importance vector to sum to 1.
- *
- * @param trees Unweighted forest of trees
- * @param numFeatures Number of features in model (even if not all are explicitly used by
- * the model).
- * If -1, then numFeatures is set based on the max feature index in all trees.
- * @return Feature importance values, of length numFeatures.
- */
- private[ml] def featureImportances(trees: Array[DecisionTreeModel], numFeatures: Int): Vector = {
- val totalImportances = new OpenHashMap[Int, Double]()
- trees.foreach { tree =>
- // Aggregate feature importance vector for this tree
- val importances = new OpenHashMap[Int, Double]()
- computeFeatureImportance(tree.rootNode, importances)
- // Normalize importance vector for this tree, and add it to total.
- // TODO: In the future, also support normalizing by tree.rootNode.impurityStats.count?
- val treeNorm = importances.map(_._2).sum
- if (treeNorm != 0) {
- importances.foreach { case (idx, impt) =>
- val normImpt = impt / treeNorm
- totalImportances.changeValue(idx, normImpt, _ + normImpt)
- }
- }
- }
- // Normalize importances
- normalizeMapValues(totalImportances)
- // Construct vector
- val d = if (numFeatures != -1) {
- numFeatures
- } else {
- // Find max feature index used in trees
- val maxFeatureIndex = trees.map(_.maxSplitFeatureIndex()).max
- maxFeatureIndex + 1
- }
- if (d == 0) {
- assert(totalImportances.size == 0, s"Unknown error in computing feature" +
- s" importance: No splits found, but some non-zero importances.")
- }
- val (indices, values) = totalImportances.iterator.toSeq.sortBy(_._1).unzip
- Vectors.sparse(d, indices.toArray, values.toArray)
- }
-
- /**
- * Given a Decision Tree model, compute the importance of each feature.
- * This generalizes the idea of "Gini" importance to other losses,
- * following the explanation of Gini importance from "Random Forests" documentation
- * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
- *
- * This feature importance is calculated as follows:
- * - importance(feature j) = sum (over nodes which split on feature j) of the gain,
- * where gain is scaled by the number of instances passing through node
- * - Normalize importances for tree to sum to 1.
- *
- * @param tree Decision tree to compute importances for.
- * @param numFeatures Number of features in model (even if not all are explicitly used by
- * the model).
- * If -1, then numFeatures is set based on the max feature index in all trees.
- * @return Feature importance values, of length numFeatures.
- */
- private[ml] def featureImportances(tree: DecisionTreeModel, numFeatures: Int): Vector = {
- featureImportances(Array(tree), numFeatures)
- }
-
- /**
- * Recursive method for computing feature importances for one tree.
- * This walks down the tree, adding to the importance of 1 feature at each node.
- * @param node Current node in recursion
- * @param importances Aggregate feature importances, modified by this method
- */
- private[impl] def computeFeatureImportance(
- node: Node,
- importances: OpenHashMap[Int, Double]): Unit = {
- node match {
- case n: InternalNode =>
- val feature = n.split.featureIndex
- val scaledGain = n.gain * n.impurityStats.count
- importances.changeValue(feature, scaledGain, _ + scaledGain)
- computeFeatureImportance(n.leftChild, importances)
- computeFeatureImportance(n.rightChild, importances)
- case n: LeafNode =>
- // do nothing
- }
- }
-
- /**
- * Normalize the values of this map to sum to 1, in place.
- * If all values are 0, this method does nothing.
- * @param map Map with non-negative values.
- */
- private[impl] def normalizeMapValues(map: OpenHashMap[Int, Double]): Unit = {
- val total = map.map(_._2).sum
- if (total != 0) {
- val keys = map.iterator.map(_._1).toArray
- keys.foreach { key => map.changeValue(key, 0.0, _ / total) }
- }
- }
-
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
index ef40c9068f..1fad9d6d8c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -27,6 +27,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
import org.apache.spark.sql.SQLContext
+import org.apache.spark.util.collection.OpenHashMap
/**
* Abstraction for Decision Tree models.
@@ -115,6 +116,125 @@ private[ml] trait TreeEnsembleModel {
lazy val totalNumNodes: Int = trees.map(_.numNodes).sum
}
+private[ml] object TreeEnsembleModel {
+
+ /**
+ * Given a tree ensemble model, compute the importance of each feature.
+ * This generalizes the idea of "Gini" importance to other losses,
+ * following the explanation of Gini importance from "Random Forests" documentation
+ * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
+ *
+ * For collections of trees, including boosting and bagging, Hastie et al.
+ * propose to use the average of single tree importances across all trees in the ensemble.
+ *
+ * This feature importance is calculated as follows:
+ * - Average over trees:
+ * - importance(feature j) = sum (over nodes which split on feature j) of the gain,
+ * where gain is scaled by the number of instances passing through node
+ * - Normalize importances for tree to sum to 1.
+ * - Normalize feature importance vector to sum to 1.
+ *
+ * References:
+ * - Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.
+ *
+ * @param trees Unweighted collection of trees
+ * @param numFeatures Number of features in model (even if not all are explicitly used by
+ * the model).
+ * If -1, then numFeatures is set based on the max feature index in all trees.
+ * @return Feature importance values, of length numFeatures.
+ */
+ def featureImportances(trees: Array[DecisionTreeModel], numFeatures: Int): Vector = {
+ val totalImportances = new OpenHashMap[Int, Double]()
+ trees.foreach { tree =>
+ // Aggregate feature importance vector for this tree
+ val importances = new OpenHashMap[Int, Double]()
+ computeFeatureImportance(tree.rootNode, importances)
+ // Normalize importance vector for this tree, and add it to total.
+ // TODO: In the future, also support normalizing by tree.rootNode.impurityStats.count?
+ val treeNorm = importances.map(_._2).sum
+ if (treeNorm != 0) {
+ importances.foreach { case (idx, impt) =>
+ val normImpt = impt / treeNorm
+ totalImportances.changeValue(idx, normImpt, _ + normImpt)
+ }
+ }
+ }
+ // Normalize importances
+ normalizeMapValues(totalImportances)
+ // Construct vector
+ val d = if (numFeatures != -1) {
+ numFeatures
+ } else {
+ // Find max feature index used in trees
+ val maxFeatureIndex = trees.map(_.maxSplitFeatureIndex()).max
+ maxFeatureIndex + 1
+ }
+ if (d == 0) {
+ assert(totalImportances.size == 0, s"Unknown error in computing feature" +
+ s" importance: No splits found, but some non-zero importances.")
+ }
+ val (indices, values) = totalImportances.iterator.toSeq.sortBy(_._1).unzip
+ Vectors.sparse(d, indices.toArray, values.toArray)
+ }
+
+ /**
+ * Given a Decision Tree model, compute the importance of each feature.
+ * This generalizes the idea of "Gini" importance to other losses,
+ * following the explanation of Gini importance from "Random Forests" documentation
+ * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
+ *
+ * This feature importance is calculated as follows:
+ * - importance(feature j) = sum (over nodes which split on feature j) of the gain,
+ * where gain is scaled by the number of instances passing through node
+ * - Normalize importances for tree to sum to 1.
+ *
+ * @param tree Decision tree to compute importances for.
+ * @param numFeatures Number of features in model (even if not all are explicitly used by
+ * the model).
+ * If -1, then numFeatures is set based on the max feature index in all trees.
+ * @return Feature importance values, of length numFeatures.
+ */
+ def featureImportances(tree: DecisionTreeModel, numFeatures: Int): Vector = {
+ featureImportances(Array(tree), numFeatures)
+ }
+
+ /**
+ * Recursive method for computing feature importances for one tree.
+ * This walks down the tree, adding to the importance of 1 feature at each node.
+ *
+ * @param node Current node in recursion
+ * @param importances Aggregate feature importances, modified by this method
+ */
+ def computeFeatureImportance(
+ node: Node,
+ importances: OpenHashMap[Int, Double]): Unit = {
+ node match {
+ case n: InternalNode =>
+ val feature = n.split.featureIndex
+ val scaledGain = n.gain * n.impurityStats.count
+ importances.changeValue(feature, scaledGain, _ + scaledGain)
+ computeFeatureImportance(n.leftChild, importances)
+ computeFeatureImportance(n.rightChild, importances)
+ case n: LeafNode =>
+ // do nothing
+ }
+ }
+
+ /**
+ * Normalize the values of this map to sum to 1, in place.
+ * If all values are 0, this method does nothing.
+ *
+ * @param map Map with non-negative values.
+ */
+ def normalizeMapValues(map: OpenHashMap[Int, Double]): Unit = {
+ val total = map.map(_._2).sum
+ if (total != 0) {
+ val keys = map.iterator.map(_._1).toArray
+ keys.foreach { key => map.changeValue(key, 0.0, _ / total) }
+ }
+ }
+}
+
/** Helper classes for tree model persistence */
private[ml] object DecisionTreeModelReadWrite {