diff options
Diffstat (limited to 'mllib/src/main')
9 files changed, 162 insertions, 132 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 3e4b21bff6..23c4af17f9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -203,7 +203,7 @@ final class DecisionTreeClassificationModel private[ml] ( * to determine feature importance instead. */ @Since("2.0.0") - lazy val featureImportances: Vector = RandomForest.featureImportances(this, numFeatures) + lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(this, numFeatures) /** Convert to spark.mllib DecisionTreeModel (losing some infomation) */ override private[spark] def toOld: OldDecisionTreeModel = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index c31df3aa18..48ce051d0a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -238,6 +238,19 @@ final class GBTClassificationModel private[ml]( s"GBTClassificationModel (uid=$uid) with $numTrees trees" } + /** + * Estimate of the importance of each feature. + * + * Each feature's importance is the average of its importance across all trees in the ensemble + * The importance vector is normalized to sum to 1. This method is suggested by Hastie et al. + * (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.) + * and follows the implementation from scikit-learn. + * + * @see [[DecisionTreeClassificationModel.featureImportances]] + */ + @Since("2.0.0") + lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures) + /** (private[ml]) Convert to a model in the old API */ private[ml] def toOld: OldGBTModel = { new OldGBTModel(OldAlgo.Classification, _trees.map(_.toOld), _treeWeights) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 5da04d341d..82fa05a604 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -222,19 +222,15 @@ final class RandomForestClassificationModel private[ml] ( /** * Estimate of the importance of each feature. * - * This generalizes the idea of "Gini" importance to other losses, - * following the explanation of Gini importance from "Random Forests" documentation - * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn. + * Each feature's importance is the average of its importance across all trees in the ensemble + * The importance vector is normalized to sum to 1. This method is suggested by Hastie et al. + * (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.) + * and follows the implementation from scikit-learn. * - * This feature importance is calculated as follows: - * - Average over trees: - * - importance(feature j) = sum (over nodes which split on feature j) of the gain, - * where gain is scaled by the number of instances passing through node - * - Normalize importances for tree to sum to 1. - * - Normalize feature importance vector to sum to 1. + * @see [[DecisionTreeClassificationModel.featureImportances]] */ @Since("1.5.0") - lazy val featureImportances: Vector = RandomForest.featureImportances(trees, numFeatures) + lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures) /** (private[ml]) Convert to a model in the old API */ private[ml] def toOld: OldRandomForestModel = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 50ac96eb5e..0a3d00e470 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -203,7 +203,7 @@ final class DecisionTreeRegressionModel private[ml] ( * to determine feature importance instead. */ @Since("2.0.0") - lazy val featureImportances: Vector = RandomForest.featureImportances(this, numFeatures) + lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(this, numFeatures) /** Convert to spark.mllib DecisionTreeModel (losing some infomation) */ override private[spark] def toOld: OldDecisionTreeModel = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index da5b77e8fa..8fca35da51 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -224,6 +224,19 @@ final class GBTRegressionModel private[ml]( s"GBTRegressionModel (uid=$uid) with $numTrees trees" } + /** + * Estimate of the importance of each feature. + * + * Each feature's importance is the average of its importance across all trees in the ensemble + * The importance vector is normalized to sum to 1. This method is suggested by Hastie et al. + * (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.) + * and follows the implementation from scikit-learn. + * + * @see [[DecisionTreeRegressionModel.featureImportances]] + */ + @Since("2.0.0") + lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures) + /** (private[ml]) Convert to a model in the old API */ private[ml] def toOld: OldGBTModel = { new OldGBTModel(OldAlgo.Regression, _trees.map(_.toOld), _treeWeights) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 798947b94a..5b3f3a1f5d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -181,19 +181,15 @@ final class RandomForestRegressionModel private[ml] ( /** * Estimate of the importance of each feature. * - * This generalizes the idea of "Gini" importance to other losses, - * following the explanation of Gini importance from "Random Forests" documentation - * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn. + * Each feature's importance is the average of its importance across all trees in the ensemble + * The importance vector is normalized to sum to 1. This method is suggested by Hastie et al. + * (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.) + * and follows the implementation from scikit-learn. * - * This feature importance is calculated as follows: - * - Average over trees: - * - importance(feature j) = sum (over nodes which split on feature j) of the gain, - * where gain is scaled by the number of instances passing through node - * - Normalize importances for tree to sum to 1. - * - Normalize feature importance vector to sum to 1. + * @see [[DecisionTreeRegressionModel.featureImportances]] */ @Since("1.5.0") - lazy val featureImportances: Vector = RandomForest.featureImportances(trees, numFeatures) + lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures) /** (private[ml]) Convert to a model in the old API */ private[ml] def toOld: OldRandomForestModel = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala index 1c8a9b4dfe..b37f4e891e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala @@ -19,7 +19,9 @@ package org.apache.spark.ml.tree.impl import org.apache.spark.internal.Logging import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor} +import org.apache.spark.ml.tree.DecisionTreeModel import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.configuration.{BoostingStrategy => OldBoostingStrategy} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 7774ae64e5..cccf052b3e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -26,7 +26,6 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml.classification.DecisionTreeClassificationModel import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree._ -import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, DTStatsAggregator, @@ -35,7 +34,6 @@ import org.apache.spark.mllib.tree.impurity.ImpurityCalculator import org.apache.spark.mllib.tree.model.ImpurityStats import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.collection.OpenHashMap import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom} @@ -1105,112 +1103,4 @@ private[spark] object RandomForest extends Logging { } } - /** - * Given a Random Forest model, compute the importance of each feature. - * This generalizes the idea of "Gini" importance to other losses, - * following the explanation of Gini importance from "Random Forests" documentation - * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn. - * - * This feature importance is calculated as follows: - * - Average over trees: - * - importance(feature j) = sum (over nodes which split on feature j) of the gain, - * where gain is scaled by the number of instances passing through node - * - Normalize importances for tree to sum to 1. - * - Normalize feature importance vector to sum to 1. - * - * @param trees Unweighted forest of trees - * @param numFeatures Number of features in model (even if not all are explicitly used by - * the model). - * If -1, then numFeatures is set based on the max feature index in all trees. - * @return Feature importance values, of length numFeatures. - */ - private[ml] def featureImportances(trees: Array[DecisionTreeModel], numFeatures: Int): Vector = { - val totalImportances = new OpenHashMap[Int, Double]() - trees.foreach { tree => - // Aggregate feature importance vector for this tree - val importances = new OpenHashMap[Int, Double]() - computeFeatureImportance(tree.rootNode, importances) - // Normalize importance vector for this tree, and add it to total. - // TODO: In the future, also support normalizing by tree.rootNode.impurityStats.count? - val treeNorm = importances.map(_._2).sum - if (treeNorm != 0) { - importances.foreach { case (idx, impt) => - val normImpt = impt / treeNorm - totalImportances.changeValue(idx, normImpt, _ + normImpt) - } - } - } - // Normalize importances - normalizeMapValues(totalImportances) - // Construct vector - val d = if (numFeatures != -1) { - numFeatures - } else { - // Find max feature index used in trees - val maxFeatureIndex = trees.map(_.maxSplitFeatureIndex()).max - maxFeatureIndex + 1 - } - if (d == 0) { - assert(totalImportances.size == 0, s"Unknown error in computing feature" + - s" importance: No splits found, but some non-zero importances.") - } - val (indices, values) = totalImportances.iterator.toSeq.sortBy(_._1).unzip - Vectors.sparse(d, indices.toArray, values.toArray) - } - - /** - * Given a Decision Tree model, compute the importance of each feature. - * This generalizes the idea of "Gini" importance to other losses, - * following the explanation of Gini importance from "Random Forests" documentation - * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn. - * - * This feature importance is calculated as follows: - * - importance(feature j) = sum (over nodes which split on feature j) of the gain, - * where gain is scaled by the number of instances passing through node - * - Normalize importances for tree to sum to 1. - * - * @param tree Decision tree to compute importances for. - * @param numFeatures Number of features in model (even if not all are explicitly used by - * the model). - * If -1, then numFeatures is set based on the max feature index in all trees. - * @return Feature importance values, of length numFeatures. - */ - private[ml] def featureImportances(tree: DecisionTreeModel, numFeatures: Int): Vector = { - featureImportances(Array(tree), numFeatures) - } - - /** - * Recursive method for computing feature importances for one tree. - * This walks down the tree, adding to the importance of 1 feature at each node. - * @param node Current node in recursion - * @param importances Aggregate feature importances, modified by this method - */ - private[impl] def computeFeatureImportance( - node: Node, - importances: OpenHashMap[Int, Double]): Unit = { - node match { - case n: InternalNode => - val feature = n.split.featureIndex - val scaledGain = n.gain * n.impurityStats.count - importances.changeValue(feature, scaledGain, _ + scaledGain) - computeFeatureImportance(n.leftChild, importances) - computeFeatureImportance(n.rightChild, importances) - case n: LeafNode => - // do nothing - } - } - - /** - * Normalize the values of this map to sum to 1, in place. - * If all values are 0, this method does nothing. - * @param map Map with non-negative values. - */ - private[impl] def normalizeMapValues(map: OpenHashMap[Int, Double]): Unit = { - val total = map.map(_._2).sum - if (total != 0) { - val keys = map.iterator.map(_._1).toArray - keys.foreach { key => map.changeValue(key, 0.0, _ / total) } - } - } - } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index ef40c9068f..1fad9d6d8c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -27,6 +27,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.tree.impurity.ImpurityCalculator import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} import org.apache.spark.sql.SQLContext +import org.apache.spark.util.collection.OpenHashMap /** * Abstraction for Decision Tree models. @@ -115,6 +116,125 @@ private[ml] trait TreeEnsembleModel { lazy val totalNumNodes: Int = trees.map(_.numNodes).sum } +private[ml] object TreeEnsembleModel { + + /** + * Given a tree ensemble model, compute the importance of each feature. + * This generalizes the idea of "Gini" importance to other losses, + * following the explanation of Gini importance from "Random Forests" documentation + * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn. + * + * For collections of trees, including boosting and bagging, Hastie et al. + * propose to use the average of single tree importances across all trees in the ensemble. + * + * This feature importance is calculated as follows: + * - Average over trees: + * - importance(feature j) = sum (over nodes which split on feature j) of the gain, + * where gain is scaled by the number of instances passing through node + * - Normalize importances for tree to sum to 1. + * - Normalize feature importance vector to sum to 1. + * + * References: + * - Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001. + * + * @param trees Unweighted collection of trees + * @param numFeatures Number of features in model (even if not all are explicitly used by + * the model). + * If -1, then numFeatures is set based on the max feature index in all trees. + * @return Feature importance values, of length numFeatures. + */ + def featureImportances(trees: Array[DecisionTreeModel], numFeatures: Int): Vector = { + val totalImportances = new OpenHashMap[Int, Double]() + trees.foreach { tree => + // Aggregate feature importance vector for this tree + val importances = new OpenHashMap[Int, Double]() + computeFeatureImportance(tree.rootNode, importances) + // Normalize importance vector for this tree, and add it to total. + // TODO: In the future, also support normalizing by tree.rootNode.impurityStats.count? + val treeNorm = importances.map(_._2).sum + if (treeNorm != 0) { + importances.foreach { case (idx, impt) => + val normImpt = impt / treeNorm + totalImportances.changeValue(idx, normImpt, _ + normImpt) + } + } + } + // Normalize importances + normalizeMapValues(totalImportances) + // Construct vector + val d = if (numFeatures != -1) { + numFeatures + } else { + // Find max feature index used in trees + val maxFeatureIndex = trees.map(_.maxSplitFeatureIndex()).max + maxFeatureIndex + 1 + } + if (d == 0) { + assert(totalImportances.size == 0, s"Unknown error in computing feature" + + s" importance: No splits found, but some non-zero importances.") + } + val (indices, values) = totalImportances.iterator.toSeq.sortBy(_._1).unzip + Vectors.sparse(d, indices.toArray, values.toArray) + } + + /** + * Given a Decision Tree model, compute the importance of each feature. + * This generalizes the idea of "Gini" importance to other losses, + * following the explanation of Gini importance from "Random Forests" documentation + * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn. + * + * This feature importance is calculated as follows: + * - importance(feature j) = sum (over nodes which split on feature j) of the gain, + * where gain is scaled by the number of instances passing through node + * - Normalize importances for tree to sum to 1. + * + * @param tree Decision tree to compute importances for. + * @param numFeatures Number of features in model (even if not all are explicitly used by + * the model). + * If -1, then numFeatures is set based on the max feature index in all trees. + * @return Feature importance values, of length numFeatures. + */ + def featureImportances(tree: DecisionTreeModel, numFeatures: Int): Vector = { + featureImportances(Array(tree), numFeatures) + } + + /** + * Recursive method for computing feature importances for one tree. + * This walks down the tree, adding to the importance of 1 feature at each node. + * + * @param node Current node in recursion + * @param importances Aggregate feature importances, modified by this method + */ + def computeFeatureImportance( + node: Node, + importances: OpenHashMap[Int, Double]): Unit = { + node match { + case n: InternalNode => + val feature = n.split.featureIndex + val scaledGain = n.gain * n.impurityStats.count + importances.changeValue(feature, scaledGain, _ + scaledGain) + computeFeatureImportance(n.leftChild, importances) + computeFeatureImportance(n.rightChild, importances) + case n: LeafNode => + // do nothing + } + } + + /** + * Normalize the values of this map to sum to 1, in place. + * If all values are 0, this method does nothing. + * + * @param map Map with non-negative values. + */ + def normalizeMapValues(map: OpenHashMap[Int, Double]): Unit = { + val total = map.map(_._2).sum + if (total != 0) { + val keys = map.iterator.map(_._1).toArray + keys.foreach { key => map.changeValue(key, 0.0, _ / total) } + } + } +} + /** Helper classes for tree model persistence */ private[ml] object DecisionTreeModelReadWrite { |