aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorsethah <seth.hendrickson16@gmail.com>2016-03-09 14:44:51 -0800
committerJoseph K. Bradley <joseph@databricks.com>2016-03-09 14:44:51 -0800
commite1772d3f19bed7e69a80de7900ed22d3eeb05300 (patch)
tree9db2d2a2b3ac0786141cc51790dc4de0f8e307c5 /mllib
parentc6aa356cd831ea2d159568b699bd5b791f3d8f25 (diff)
downloadspark-e1772d3f19bed7e69a80de7900ed22d3eeb05300.tar.gz
spark-e1772d3f19bed7e69a80de7900ed22d3eeb05300.tar.bz2
spark-e1772d3f19bed7e69a80de7900ed22d3eeb05300.zip
[SPARK-11861][ML] Add feature importances for decision trees
This patch adds an API entry point for single decision tree feature importances. Author: sethah <seth.hendrickson16@gmail.com> Closes #9912 from sethah/SPARK-11861.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala19
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala19
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala30
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala21
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala10
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala13
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala20
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala13
10 files changed, 126 insertions, 27 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 8c4cec1326..7f0397f6bd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -169,6 +169,25 @@ final class DecisionTreeClassificationModel private[ml] (
s"DecisionTreeClassificationModel (uid=$uid) of depth $depth with $numNodes nodes"
}
+ /**
+ * Estimate of the importance of each feature.
+ *
+ * This generalizes the idea of "Gini" importance to other losses,
+ * following the explanation of Gini importance from "Random Forests" documentation
+ * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
+ *
+ * This feature importance is calculated as follows:
+ * - importance(feature j) = sum (over nodes which split on feature j) of the gain,
+ * where gain is scaled by the number of instances passing through node
+ * - Normalize importances for tree to sum to 1.
+ *
+ * Note: Feature importance for single decision trees can have high variance due to
+ * correlated predictor variables. Consider using a [[RandomForestClassifier]]
+ * to determine feature importance instead.
+ */
+ @Since("2.0.0")
+ lazy val featureImportances: Vector = RandomForest.featureImportances(this, numFeatures)
+
/** (private[ml]) Convert to a model in the old API */
private[ml] def toOld: OldDecisionTreeModel = {
new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Classification)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index f7d662df2f..5da04d341d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -230,10 +230,10 @@ final class RandomForestClassificationModel private[ml] (
* - Average over trees:
* - importance(feature j) = sum (over nodes which split on feature j) of the gain,
* where gain is scaled by the number of instances passing through node
- * - Normalize importances for tree based on total number of training instances used
- * to build tree.
+ * - Normalize importances for tree to sum to 1.
* - Normalize feature importance vector to sum to 1.
*/
+ @Since("1.5.0")
lazy val featureImportances: Vector = RandomForest.featureImportances(trees, numFeatures)
/** (private[ml]) Convert to a model in the old API */
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index 18c94f3638..897b23383c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -169,6 +169,25 @@ final class DecisionTreeRegressionModel private[ml] (
s"DecisionTreeRegressionModel (uid=$uid) of depth $depth with $numNodes nodes"
}
+ /**
+ * Estimate of the importance of each feature.
+ *
+ * This generalizes the idea of "Gini" importance to other losses,
+ * following the explanation of Gini importance from "Random Forests" documentation
+ * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
+ *
+ * This feature importance is calculated as follows:
+ * - importance(feature j) = sum (over nodes which split on feature j) of the gain,
+ * where gain is scaled by the number of instances passing through node
+ * - Normalize importances for tree to sum to 1.
+ *
+ * Note: Feature importance for single decision trees can have high variance due to
+ * correlated predictor variables. Consider using a [[RandomForestRegressor]]
+ * to determine feature importance instead.
+ */
+ @Since("2.0.0")
+ lazy val featureImportances: Vector = RandomForest.featureImportances(this, numFeatures)
+
/** Convert to a model in the old API */
private[ml] def toOld: OldDecisionTreeModel = {
new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Regression)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index 71e40b513e..798947b94a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -189,10 +189,10 @@ final class RandomForestRegressionModel private[ml] (
* - Average over trees:
* - importance(feature j) = sum (over nodes which split on feature j) of the gain,
* where gain is scaled by the number of instances passing through node
- * - Normalize importances for tree based on total number of training instances used
- * to build tree.
+ * - Normalize importances for tree to sum to 1.
* - Normalize feature importance vector to sum to 1.
*/
+ @Since("1.5.0")
lazy val featureImportances: Vector = RandomForest.featureImportances(trees, numFeatures)
/** (private[ml]) Convert to a model in the old API */
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index ea733d577a..f994c258b2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -1089,12 +1089,9 @@ private[ml] object RandomForest extends Logging {
* - Average over trees:
* - importance(feature j) = sum (over nodes which split on feature j) of the gain,
* where gain is scaled by the number of instances passing through node
- * - Normalize importances for tree based on total number of training instances used
- * to build tree.
+ * - Normalize importances for tree to sum to 1.
* - Normalize feature importance vector to sum to 1.
*
- * Note: This should not be used with Gradient-Boosted Trees. It only makes sense for
- * independently trained trees.
* @param trees Unweighted forest of trees
* @param numFeatures Number of features in model (even if not all are explicitly used by
* the model).
@@ -1128,14 +1125,35 @@ private[ml] object RandomForest extends Logging {
maxFeatureIndex + 1
}
if (d == 0) {
- assert(totalImportances.size == 0, s"Unknown error in computing RandomForest feature" +
- s" importance: No splits in forest, but some non-zero importances.")
+ assert(totalImportances.size == 0, s"Unknown error in computing feature" +
+ s" importance: No splits found, but some non-zero importances.")
}
val (indices, values) = totalImportances.iterator.toSeq.sortBy(_._1).unzip
Vectors.sparse(d, indices.toArray, values.toArray)
}
/**
+ * Given a Decision Tree model, compute the importance of each feature.
+ * This generalizes the idea of "Gini" importance to other losses,
+ * following the explanation of Gini importance from "Random Forests" documentation
+ * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
+ *
+ * This feature importance is calculated as follows:
+ * - importance(feature j) = sum (over nodes which split on feature j) of the gain,
+ * where gain is scaled by the number of instances passing through node
+ * - Normalize importances for tree to sum to 1.
+ *
+ * @param tree Decision tree to compute importances for.
+ * @param numFeatures Number of features in model (even if not all are explicitly used by
+ * the model).
+ * If -1, then numFeatures is set based on the max feature index in all trees.
+ * @return Feature importance values, of length numFeatures.
+ */
+ private[ml] def featureImportances(tree: DecisionTreeModel, numFeatures: Int): Vector = {
+ featureImportances(Array(tree), numFeatures)
+ }
+
+ /**
* Recursive method for computing feature importances for one tree.
* This walks down the tree, adding to the importance of 1 feature at each node.
* @param node Current node in recursion
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index 9169bcd390..6d68364499 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -313,6 +313,27 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
}
}
+ test("Feature importance with toy data") {
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("gini")
+ .setMaxDepth(3)
+ .setSeed(123)
+
+ // In this data, feature 1 is very important.
+ val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
+ val numFeatures = data.first().features.size
+ val categoricalFeatures = (0 to numFeatures).map(i => (i, 2)).toMap
+ val df = TreeTests.setMetadata(data, categoricalFeatures, 2)
+
+ val model = dt.fit(df)
+
+ val importances = model.featureImportances
+ val mostImportantFeature = importances.argmax
+ assert(mostImportantFeature === 1)
+ assert(importances.toArray.sum === 1.0)
+ assert(importances.toArray.forall(_ >= 0.0))
+ }
+
/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index deb8ec771c..6b810ab9ea 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -167,19 +167,15 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
.setSeed(123)
// In this data, feature 1 is very important.
- val data: RDD[LabeledPoint] = sc.parallelize(Seq(
- new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 1)),
- new LabeledPoint(1, Vectors.dense(1, 1, 0, 1, 0)),
- new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)),
- new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 0)),
- new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0))
- ))
+ val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
val categoricalFeatures = Map.empty[Int, Int]
val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses)
val importances = rf.fit(df).featureImportances
val mostImportantFeature = importances.argmax
assert(mostImportantFeature === 1)
+ assert(importances.toArray.sum === 1.0)
+ assert(importances.toArray.forall(_ >= 0.0))
}
/////////////////////////////////////////////////////////////////////////////
diff --git a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
index a808177cb9..5561f6f0ef 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
@@ -19,10 +19,12 @@ package org.apache.spark.ml.impl
import scala.collection.JavaConverters._
+import org.apache.spark.SparkContext
import org.apache.spark.SparkFunSuite
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute}
import org.apache.spark.ml.tree._
+import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SQLContext}
@@ -141,4 +143,15 @@ private[ml] object TreeTests extends SparkFunSuite {
val pred = parentImp.predict
new InternalNode(pred, parentImp.calculate(), gain, left, right, split, parentImp)
}
+
+ /**
+ * Create some toy data for testing feature importances.
+ */
+ def featureImportanceData(sc: SparkContext): RDD[LabeledPoint] = sc.parallelize(Seq(
+ new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 1)),
+ new LabeledPoint(1, Vectors.dense(1, 1, 0, 1, 0)),
+ new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)),
+ new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 0)),
+ new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0))
+ ))
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index 13165f6701..56b335a33a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -96,6 +96,26 @@ class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContex
}
}
+ test("Feature importance with toy data") {
+ val dt = new DecisionTreeRegressor()
+ .setImpurity("variance")
+ .setMaxDepth(3)
+ .setSeed(123)
+
+ // In this data, feature 1 is very important.
+ val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
+ val categoricalFeatures = Map.empty[Int, Int]
+ val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0)
+
+ val model = dt.fit(df)
+
+ val importances = model.featureImportances
+ val mostImportantFeature = importances.argmax
+ assert(mostImportantFeature === 1)
+ assert(importances.toArray.sum === 1.0)
+ assert(importances.toArray.forall(_ >= 0.0))
+ }
+
/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
index 7e751e4b55..efb117f8f9 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
@@ -20,7 +20,6 @@ package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.ml.util.MLTestingUtils
-import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
@@ -82,23 +81,17 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex
.setSeed(123)
// In this data, feature 1 is very important.
- val data: RDD[LabeledPoint] = sc.parallelize(Seq(
- new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 1)),
- new LabeledPoint(1, Vectors.dense(1, 1, 0, 1, 0)),
- new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)),
- new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 0)),
- new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0))
- ))
+ val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
val categoricalFeatures = Map.empty[Int, Int]
val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0)
val model = rf.fit(df)
- // copied model must have the same parent.
- MLTestingUtils.checkCopy(model)
val importances = model.featureImportances
val mostImportantFeature = importances.argmax
assert(mostImportantFeature === 1)
+ assert(importances.toArray.sum === 1.0)
+ assert(importances.toArray.forall(_ >= 0.0))
}
/////////////////////////////////////////////////////////////////////////////