aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala13
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala16
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala13
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala16
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala110
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala120
9 files changed, 162 insertions, 132 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 3e4b21bff6..23c4af17f9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -203,7 +203,7 @@ final class DecisionTreeClassificationModel private[ml] (
* to determine feature importance instead.
*/
@Since("2.0.0")
- lazy val featureImportances: Vector = RandomForest.featureImportances(this, numFeatures)
+ lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(this, numFeatures)
/** Convert to spark.mllib DecisionTreeModel (losing some infomation) */
override private[spark] def toOld: OldDecisionTreeModel = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index c31df3aa18..48ce051d0a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -238,6 +238,19 @@ final class GBTClassificationModel private[ml](
s"GBTClassificationModel (uid=$uid) with $numTrees trees"
}
+ /**
+ * Estimate of the importance of each feature.
+ *
+ * Each feature's importance is the average of its importance across all trees in the ensemble
+ * The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
+ * (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
+ * and follows the implementation from scikit-learn.
+ *
+ * @see [[DecisionTreeClassificationModel.featureImportances]]
+ */
+ @Since("2.0.0")
+ lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures)
+
/** (private[ml]) Convert to a model in the old API */
private[ml] def toOld: OldGBTModel = {
new OldGBTModel(OldAlgo.Classification, _trees.map(_.toOld), _treeWeights)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 5da04d341d..82fa05a604 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -222,19 +222,15 @@ final class RandomForestClassificationModel private[ml] (
/**
* Estimate of the importance of each feature.
*
- * This generalizes the idea of "Gini" importance to other losses,
- * following the explanation of Gini importance from "Random Forests" documentation
- * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
+ * Each feature's importance is the average of its importance across all trees in the ensemble
+ * The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
+ * (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
+ * and follows the implementation from scikit-learn.
*
- * This feature importance is calculated as follows:
- * - Average over trees:
- * - importance(feature j) = sum (over nodes which split on feature j) of the gain,
- * where gain is scaled by the number of instances passing through node
- * - Normalize importances for tree to sum to 1.
- * - Normalize feature importance vector to sum to 1.
+ * @see [[DecisionTreeClassificationModel.featureImportances]]
*/
@Since("1.5.0")
- lazy val featureImportances: Vector = RandomForest.featureImportances(trees, numFeatures)
+ lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures)
/** (private[ml]) Convert to a model in the old API */
private[ml] def toOld: OldRandomForestModel = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index 50ac96eb5e..0a3d00e470 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -203,7 +203,7 @@ final class DecisionTreeRegressionModel private[ml] (
* to determine feature importance instead.
*/
@Since("2.0.0")
- lazy val featureImportances: Vector = RandomForest.featureImportances(this, numFeatures)
+ lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(this, numFeatures)
/** Convert to spark.mllib DecisionTreeModel (losing some infomation) */
override private[spark] def toOld: OldDecisionTreeModel = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index da5b77e8fa..8fca35da51 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -224,6 +224,19 @@ final class GBTRegressionModel private[ml](
s"GBTRegressionModel (uid=$uid) with $numTrees trees"
}
+ /**
+ * Estimate of the importance of each feature.
+ *
+ * Each feature's importance is the average of its importance across all trees in the ensemble
+ * The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
+ * (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
+ * and follows the implementation from scikit-learn.
+ *
+ * @see [[DecisionTreeRegressionModel.featureImportances]]
+ */
+ @Since("2.0.0")
+ lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures)
+
/** (private[ml]) Convert to a model in the old API */
private[ml] def toOld: OldGBTModel = {
new OldGBTModel(OldAlgo.Regression, _trees.map(_.toOld), _treeWeights)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index 798947b94a..5b3f3a1f5d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -181,19 +181,15 @@ final class RandomForestRegressionModel private[ml] (
/**
* Estimate of the importance of each feature.
*
- * This generalizes the idea of "Gini" importance to other losses,
- * following the explanation of Gini importance from "Random Forests" documentation
- * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
+ * Each feature's importance is the average of its importance across all trees in the ensemble
+ * The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
+ * (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
+ * and follows the implementation from scikit-learn.
*
- * This feature importance is calculated as follows:
- * - Average over trees:
- * - importance(feature j) = sum (over nodes which split on feature j) of the gain,
- * where gain is scaled by the number of instances passing through node
- * - Normalize importances for tree to sum to 1.
- * - Normalize feature importance vector to sum to 1.
+ * @see [[DecisionTreeRegressionModel.featureImportances]]
*/
@Since("1.5.0")
- lazy val featureImportances: Vector = RandomForest.featureImportances(trees, numFeatures)
+ lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures)
/** (private[ml]) Convert to a model in the old API */
private[ml] def toOld: OldRandomForestModel = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
index 1c8a9b4dfe..b37f4e891e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
@@ -19,7 +19,9 @@ package org.apache.spark.ml.tree.impl
import org.apache.spark.internal.Logging
import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor}
+import org.apache.spark.ml.tree.DecisionTreeModel
import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer
+import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy => OldBoostingStrategy}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index 7774ae64e5..cccf052b3e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -26,7 +26,6 @@ import org.apache.spark.internal.Logging
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree._
-import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, DTStatsAggregator,
@@ -35,7 +34,6 @@ import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
import org.apache.spark.mllib.tree.model.ImpurityStats
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
-import org.apache.spark.util.collection.OpenHashMap
import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom}
@@ -1105,112 +1103,4 @@ private[spark] object RandomForest extends Logging {
}
}
- /**
- * Given a Random Forest model, compute the importance of each feature.
- * This generalizes the idea of "Gini" importance to other losses,
- * following the explanation of Gini importance from "Random Forests" documentation
- * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
- *
- * This feature importance is calculated as follows:
- * - Average over trees:
- * - importance(feature j) = sum (over nodes which split on feature j) of the gain,
- * where gain is scaled by the number of instances passing through node
- * - Normalize importances for tree to sum to 1.
- * - Normalize feature importance vector to sum to 1.
- *
- * @param trees Unweighted forest of trees
- * @param numFeatures Number of features in model (even if not all are explicitly used by
- * the model).
- * If -1, then numFeatures is set based on the max feature index in all trees.
- * @return Feature importance values, of length numFeatures.
- */
- private[ml] def featureImportances(trees: Array[DecisionTreeModel], numFeatures: Int): Vector = {
- val totalImportances = new OpenHashMap[Int, Double]()
- trees.foreach { tree =>
- // Aggregate feature importance vector for this tree
- val importances = new OpenHashMap[Int, Double]()
- computeFeatureImportance(tree.rootNode, importances)
- // Normalize importance vector for this tree, and add it to total.
- // TODO: In the future, also support normalizing by tree.rootNode.impurityStats.count?
- val treeNorm = importances.map(_._2).sum
- if (treeNorm != 0) {
- importances.foreach { case (idx, impt) =>
- val normImpt = impt / treeNorm
- totalImportances.changeValue(idx, normImpt, _ + normImpt)
- }
- }
- }
- // Normalize importances
- normalizeMapValues(totalImportances)
- // Construct vector
- val d = if (numFeatures != -1) {
- numFeatures
- } else {
- // Find max feature index used in trees
- val maxFeatureIndex = trees.map(_.maxSplitFeatureIndex()).max
- maxFeatureIndex + 1
- }
- if (d == 0) {
- assert(totalImportances.size == 0, s"Unknown error in computing feature" +
- s" importance: No splits found, but some non-zero importances.")
- }
- val (indices, values) = totalImportances.iterator.toSeq.sortBy(_._1).unzip
- Vectors.sparse(d, indices.toArray, values.toArray)
- }
-
- /**
- * Given a Decision Tree model, compute the importance of each feature.
- * This generalizes the idea of "Gini" importance to other losses,
- * following the explanation of Gini importance from "Random Forests" documentation
- * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
- *
- * This feature importance is calculated as follows:
- * - importance(feature j) = sum (over nodes which split on feature j) of the gain,
- * where gain is scaled by the number of instances passing through node
- * - Normalize importances for tree to sum to 1.
- *
- * @param tree Decision tree to compute importances for.
- * @param numFeatures Number of features in model (even if not all are explicitly used by
- * the model).
- * If -1, then numFeatures is set based on the max feature index in all trees.
- * @return Feature importance values, of length numFeatures.
- */
- private[ml] def featureImportances(tree: DecisionTreeModel, numFeatures: Int): Vector = {
- featureImportances(Array(tree), numFeatures)
- }
-
- /**
- * Recursive method for computing feature importances for one tree.
- * This walks down the tree, adding to the importance of 1 feature at each node.
- * @param node Current node in recursion
- * @param importances Aggregate feature importances, modified by this method
- */
- private[impl] def computeFeatureImportance(
- node: Node,
- importances: OpenHashMap[Int, Double]): Unit = {
- node match {
- case n: InternalNode =>
- val feature = n.split.featureIndex
- val scaledGain = n.gain * n.impurityStats.count
- importances.changeValue(feature, scaledGain, _ + scaledGain)
- computeFeatureImportance(n.leftChild, importances)
- computeFeatureImportance(n.rightChild, importances)
- case n: LeafNode =>
- // do nothing
- }
- }
-
- /**
- * Normalize the values of this map to sum to 1, in place.
- * If all values are 0, this method does nothing.
- * @param map Map with non-negative values.
- */
- private[impl] def normalizeMapValues(map: OpenHashMap[Int, Double]): Unit = {
- val total = map.map(_._2).sum
- if (total != 0) {
- val keys = map.iterator.map(_._1).toArray
- keys.foreach { key => map.changeValue(key, 0.0, _ / total) }
- }
- }
-
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
index ef40c9068f..1fad9d6d8c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -27,6 +27,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
import org.apache.spark.sql.SQLContext
+import org.apache.spark.util.collection.OpenHashMap
/**
* Abstraction for Decision Tree models.
@@ -115,6 +116,125 @@ private[ml] trait TreeEnsembleModel {
lazy val totalNumNodes: Int = trees.map(_.numNodes).sum
}
+private[ml] object TreeEnsembleModel {
+
+ /**
+ * Given a tree ensemble model, compute the importance of each feature.
+ * This generalizes the idea of "Gini" importance to other losses,
+ * following the explanation of Gini importance from "Random Forests" documentation
+ * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
+ *
+ * For collections of trees, including boosting and bagging, Hastie et al.
+ * propose to use the average of single tree importances across all trees in the ensemble.
+ *
+ * This feature importance is calculated as follows:
+ * - Average over trees:
+ * - importance(feature j) = sum (over nodes which split on feature j) of the gain,
+ * where gain is scaled by the number of instances passing through node
+ * - Normalize importances for tree to sum to 1.
+ * - Normalize feature importance vector to sum to 1.
+ *
+ * References:
+ * - Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.
+ *
+ * @param trees Unweighted collection of trees
+ * @param numFeatures Number of features in model (even if not all are explicitly used by
+ * the model).
+ * If -1, then numFeatures is set based on the max feature index in all trees.
+ * @return Feature importance values, of length numFeatures.
+ */
+ def featureImportances(trees: Array[DecisionTreeModel], numFeatures: Int): Vector = {
+ val totalImportances = new OpenHashMap[Int, Double]()
+ trees.foreach { tree =>
+ // Aggregate feature importance vector for this tree
+ val importances = new OpenHashMap[Int, Double]()
+ computeFeatureImportance(tree.rootNode, importances)
+ // Normalize importance vector for this tree, and add it to total.
+ // TODO: In the future, also support normalizing by tree.rootNode.impurityStats.count?
+ val treeNorm = importances.map(_._2).sum
+ if (treeNorm != 0) {
+ importances.foreach { case (idx, impt) =>
+ val normImpt = impt / treeNorm
+ totalImportances.changeValue(idx, normImpt, _ + normImpt)
+ }
+ }
+ }
+ // Normalize importances
+ normalizeMapValues(totalImportances)
+ // Construct vector
+ val d = if (numFeatures != -1) {
+ numFeatures
+ } else {
+ // Find max feature index used in trees
+ val maxFeatureIndex = trees.map(_.maxSplitFeatureIndex()).max
+ maxFeatureIndex + 1
+ }
+ if (d == 0) {
+ assert(totalImportances.size == 0, s"Unknown error in computing feature" +
+ s" importance: No splits found, but some non-zero importances.")
+ }
+ val (indices, values) = totalImportances.iterator.toSeq.sortBy(_._1).unzip
+ Vectors.sparse(d, indices.toArray, values.toArray)
+ }
+
+ /**
+ * Given a Decision Tree model, compute the importance of each feature.
+ * This generalizes the idea of "Gini" importance to other losses,
+ * following the explanation of Gini importance from "Random Forests" documentation
+ * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
+ *
+ * This feature importance is calculated as follows:
+ * - importance(feature j) = sum (over nodes which split on feature j) of the gain,
+ * where gain is scaled by the number of instances passing through node
+ * - Normalize importances for tree to sum to 1.
+ *
+ * @param tree Decision tree to compute importances for.
+ * @param numFeatures Number of features in model (even if not all are explicitly used by
+ * the model).
+ * If -1, then numFeatures is set based on the max feature index in all trees.
+ * @return Feature importance values, of length numFeatures.
+ */
+ def featureImportances(tree: DecisionTreeModel, numFeatures: Int): Vector = {
+ featureImportances(Array(tree), numFeatures)
+ }
+
+ /**
+ * Recursive method for computing feature importances for one tree.
+ * This walks down the tree, adding to the importance of 1 feature at each node.
+ *
+ * @param node Current node in recursion
+ * @param importances Aggregate feature importances, modified by this method
+ */
+ def computeFeatureImportance(
+ node: Node,
+ importances: OpenHashMap[Int, Double]): Unit = {
+ node match {
+ case n: InternalNode =>
+ val feature = n.split.featureIndex
+ val scaledGain = n.gain * n.impurityStats.count
+ importances.changeValue(feature, scaledGain, _ + scaledGain)
+ computeFeatureImportance(n.leftChild, importances)
+ computeFeatureImportance(n.rightChild, importances)
+ case n: LeafNode =>
+ // do nothing
+ }
+ }
+
+ /**
+ * Normalize the values of this map to sum to 1, in place.
+ * If all values are 0, this method does nothing.
+ *
+ * @param map Map with non-negative values.
+ */
+ def normalizeMapValues(map: OpenHashMap[Int, Double]): Unit = {
+ val total = map.map(_._2).sum
+ if (total != 0) {
+ val keys = map.iterator.map(_._1).toArray
+ keys.foreach { key => map.changeValue(key, 0.0, _ / total) }
+ }
+ }
+}
+
/** Helper classes for tree model persistence */
private[ml] object DecisionTreeModelReadWrite {