aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorsethah <seth.hendrickson16@gmail.com>2016-03-28 22:27:53 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-03-28 22:27:53 -0700
commitf6066b0c3c35ceea1706378145e15776c9b4415a (patch)
tree3878e899b9bbf1b568c82a2ff0783204a7e81b3e /mllib/src/test
parentd3638d7bffd4ee43db594c0669d86fb64d448fc8 (diff)
downloadspark-f6066b0c3c35ceea1706378145e15776c9b4415a.tar.gz
spark-f6066b0c3c35ceea1706378145e15776c9b4415a.tar.bz2
spark-f6066b0c3c35ceea1706378145e15776c9b4415a.zip
[SPARK-11730][ML] Add feature importances for GBTs.
## What changes were proposed in this pull request? Now that GBTs have been moved to ML, they can use the implementation of feature importance for random forests. This patch simply adds a `featureImportances` attribute to `GBTClassifier` and `GBTRegressor` and adds tests for each. GBT feature importances here simply average the feature importances for each tree in its ensemble. This follows the implementation from scikit-learn. This method is also suggested by J Friedman in [this paper](https://statweb.stanford.edu/~jhf/ftp/trebst.pdf). ## How was this patch tested? Unit tests were added to `GBTClassifierSuite` and `GBTRegressorSuite` to validate feature importances. Author: sethah <seth.hendrickson16@gmail.com> Closes #11961 from sethah/SPARK-11730.
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala25
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala23
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala6
3 files changed, 51 insertions, 3 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index f3680ed044..bf7481e8a3 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -121,6 +121,31 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
*/
/////////////////////////////////////////////////////////////////////////////
+ // Tests of feature importance
+ /////////////////////////////////////////////////////////////////////////////
+ test("Feature importance with toy data") {
+ val numClasses = 2
+ val gbt = new GBTClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(3)
+ .setMaxIter(5)
+ .setSubsamplingRate(1.0)
+ .setStepSize(0.5)
+ .setSeed(123)
+
+ // In this data, feature 1 is very important.
+ val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
+ val categoricalFeatures = Map.empty[Int, Int]
+ val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses)
+
+ val importances = gbt.fit(df).featureImportances
+ val mostImportantFeature = importances.argmax
+ assert(mostImportantFeature === 1)
+ assert(importances.toArray.sum === 1.0)
+ assert(importances.toArray.forall(_ >= 0.0))
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index 84148a8a4a..dfb8418086 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -132,6 +132,29 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
*/
/////////////////////////////////////////////////////////////////////////////
+ // Tests of feature importance
+ /////////////////////////////////////////////////////////////////////////////
+ test("Feature importance with toy data") {
+ val gbt = new GBTRegressor()
+ .setMaxDepth(3)
+ .setMaxIter(5)
+ .setSubsamplingRate(1.0)
+ .setStepSize(0.5)
+ .setSeed(123)
+
+ // In this data, feature 1 is very important.
+ val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
+ val categoricalFeatures = Map.empty[Int, Int]
+ val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0)
+
+ val importances = gbt.fit(df).featureImportances
+ val mostImportantFeature = importances.argmax
+ assert(mostImportantFeature === 1)
+ assert(importances.toArray.sum === 1.0)
+ assert(importances.toArray.forall(_ >= 0.0))
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
index 361366fde7..441338e74e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
@@ -471,7 +471,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
// Test feature importance computed at different subtrees.
def testNode(node: Node, expected: Map[Int, Double]): Unit = {
val map = new OpenHashMap[Int, Double]()
- RandomForest.computeFeatureImportance(node, map)
+ TreeEnsembleModel.computeFeatureImportance(node, map)
assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
}
@@ -493,7 +493,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
new DecisionTreeClassificationModel(root, numFeatures = 2, numClasses = 3)
.asInstanceOf[DecisionTreeModel]
}
- val importances: Vector = RandomForest.featureImportances(trees, 2)
+ val importances: Vector = TreeEnsembleModel.featureImportances(trees, 2)
val tree2norm = feature0importance + feature1importance
val expected = Vectors.dense((1.0 + feature0importance / tree2norm) / 2.0,
(feature1importance / tree2norm) / 2.0)
@@ -504,7 +504,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
val map = new OpenHashMap[Int, Double]()
map(0) = 1.0
map(2) = 2.0
- RandomForest.normalizeMapValues(map)
+ TreeEnsembleModel.normalizeMapValues(map)
val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0)
assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
}