diff options
Diffstat (limited to 'mllib/src/test')
3 files changed, 51 insertions, 3 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index f3680ed044..bf7481e8a3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -121,6 +121,31 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { */ ///////////////////////////////////////////////////////////////////////////// + // Tests of feature importance + ///////////////////////////////////////////////////////////////////////////// + test("Feature importance with toy data") { + val numClasses = 2 + val gbt = new GBTClassifier() + .setImpurity("Gini") + .setMaxDepth(3) + .setMaxIter(5) + .setSubsamplingRate(1.0) + .setStepSize(0.5) + .setSeed(123) + + // In this data, feature 1 is very important. + val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc) + val categoricalFeatures = Map.empty[Int, Int] + val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses) + + val importances = gbt.fit(df).featureImportances + val mostImportantFeature = importances.argmax + assert(mostImportantFeature === 1) + assert(importances.toArray.sum === 1.0) + assert(importances.toArray.forall(_ >= 0.0)) + } + + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 84148a8a4a..dfb8418086 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -132,6 +132,29 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { */ ///////////////////////////////////////////////////////////////////////////// + // Tests of feature importance + ///////////////////////////////////////////////////////////////////////////// + test("Feature importance with toy data") { + val gbt = new GBTRegressor() + .setMaxDepth(3) + .setMaxIter(5) + .setSubsamplingRate(1.0) + .setStepSize(0.5) + .setSeed(123) + + // In this data, feature 1 is very important. + val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc) + val categoricalFeatures = Map.empty[Int, Int] + val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0) + + val importances = gbt.fit(df).featureImportances + val mostImportantFeature = importances.argmax + assert(mostImportantFeature === 1) + assert(importances.toArray.sum === 1.0) + assert(importances.toArray.forall(_ >= 0.0)) + } + + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index 361366fde7..441338e74e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -471,7 +471,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // Test feature importance computed at different subtrees. def testNode(node: Node, expected: Map[Int, Double]): Unit = { val map = new OpenHashMap[Int, Double]() - RandomForest.computeFeatureImportance(node, map) + TreeEnsembleModel.computeFeatureImportance(node, map) assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01) } @@ -493,7 +493,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { new DecisionTreeClassificationModel(root, numFeatures = 2, numClasses = 3) .asInstanceOf[DecisionTreeModel] } - val importances: Vector = RandomForest.featureImportances(trees, 2) + val importances: Vector = TreeEnsembleModel.featureImportances(trees, 2) val tree2norm = feature0importance + feature1importance val expected = Vectors.dense((1.0 + feature0importance / tree2norm) / 2.0, (feature1importance / tree2norm) / 2.0) @@ -504,7 +504,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val map = new OpenHashMap[Int, Double]() map(0) = 1.0 map(2) = 2.0 - RandomForest.normalizeMapValues(map) + TreeEnsembleModel.normalizeMapValues(map) val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0) assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01) } |