aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala114
1 files changed, 1 insertions, 113 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index 7774ae64e5..7b1fd089f2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -26,16 +26,12 @@ import org.apache.spark.internal.Logging
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree._
-import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
-import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, DTStatsAggregator,
- TimeTracker}
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
import org.apache.spark.mllib.tree.model.ImpurityStats
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
-import org.apache.spark.util.collection.OpenHashMap
import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom}
@@ -332,7 +328,7 @@ private[spark] object RandomForest extends Logging {
/**
* Given a group of nodes, this finds the best split for each node.
*
- * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]]
+ * @param input Training data: RDD of [[org.apache.spark.ml.tree.impl.TreePoint]]
* @param metadata Learning and dataset metadata
* @param topNodes Root node for each tree. Used for matching instances with nodes.
* @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree
@@ -1105,112 +1101,4 @@ private[spark] object RandomForest extends Logging {
}
}
- /**
- * Given a Random Forest model, compute the importance of each feature.
- * This generalizes the idea of "Gini" importance to other losses,
- * following the explanation of Gini importance from "Random Forests" documentation
- * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
- *
- * This feature importance is calculated as follows:
- * - Average over trees:
- * - importance(feature j) = sum (over nodes which split on feature j) of the gain,
- * where gain is scaled by the number of instances passing through node
- * - Normalize importances for tree to sum to 1.
- * - Normalize feature importance vector to sum to 1.
- *
- * @param trees Unweighted forest of trees
- * @param numFeatures Number of features in model (even if not all are explicitly used by
- * the model).
- * If -1, then numFeatures is set based on the max feature index in all trees.
- * @return Feature importance values, of length numFeatures.
- */
- private[ml] def featureImportances(trees: Array[DecisionTreeModel], numFeatures: Int): Vector = {
- val totalImportances = new OpenHashMap[Int, Double]()
- trees.foreach { tree =>
- // Aggregate feature importance vector for this tree
- val importances = new OpenHashMap[Int, Double]()
- computeFeatureImportance(tree.rootNode, importances)
- // Normalize importance vector for this tree, and add it to total.
- // TODO: In the future, also support normalizing by tree.rootNode.impurityStats.count?
- val treeNorm = importances.map(_._2).sum
- if (treeNorm != 0) {
- importances.foreach { case (idx, impt) =>
- val normImpt = impt / treeNorm
- totalImportances.changeValue(idx, normImpt, _ + normImpt)
- }
- }
- }
- // Normalize importances
- normalizeMapValues(totalImportances)
- // Construct vector
- val d = if (numFeatures != -1) {
- numFeatures
- } else {
- // Find max feature index used in trees
- val maxFeatureIndex = trees.map(_.maxSplitFeatureIndex()).max
- maxFeatureIndex + 1
- }
- if (d == 0) {
- assert(totalImportances.size == 0, s"Unknown error in computing feature" +
- s" importance: No splits found, but some non-zero importances.")
- }
- val (indices, values) = totalImportances.iterator.toSeq.sortBy(_._1).unzip
- Vectors.sparse(d, indices.toArray, values.toArray)
- }
-
- /**
- * Given a Decision Tree model, compute the importance of each feature.
- * This generalizes the idea of "Gini" importance to other losses,
- * following the explanation of Gini importance from "Random Forests" documentation
- * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
- *
- * This feature importance is calculated as follows:
- * - importance(feature j) = sum (over nodes which split on feature j) of the gain,
- * where gain is scaled by the number of instances passing through node
- * - Normalize importances for tree to sum to 1.
- *
- * @param tree Decision tree to compute importances for.
- * @param numFeatures Number of features in model (even if not all are explicitly used by
- * the model).
- * If -1, then numFeatures is set based on the max feature index in all trees.
- * @return Feature importance values, of length numFeatures.
- */
- private[ml] def featureImportances(tree: DecisionTreeModel, numFeatures: Int): Vector = {
- featureImportances(Array(tree), numFeatures)
- }
-
- /**
- * Recursive method for computing feature importances for one tree.
- * This walks down the tree, adding to the importance of 1 feature at each node.
- * @param node Current node in recursion
- * @param importances Aggregate feature importances, modified by this method
- */
- private[impl] def computeFeatureImportance(
- node: Node,
- importances: OpenHashMap[Int, Double]): Unit = {
- node match {
- case n: InternalNode =>
- val feature = n.split.featureIndex
- val scaledGain = n.gain * n.impurityStats.count
- importances.changeValue(feature, scaledGain, _ + scaledGain)
- computeFeatureImportance(n.leftChild, importances)
- computeFeatureImportance(n.rightChild, importances)
- case n: LeafNode =>
- // do nothing
- }
- }
-
- /**
- * Normalize the values of this map to sum to 1, in place.
- * If all values are 0, this method does nothing.
- * @param map Map with non-negative values.
- */
- private[impl] def normalizeMapValues(map: OpenHashMap[Int, Double]): Unit = {
- val total = map.map(_._2).sum
- if (total != 0) {
- val keys = map.iterator.map(_._1).toArray
- keys.foreach { key => map.changeValue(key, 0.0, _ / total) }
- }
- }
-
}