aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-08-03 12:17:46 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-08-03 12:17:46 -0700
commitff9169a002f1b75231fd25b7d04157a912503038 (patch)
treeef57aa63ad02760806657e491a78f15f5daa7f66 /mllib/src/main
parent703e44bff19f4c394f6f9bff1ce9152cdc68c51e (diff)
downloadspark-ff9169a002f1b75231fd25b7d04157a912503038.tar.gz
spark-ff9169a002f1b75231fd25b7d04157a912503038.tar.bz2
spark-ff9169a002f1b75231fd25b7d04157a912503038.zip
[SPARK-5133] [ML] Added featureImportance to RandomForestClassifier and Regressor
Added featureImportance to RandomForestClassifier and Regressor. This follows the scikit-learn implementation here: [https://github.com/scikit-learn/scikit-learn/blob/a95203b249c1cf392f86d001ad999e29b2392739/sklearn/tree/_tree.pyx#L3341] CC: yanboliang Would you mind taking a look? Thanks! Author: Joseph K. Bradley <joseph@databricks.com> Author: Feynman Liang <fliang@databricks.com> Closes #7838 from jkbradley/dt-feature-importance and squashes the following commits: 72a167a [Joseph K. Bradley] fixed unit test 86cea5f [Joseph K. Bradley] Modified RF featuresImportances to return Vector instead of Map 5aa74f0 [Joseph K. Bradley] finally fixed unit test for real 33df5db [Joseph K. Bradley] fix unit test 42a2d3b [Joseph K. Bradley] fix unit test fe94e72 [Joseph K. Bradley] modified feature importance unit tests cc693ee [Feynman Liang] Add classifier tests 79a6f87 [Feynman Liang] Compare dense vectors in test 21d01fc [Feynman Liang] Added failing SKLearn test ac0b254 [Joseph K. Bradley] Added featureImportance to RandomForestClassifier/Regressor. Need to add unit tests
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala30
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala33
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala19
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala92
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala6
5 files changed, 166 insertions, 14 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 56e80cc8fe..b59826a594 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -95,7 +95,8 @@ final class RandomForestClassifier(override val uid: String)
val trees =
RandomForest.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed)
.map(_.asInstanceOf[DecisionTreeClassificationModel])
- new RandomForestClassificationModel(trees, numClasses)
+ val numFeatures = oldDataset.first().features.size
+ new RandomForestClassificationModel(trees, numFeatures, numClasses)
}
override def copy(extra: ParamMap): RandomForestClassifier = defaultCopy(extra)
@@ -118,11 +119,13 @@ object RandomForestClassifier {
* features.
* @param _trees Decision trees in the ensemble.
* Warning: These have null parents.
+ * @param numFeatures Number of features used by this model
*/
@Experimental
final class RandomForestClassificationModel private[ml] (
override val uid: String,
private val _trees: Array[DecisionTreeClassificationModel],
+ val numFeatures: Int,
override val numClasses: Int)
extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel]
with TreeEnsembleModel with Serializable {
@@ -133,8 +136,8 @@ final class RandomForestClassificationModel private[ml] (
* Construct a random forest classification model, with all trees weighted equally.
* @param trees Component trees
*/
- def this(trees: Array[DecisionTreeClassificationModel], numClasses: Int) =
- this(Identifiable.randomUID("rfc"), trees, numClasses)
+ def this(trees: Array[DecisionTreeClassificationModel], numFeatures: Int, numClasses: Int) =
+ this(Identifiable.randomUID("rfc"), trees, numFeatures, numClasses)
override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
@@ -182,13 +185,30 @@ final class RandomForestClassificationModel private[ml] (
}
override def copy(extra: ParamMap): RandomForestClassificationModel = {
- copyValues(new RandomForestClassificationModel(uid, _trees, numClasses), extra)
+ copyValues(new RandomForestClassificationModel(uid, _trees, numFeatures, numClasses), extra)
}
override def toString: String = {
s"RandomForestClassificationModel with $numTrees trees"
}
+ /**
+ * Estimate of the importance of each feature.
+ *
+ * This generalizes the idea of "Gini" importance to other losses,
+ * following the explanation of Gini importance from "Random Forests" documentation
+ * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
+ *
+ * This feature importance is calculated as follows:
+ * - Average over trees:
+ * - importance(feature j) = sum (over nodes which split on feature j) of the gain,
+ * where gain is scaled by the number of instances passing through node
+ * - Normalize importances for tree based on total number of training instances used
+ * to build tree.
+ * - Normalize feature importance vector to sum to 1.
+ */
+ lazy val featureImportances: Vector = RandomForest.featureImportances(trees, numFeatures)
+
/** (private[ml]) Convert to a model in the old API */
private[ml] def toOld: OldRandomForestModel = {
new OldRandomForestModel(OldAlgo.Classification, _trees.map(_.toOld))
@@ -210,6 +230,6 @@ private[ml] object RandomForestClassificationModel {
DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures)
}
val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfc")
- new RandomForestClassificationModel(uid, newTrees, numClasses)
+ new RandomForestClassificationModel(uid, newTrees, -1, numClasses)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index 17fb1ad5e1..1ee43c8725 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -30,7 +30,7 @@ import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestMo
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.DoubleType
+
/**
* :: Experimental ::
@@ -87,7 +87,8 @@ final class RandomForestRegressor(override val uid: String)
val trees =
RandomForest.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed)
.map(_.asInstanceOf[DecisionTreeRegressionModel])
- new RandomForestRegressionModel(trees)
+ val numFeatures = oldDataset.first().features.size
+ new RandomForestRegressionModel(trees, numFeatures)
}
override def copy(extra: ParamMap): RandomForestRegressor = defaultCopy(extra)
@@ -108,11 +109,13 @@ object RandomForestRegressor {
* [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for regression.
* It supports both continuous and categorical features.
* @param _trees Decision trees in the ensemble.
+ * @param numFeatures Number of features used by this model
*/
@Experimental
final class RandomForestRegressionModel private[ml] (
override val uid: String,
- private val _trees: Array[DecisionTreeRegressionModel])
+ private val _trees: Array[DecisionTreeRegressionModel],
+ val numFeatures: Int)
extends PredictionModel[Vector, RandomForestRegressionModel]
with TreeEnsembleModel with Serializable {
@@ -122,7 +125,8 @@ final class RandomForestRegressionModel private[ml] (
* Construct a random forest regression model, with all trees weighted equally.
* @param trees Component trees
*/
- def this(trees: Array[DecisionTreeRegressionModel]) = this(Identifiable.randomUID("rfr"), trees)
+ def this(trees: Array[DecisionTreeRegressionModel], numFeatures: Int) =
+ this(Identifiable.randomUID("rfr"), trees, numFeatures)
override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
@@ -147,13 +151,30 @@ final class RandomForestRegressionModel private[ml] (
}
override def copy(extra: ParamMap): RandomForestRegressionModel = {
- copyValues(new RandomForestRegressionModel(uid, _trees), extra)
+ copyValues(new RandomForestRegressionModel(uid, _trees, numFeatures), extra)
}
override def toString: String = {
s"RandomForestRegressionModel with $numTrees trees"
}
+ /**
+ * Estimate of the importance of each feature.
+ *
+ * This generalizes the idea of "Gini" importance to other losses,
+ * following the explanation of Gini importance from "Random Forests" documentation
+ * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
+ *
+ * This feature importance is calculated as follows:
+ * - Average over trees:
+ * - importance(feature j) = sum (over nodes which split on feature j) of the gain,
+ * where gain is scaled by the number of instances passing through node
+ * - Normalize importances for tree based on total number of training instances used
+ * to build tree.
+ * - Normalize feature importance vector to sum to 1.
+ */
+ lazy val featureImportances: Vector = RandomForest.featureImportances(trees, numFeatures)
+
/** (private[ml]) Convert to a model in the old API */
private[ml] def toOld: OldRandomForestModel = {
new OldRandomForestModel(OldAlgo.Regression, _trees.map(_.toOld))
@@ -173,6 +194,6 @@ private[ml] object RandomForestRegressionModel {
// parent for each tree is null since there is no good way to set this.
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
- new RandomForestRegressionModel(parent.uid, newTrees)
+ new RandomForestRegressionModel(parent.uid, newTrees, -1)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
index 8879352a60..cd24931293 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
@@ -44,7 +44,7 @@ sealed abstract class Node extends Serializable {
* and probabilities.
* For classification, the array of class counts must be normalized to a probability distribution.
*/
- private[tree] def impurityStats: ImpurityCalculator
+ private[ml] def impurityStats: ImpurityCalculator
/** Recursive prediction helper method */
private[ml] def predictImpl(features: Vector): LeafNode
@@ -72,6 +72,12 @@ sealed abstract class Node extends Serializable {
* @param id Node ID using old format IDs
*/
private[ml] def toOld(id: Int): OldNode
+
+ /**
+ * Trace down the tree, and return the largest feature index used in any split.
+ * @return Max feature index used in a split, or -1 if there are no splits (single leaf node).
+ */
+ private[ml] def maxSplitFeatureIndex(): Int
}
private[ml] object Node {
@@ -109,7 +115,7 @@ private[ml] object Node {
final class LeafNode private[ml] (
override val prediction: Double,
override val impurity: Double,
- override val impurityStats: ImpurityCalculator) extends Node {
+ override private[ml] val impurityStats: ImpurityCalculator) extends Node {
override def toString: String =
s"LeafNode(prediction = $prediction, impurity = $impurity)"
@@ -129,6 +135,8 @@ final class LeafNode private[ml] (
new OldNode(id, new OldPredict(prediction, prob = impurityStats.prob(prediction)),
impurity, isLeaf = true, None, None, None, None)
}
+
+ override private[ml] def maxSplitFeatureIndex(): Int = -1
}
/**
@@ -150,7 +158,7 @@ final class InternalNode private[ml] (
val leftChild: Node,
val rightChild: Node,
val split: Split,
- override val impurityStats: ImpurityCalculator) extends Node {
+ override private[ml] val impurityStats: ImpurityCalculator) extends Node {
override def toString: String = {
s"InternalNode(prediction = $prediction, impurity = $impurity, split = $split)"
@@ -190,6 +198,11 @@ final class InternalNode private[ml] (
new OldPredict(leftChild.prediction, prob = 0.0),
new OldPredict(rightChild.prediction, prob = 0.0))))
}
+
+ override private[ml] def maxSplitFeatureIndex(): Int = {
+ math.max(split.featureIndex,
+ math.max(leftChild.maxSplitFeatureIndex(), rightChild.maxSplitFeatureIndex()))
+ }
}
private object InternalNode {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index a8b90d9d26..4ac51a4754 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -26,6 +26,7 @@ import org.apache.spark.Logging
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree._
+import org.apache.spark.mllib.linalg.{Vectors, Vector}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.impl.{BaggedPoint, DTStatsAggregator, DecisionTreeMetadata,
@@ -34,6 +35,7 @@ import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
import org.apache.spark.mllib.tree.model.ImpurityStats
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.collection.OpenHashMap
import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom}
@@ -1113,4 +1115,94 @@ private[ml] object RandomForest extends Logging {
}
}
+ /**
+ * Given a Random Forest model, compute the importance of each feature.
+ * This generalizes the idea of "Gini" importance to other losses,
+ * following the explanation of Gini importance from "Random Forests" documentation
+ * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
+ *
+ * This feature importance is calculated as follows:
+ * - Average over trees:
+ * - importance(feature j) = sum (over nodes which split on feature j) of the gain,
+ * where gain is scaled by the number of instances passing through node
+ * - Normalize importances for tree based on total number of training instances used
+ * to build tree.
+ * - Normalize feature importance vector to sum to 1.
+ *
+ * Note: This should not be used with Gradient-Boosted Trees. It only makes sense for
+ * independently trained trees.
+ * @param trees Unweighted forest of trees
+ * @param numFeatures Number of features in model (even if not all are explicitly used by
+ * the model).
+ * If -1, then numFeatures is set based on the max feature index in all trees.
+ * @return Feature importance values, of length numFeatures.
+ */
+ private[ml] def featureImportances(trees: Array[DecisionTreeModel], numFeatures: Int): Vector = {
+ val totalImportances = new OpenHashMap[Int, Double]()
+ trees.foreach { tree =>
+ // Aggregate feature importance vector for this tree
+ val importances = new OpenHashMap[Int, Double]()
+ computeFeatureImportance(tree.rootNode, importances)
+ // Normalize importance vector for this tree, and add it to total.
+ // TODO: In the future, also support normalizing by tree.rootNode.impurityStats.count?
+ val treeNorm = importances.map(_._2).sum
+ if (treeNorm != 0) {
+ importances.foreach { case (idx, impt) =>
+ val normImpt = impt / treeNorm
+ totalImportances.changeValue(idx, normImpt, _ + normImpt)
+ }
+ }
+ }
+ // Normalize importances
+ normalizeMapValues(totalImportances)
+ // Construct vector
+ val d = if (numFeatures != -1) {
+ numFeatures
+ } else {
+ // Find max feature index used in trees
+ val maxFeatureIndex = trees.map(_.maxSplitFeatureIndex()).max
+ maxFeatureIndex + 1
+ }
+ if (d == 0) {
+ assert(totalImportances.size == 0, s"Unknown error in computing RandomForest feature" +
+ s" importance: No splits in forest, but some non-zero importances.")
+ }
+ val (indices, values) = totalImportances.iterator.toSeq.sortBy(_._1).unzip
+ Vectors.sparse(d, indices.toArray, values.toArray)
+ }
+
+ /**
+ * Recursive method for computing feature importances for one tree.
+ * This walks down the tree, adding to the importance of 1 feature at each node.
+ * @param node Current node in recursion
+ * @param importances Aggregate feature importances, modified by this method
+ */
+ private[impl] def computeFeatureImportance(
+ node: Node,
+ importances: OpenHashMap[Int, Double]): Unit = {
+ node match {
+ case n: InternalNode =>
+ val feature = n.split.featureIndex
+ val scaledGain = n.gain * n.impurityStats.count
+ importances.changeValue(feature, scaledGain, _ + scaledGain)
+ computeFeatureImportance(n.leftChild, importances)
+ computeFeatureImportance(n.rightChild, importances)
+ case n: LeafNode =>
+ // do nothing
+ }
+ }
+
+ /**
+ * Normalize the values of this map to sum to 1, in place.
+ * If all values are 0, this method does nothing.
+ * @param map Map with non-negative values.
+ */
+ private[impl] def normalizeMapValues(map: OpenHashMap[Int, Double]): Unit = {
+ val total = map.map(_._2).sum
+ if (total != 0) {
+ val keys = map.iterator.map(_._1).toArray
+ keys.foreach { key => map.changeValue(key, 0.0, _ / total) }
+ }
+ }
+
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
index 22873909c3..b77191156f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -53,6 +53,12 @@ private[ml] trait DecisionTreeModel {
val header = toString + "\n"
header + rootNode.subtreeToString(2)
}
+
+ /**
+ * Trace down the tree, and return the largest feature index used in any split.
+ * @return Max feature index used in a split, or -1 if there are no splits (single leaf node).
+ */
+ private[ml] def maxSplitFeatureIndex(): Int = rootNode.maxSplitFeatureIndex()
}
/**