aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-08-03 12:17:46 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-08-03 12:17:46 -0700
commitff9169a002f1b75231fd25b7d04157a912503038 (patch)
treeef57aa63ad02760806657e491a78f15f5daa7f66 /mllib
parent703e44bff19f4c394f6f9bff1ce9152cdc68c51e (diff)
downloadspark-ff9169a002f1b75231fd25b7d04157a912503038.tar.gz
spark-ff9169a002f1b75231fd25b7d04157a912503038.tar.bz2
spark-ff9169a002f1b75231fd25b7d04157a912503038.zip
[SPARK-5133] [ML] Added featureImportance to RandomForestClassifier and Regressor
Added featureImportance to RandomForestClassifier and Regressor. This follows the scikit-learn implementation here: [https://github.com/scikit-learn/scikit-learn/blob/a95203b249c1cf392f86d001ad999e29b2392739/sklearn/tree/_tree.pyx#L3341] CC: yanboliang Would you mind taking a look? Thanks! Author: Joseph K. Bradley <joseph@databricks.com> Author: Feynman Liang <fliang@databricks.com> Closes #7838 from jkbradley/dt-feature-importance and squashes the following commits: 72a167a [Joseph K. Bradley] fixed unit test 86cea5f [Joseph K. Bradley] Modified RF featuresImportances to return Vector instead of Map 5aa74f0 [Joseph K. Bradley] finally fixed unit test for real 33df5db [Joseph K. Bradley] fix unit test 42a2d3b [Joseph K. Bradley] fix unit test fe94e72 [Joseph K. Bradley] modified feature importance unit tests cc693ee [Feynman Liang] Add classifier tests 79a6f87 [Feynman Liang] Compare dense vectors in test 21d01fc [Feynman Liang] Added failing SKLearn test ac0b254 [Joseph K. Bradley] Added featureImportance to RandomForestClassifier/Regressor. Need to add unit tests
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala30
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala33
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala19
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala92
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala6
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java2
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala31
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala18
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala27
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala107
11 files changed, 351 insertions, 16 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 56e80cc8fe..b59826a594 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -95,7 +95,8 @@ final class RandomForestClassifier(override val uid: String)
val trees =
RandomForest.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed)
.map(_.asInstanceOf[DecisionTreeClassificationModel])
- new RandomForestClassificationModel(trees, numClasses)
+ val numFeatures = oldDataset.first().features.size
+ new RandomForestClassificationModel(trees, numFeatures, numClasses)
}
override def copy(extra: ParamMap): RandomForestClassifier = defaultCopy(extra)
@@ -118,11 +119,13 @@ object RandomForestClassifier {
* features.
* @param _trees Decision trees in the ensemble.
* Warning: These have null parents.
+ * @param numFeatures Number of features used by this model
*/
@Experimental
final class RandomForestClassificationModel private[ml] (
override val uid: String,
private val _trees: Array[DecisionTreeClassificationModel],
+ val numFeatures: Int,
override val numClasses: Int)
extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel]
with TreeEnsembleModel with Serializable {
@@ -133,8 +136,8 @@ final class RandomForestClassificationModel private[ml] (
* Construct a random forest classification model, with all trees weighted equally.
* @param trees Component trees
*/
- def this(trees: Array[DecisionTreeClassificationModel], numClasses: Int) =
- this(Identifiable.randomUID("rfc"), trees, numClasses)
+ def this(trees: Array[DecisionTreeClassificationModel], numFeatures: Int, numClasses: Int) =
+ this(Identifiable.randomUID("rfc"), trees, numFeatures, numClasses)
override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
@@ -182,13 +185,30 @@ final class RandomForestClassificationModel private[ml] (
}
override def copy(extra: ParamMap): RandomForestClassificationModel = {
- copyValues(new RandomForestClassificationModel(uid, _trees, numClasses), extra)
+ copyValues(new RandomForestClassificationModel(uid, _trees, numFeatures, numClasses), extra)
}
override def toString: String = {
s"RandomForestClassificationModel with $numTrees trees"
}
+ /**
+ * Estimate of the importance of each feature.
+ *
+ * This generalizes the idea of "Gini" importance to other losses,
+ * following the explanation of Gini importance from "Random Forests" documentation
+ * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
+ *
+ * This feature importance is calculated as follows:
+ * - Average over trees:
+ * - importance(feature j) = sum (over nodes which split on feature j) of the gain,
+ * where gain is scaled by the number of instances passing through node
+ * - Normalize importances for tree based on total number of training instances used
+ * to build tree.
+ * - Normalize feature importance vector to sum to 1.
+ */
+ lazy val featureImportances: Vector = RandomForest.featureImportances(trees, numFeatures)
+
/** (private[ml]) Convert to a model in the old API */
private[ml] def toOld: OldRandomForestModel = {
new OldRandomForestModel(OldAlgo.Classification, _trees.map(_.toOld))
@@ -210,6 +230,6 @@ private[ml] object RandomForestClassificationModel {
DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures)
}
val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfc")
- new RandomForestClassificationModel(uid, newTrees, numClasses)
+ new RandomForestClassificationModel(uid, newTrees, -1, numClasses)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index 17fb1ad5e1..1ee43c8725 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -30,7 +30,7 @@ import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestMo
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.DoubleType
+
/**
* :: Experimental ::
@@ -87,7 +87,8 @@ final class RandomForestRegressor(override val uid: String)
val trees =
RandomForest.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed)
.map(_.asInstanceOf[DecisionTreeRegressionModel])
- new RandomForestRegressionModel(trees)
+ val numFeatures = oldDataset.first().features.size
+ new RandomForestRegressionModel(trees, numFeatures)
}
override def copy(extra: ParamMap): RandomForestRegressor = defaultCopy(extra)
@@ -108,11 +109,13 @@ object RandomForestRegressor {
* [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for regression.
* It supports both continuous and categorical features.
* @param _trees Decision trees in the ensemble.
+ * @param numFeatures Number of features used by this model
*/
@Experimental
final class RandomForestRegressionModel private[ml] (
override val uid: String,
- private val _trees: Array[DecisionTreeRegressionModel])
+ private val _trees: Array[DecisionTreeRegressionModel],
+ val numFeatures: Int)
extends PredictionModel[Vector, RandomForestRegressionModel]
with TreeEnsembleModel with Serializable {
@@ -122,7 +125,8 @@ final class RandomForestRegressionModel private[ml] (
* Construct a random forest regression model, with all trees weighted equally.
* @param trees Component trees
*/
- def this(trees: Array[DecisionTreeRegressionModel]) = this(Identifiable.randomUID("rfr"), trees)
+ def this(trees: Array[DecisionTreeRegressionModel], numFeatures: Int) =
+ this(Identifiable.randomUID("rfr"), trees, numFeatures)
override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
@@ -147,13 +151,30 @@ final class RandomForestRegressionModel private[ml] (
}
override def copy(extra: ParamMap): RandomForestRegressionModel = {
- copyValues(new RandomForestRegressionModel(uid, _trees), extra)
+ copyValues(new RandomForestRegressionModel(uid, _trees, numFeatures), extra)
}
override def toString: String = {
s"RandomForestRegressionModel with $numTrees trees"
}
+ /**
+ * Estimate of the importance of each feature.
+ *
+ * This generalizes the idea of "Gini" importance to other losses,
+ * following the explanation of Gini importance from "Random Forests" documentation
+ * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
+ *
+ * This feature importance is calculated as follows:
+ * - Average over trees:
+ * - importance(feature j) = sum (over nodes which split on feature j) of the gain,
+ * where gain is scaled by the number of instances passing through node
+ * - Normalize importances for tree based on total number of training instances used
+ * to build tree.
+ * - Normalize feature importance vector to sum to 1.
+ */
+ lazy val featureImportances: Vector = RandomForest.featureImportances(trees, numFeatures)
+
/** (private[ml]) Convert to a model in the old API */
private[ml] def toOld: OldRandomForestModel = {
new OldRandomForestModel(OldAlgo.Regression, _trees.map(_.toOld))
@@ -173,6 +194,6 @@ private[ml] object RandomForestRegressionModel {
// parent for each tree is null since there is no good way to set this.
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
- new RandomForestRegressionModel(parent.uid, newTrees)
+ new RandomForestRegressionModel(parent.uid, newTrees, -1)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
index 8879352a60..cd24931293 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
@@ -44,7 +44,7 @@ sealed abstract class Node extends Serializable {
* and probabilities.
* For classification, the array of class counts must be normalized to a probability distribution.
*/
- private[tree] def impurityStats: ImpurityCalculator
+ private[ml] def impurityStats: ImpurityCalculator
/** Recursive prediction helper method */
private[ml] def predictImpl(features: Vector): LeafNode
@@ -72,6 +72,12 @@ sealed abstract class Node extends Serializable {
* @param id Node ID using old format IDs
*/
private[ml] def toOld(id: Int): OldNode
+
+ /**
+ * Trace down the tree, and return the largest feature index used in any split.
+ * @return Max feature index used in a split, or -1 if there are no splits (single leaf node).
+ */
+ private[ml] def maxSplitFeatureIndex(): Int
}
private[ml] object Node {
@@ -109,7 +115,7 @@ private[ml] object Node {
final class LeafNode private[ml] (
override val prediction: Double,
override val impurity: Double,
- override val impurityStats: ImpurityCalculator) extends Node {
+ override private[ml] val impurityStats: ImpurityCalculator) extends Node {
override def toString: String =
s"LeafNode(prediction = $prediction, impurity = $impurity)"
@@ -129,6 +135,8 @@ final class LeafNode private[ml] (
new OldNode(id, new OldPredict(prediction, prob = impurityStats.prob(prediction)),
impurity, isLeaf = true, None, None, None, None)
}
+
+ override private[ml] def maxSplitFeatureIndex(): Int = -1
}
/**
@@ -150,7 +158,7 @@ final class InternalNode private[ml] (
val leftChild: Node,
val rightChild: Node,
val split: Split,
- override val impurityStats: ImpurityCalculator) extends Node {
+ override private[ml] val impurityStats: ImpurityCalculator) extends Node {
override def toString: String = {
s"InternalNode(prediction = $prediction, impurity = $impurity, split = $split)"
@@ -190,6 +198,11 @@ final class InternalNode private[ml] (
new OldPredict(leftChild.prediction, prob = 0.0),
new OldPredict(rightChild.prediction, prob = 0.0))))
}
+
+ override private[ml] def maxSplitFeatureIndex(): Int = {
+ math.max(split.featureIndex,
+ math.max(leftChild.maxSplitFeatureIndex(), rightChild.maxSplitFeatureIndex()))
+ }
}
private object InternalNode {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index a8b90d9d26..4ac51a4754 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -26,6 +26,7 @@ import org.apache.spark.Logging
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree._
+import org.apache.spark.mllib.linalg.{Vectors, Vector}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.impl.{BaggedPoint, DTStatsAggregator, DecisionTreeMetadata,
@@ -34,6 +35,7 @@ import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
import org.apache.spark.mllib.tree.model.ImpurityStats
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.collection.OpenHashMap
import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom}
@@ -1113,4 +1115,94 @@ private[ml] object RandomForest extends Logging {
}
}
+ /**
+ * Given a Random Forest model, compute the importance of each feature.
+ * This generalizes the idea of "Gini" importance to other losses,
+ * following the explanation of Gini importance from "Random Forests" documentation
+ * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
+ *
+ * This feature importance is calculated as follows:
+ * - Average over trees:
+ * - importance(feature j) = sum (over nodes which split on feature j) of the gain,
+ * where gain is scaled by the number of instances passing through node
+ * - Normalize importances for tree based on total number of training instances used
+ * to build tree.
+ * - Normalize feature importance vector to sum to 1.
+ *
+ * Note: This should not be used with Gradient-Boosted Trees. It only makes sense for
+ * independently trained trees.
+ * @param trees Unweighted forest of trees
+ * @param numFeatures Number of features in model (even if not all are explicitly used by
+ * the model).
+ * If -1, then numFeatures is set based on the max feature index in all trees.
+ * @return Feature importance values, of length numFeatures.
+ */
+ private[ml] def featureImportances(trees: Array[DecisionTreeModel], numFeatures: Int): Vector = {
+ val totalImportances = new OpenHashMap[Int, Double]()
+ trees.foreach { tree =>
+ // Aggregate feature importance vector for this tree
+ val importances = new OpenHashMap[Int, Double]()
+ computeFeatureImportance(tree.rootNode, importances)
+ // Normalize importance vector for this tree, and add it to total.
+ // TODO: In the future, also support normalizing by tree.rootNode.impurityStats.count?
+ val treeNorm = importances.map(_._2).sum
+ if (treeNorm != 0) {
+ importances.foreach { case (idx, impt) =>
+ val normImpt = impt / treeNorm
+ totalImportances.changeValue(idx, normImpt, _ + normImpt)
+ }
+ }
+ }
+ // Normalize importances
+ normalizeMapValues(totalImportances)
+ // Construct vector
+ val d = if (numFeatures != -1) {
+ numFeatures
+ } else {
+ // Find max feature index used in trees
+ val maxFeatureIndex = trees.map(_.maxSplitFeatureIndex()).max
+ maxFeatureIndex + 1
+ }
+ if (d == 0) {
+ assert(totalImportances.size == 0, s"Unknown error in computing RandomForest feature" +
+ s" importance: No splits in forest, but some non-zero importances.")
+ }
+ val (indices, values) = totalImportances.iterator.toSeq.sortBy(_._1).unzip
+ Vectors.sparse(d, indices.toArray, values.toArray)
+ }
+
+ /**
+ * Recursive method for computing feature importances for one tree.
+ * This walks down the tree, adding to the importance of 1 feature at each node.
+ * @param node Current node in recursion
+ * @param importances Aggregate feature importances, modified by this method
+ */
+ private[impl] def computeFeatureImportance(
+ node: Node,
+ importances: OpenHashMap[Int, Double]): Unit = {
+ node match {
+ case n: InternalNode =>
+ val feature = n.split.featureIndex
+ val scaledGain = n.gain * n.impurityStats.count
+ importances.changeValue(feature, scaledGain, _ + scaledGain)
+ computeFeatureImportance(n.leftChild, importances)
+ computeFeatureImportance(n.rightChild, importances)
+ case n: LeafNode =>
+ // do nothing
+ }
+ }
+
+ /**
+ * Normalize the values of this map to sum to 1, in place.
+ * If all values are 0, this method does nothing.
+ * @param map Map with non-negative values.
+ */
+ private[impl] def normalizeMapValues(map: OpenHashMap[Int, Double]): Unit = {
+ val total = map.map(_._2).sum
+ if (total != 0) {
+ val keys = map.iterator.map(_._1).toArray
+ keys.foreach { key => map.changeValue(key, 0.0, _ / total) }
+ }
+ }
+
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
index 22873909c3..b77191156f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -53,6 +53,12 @@ private[ml] trait DecisionTreeModel {
val header = toString + "\n"
header + rootNode.subtreeToString(2)
}
+
+ /**
+ * Trace down the tree, and return the largest feature index used in any split.
+ * @return Max feature index used in a split, or -1 if there are no splits (single leaf node).
+ */
+ private[ml] def maxSplitFeatureIndex(): Int = rootNode.maxSplitFeatureIndex()
}
/**
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
index 32d0b3856b..a66a1e1292 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
@@ -29,6 +29,7 @@ import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.impl.TreeTests;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
+import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.DataFrame;
@@ -85,6 +86,7 @@ public class JavaRandomForestClassifierSuite implements Serializable {
model.toDebugString();
model.trees();
model.treeWeights();
+ Vector importances = model.featureImportances();
/*
// TODO: Add test once save/load are implemented. SPARK-6725
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
index e306ebadfe..a00ce5e249 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
@@ -29,6 +29,7 @@ import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.ml.impl.TreeTests;
+import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.DataFrame;
@@ -85,6 +86,7 @@ public class JavaRandomForestRegressorSuite implements Serializable {
model.toDebugString();
model.trees();
model.treeWeights();
+ Vector importances = model.featureImportances();
/*
// TODO: Add test once save/load are implemented. SPARK-6725
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index edf848b21a..6ca4b5aa5f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -67,7 +67,7 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
test("params") {
ParamsSuite.checkParams(new RandomForestClassifier)
val model = new RandomForestClassificationModel("rfc",
- Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2)), 2)
+ Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2)), 2, 2)
ParamsSuite.checkParams(model)
}
@@ -150,6 +150,35 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
}
/////////////////////////////////////////////////////////////////////////////
+ // Tests of feature importance
+ /////////////////////////////////////////////////////////////////////////////
+ test("Feature importance with toy data") {
+ val numClasses = 2
+ val rf = new RandomForestClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(3)
+ .setNumTrees(3)
+ .setFeatureSubsetStrategy("all")
+ .setSubsamplingRate(1.0)
+ .setSeed(123)
+
+ // In this data, feature 1 is very important.
+ val data: RDD[LabeledPoint] = sc.parallelize(Seq(
+ new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 1)),
+ new LabeledPoint(1, Vectors.dense(1, 1, 0, 1, 0)),
+ new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)),
+ new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 0)),
+ new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0))
+ ))
+ val categoricalFeatures = Map.empty[Int, Int]
+ val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses)
+
+ val importances = rf.fit(df).featureImportances
+ val mostImportantFeature = importances.argmax
+ assert(mostImportantFeature === 1)
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
diff --git a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
index 778abcba22..460849c79f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
@@ -124,4 +124,22 @@ private[ml] object TreeTests extends SparkFunSuite {
"checkEqual failed since the two tree ensembles were not identical")
}
}
+
+ /**
+ * Helper method for constructing a tree for testing.
+ * Given left, right children, construct a parent node.
+ * @param split Split for parent node
+ * @return Parent node with children attached
+ */
+ def buildParentNode(left: Node, right: Node, split: Split): Node = {
+ val leftImp = left.impurityStats
+ val rightImp = right.impurityStats
+ val parentImp = leftImp.copy.add(rightImp)
+ val leftWeight = leftImp.count / parentImp.count.toDouble
+ val rightWeight = rightImp.count / parentImp.count.toDouble
+ val gain = parentImp.calculate() -
+ (leftWeight * leftImp.calculate() + rightWeight * rightImp.calculate())
+ val pred = parentImp.predict
+ new InternalNode(pred, parentImp.calculate(), gain, left, right, split, parentImp)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
index b24ecaa57c..992ce95624 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
+import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
@@ -26,7 +27,6 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
-
/**
* Test suite for [[RandomForestRegressor]].
*/
@@ -71,6 +71,31 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex
regressionTestWithContinuousFeatures(rf)
}
+ test("Feature importance with toy data") {
+ val rf = new RandomForestRegressor()
+ .setImpurity("variance")
+ .setMaxDepth(3)
+ .setNumTrees(3)
+ .setFeatureSubsetStrategy("all")
+ .setSubsamplingRate(1.0)
+ .setSeed(123)
+
+ // In this data, feature 1 is very important.
+ val data: RDD[LabeledPoint] = sc.parallelize(Seq(
+ new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 1)),
+ new LabeledPoint(1, Vectors.dense(1, 1, 0, 1, 0)),
+ new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)),
+ new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 0)),
+ new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0))
+ ))
+ val categoricalFeatures = Map.empty[Int, Int]
+ val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0)
+
+ val importances = rf.fit(df).featureImportances
+ val mostImportantFeature = importances.argmax
+ assert(mostImportantFeature === 1)
+ }
+
/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
new file mode 100644
index 0000000000..dc852795c7
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree.impl
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.classification.DecisionTreeClassificationModel
+import org.apache.spark.ml.impl.TreeTests
+import org.apache.spark.ml.tree.{ContinuousSplit, DecisionTreeModel, LeafNode, Node}
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.tree.impurity.GiniCalculator
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.util.collection.OpenHashMap
+
+/**
+ * Test suite for [[RandomForest]].
+ */
+class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+ import RandomForestSuite.mapToVec
+
+ test("computeFeatureImportance, featureImportances") {
+ /* Build tree for testing, with this structure:
+ grandParent
+ left2 parent
+ left right
+ */
+ val leftImp = new GiniCalculator(Array(3.0, 2.0, 1.0))
+ val left = new LeafNode(0.0, leftImp.calculate(), leftImp)
+
+ val rightImp = new GiniCalculator(Array(1.0, 2.0, 5.0))
+ val right = new LeafNode(2.0, rightImp.calculate(), rightImp)
+
+ val parent = TreeTests.buildParentNode(left, right, new ContinuousSplit(0, 0.5))
+ val parentImp = parent.impurityStats
+
+ val left2Imp = new GiniCalculator(Array(1.0, 6.0, 1.0))
+ val left2 = new LeafNode(0.0, left2Imp.calculate(), left2Imp)
+
+ val grandParent = TreeTests.buildParentNode(left2, parent, new ContinuousSplit(1, 1.0))
+ val grandImp = grandParent.impurityStats
+
+ // Test feature importance computed at different subtrees.
+ def testNode(node: Node, expected: Map[Int, Double]): Unit = {
+ val map = new OpenHashMap[Int, Double]()
+ RandomForest.computeFeatureImportance(node, map)
+ assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
+ }
+
+ // Leaf node
+ testNode(left, Map.empty[Int, Double])
+
+ // Internal node with 2 leaf children
+ val feature0importance = parentImp.calculate() * parentImp.count -
+ (leftImp.calculate() * leftImp.count + rightImp.calculate() * rightImp.count)
+ testNode(parent, Map(0 -> feature0importance))
+
+ // Full tree
+ val feature1importance = grandImp.calculate() * grandImp.count -
+ (left2Imp.calculate() * left2Imp.count + parentImp.calculate() * parentImp.count)
+ testNode(grandParent, Map(0 -> feature0importance, 1 -> feature1importance))
+
+ // Forest consisting of (full tree) + (internal node with 2 leafs)
+ val trees = Array(parent, grandParent).map { root =>
+ new DecisionTreeClassificationModel(root, numClasses = 3).asInstanceOf[DecisionTreeModel]
+ }
+ val importances: Vector = RandomForest.featureImportances(trees, 2)
+ val tree2norm = feature0importance + feature1importance
+ val expected = Vectors.dense((1.0 + feature0importance / tree2norm) / 2.0,
+ (feature1importance / tree2norm) / 2.0)
+ assert(importances ~== expected relTol 0.01)
+ }
+
+ test("normalizeMapValues") {
+ val map = new OpenHashMap[Int, Double]()
+ map(0) = 1.0
+ map(2) = 2.0
+ RandomForest.normalizeMapValues(map)
+ val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0)
+ assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
+ }
+
+}
+
+private object RandomForestSuite {
+
+ def mapToVec(map: Map[Int, Double]): Vector = {
+ val size = (map.keys.toSeq :+ 0).max + 1
+ val (indices, values) = map.toSeq.sortBy(_._1).unzip
+ Vectors.sparse(size, indices.toArray, values.toArray)
+ }
+}