aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala25
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala23
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala6
3 files changed, 51 insertions, 3 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index f3680ed044..bf7481e8a3 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -121,6 +121,31 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
*/
/////////////////////////////////////////////////////////////////////////////
+ // Tests of feature importance
+ /////////////////////////////////////////////////////////////////////////////
+ test("Feature importance with toy data") {
+ val numClasses = 2
+ val gbt = new GBTClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(3)
+ .setMaxIter(5)
+ .setSubsamplingRate(1.0)
+ .setStepSize(0.5)
+ .setSeed(123)
+
+ // In this data, feature 1 is very important.
+ val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
+ val categoricalFeatures = Map.empty[Int, Int]
+ val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses)
+
+ val importances = gbt.fit(df).featureImportances
+ val mostImportantFeature = importances.argmax
+ assert(mostImportantFeature === 1)
+ assert(importances.toArray.sum === 1.0)
+ assert(importances.toArray.forall(_ >= 0.0))
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index 84148a8a4a..dfb8418086 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -132,6 +132,29 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
*/
/////////////////////////////////////////////////////////////////////////////
+ // Tests of feature importance
+ /////////////////////////////////////////////////////////////////////////////
+ test("Feature importance with toy data") {
+ val gbt = new GBTRegressor()
+ .setMaxDepth(3)
+ .setMaxIter(5)
+ .setSubsamplingRate(1.0)
+ .setStepSize(0.5)
+ .setSeed(123)
+
+ // In this data, feature 1 is very important.
+ val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
+ val categoricalFeatures = Map.empty[Int, Int]
+ val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0)
+
+ val importances = gbt.fit(df).featureImportances
+ val mostImportantFeature = importances.argmax
+ assert(mostImportantFeature === 1)
+ assert(importances.toArray.sum === 1.0)
+ assert(importances.toArray.forall(_ >= 0.0))
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
index 361366fde7..441338e74e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
@@ -471,7 +471,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
// Test feature importance computed at different subtrees.
def testNode(node: Node, expected: Map[Int, Double]): Unit = {
val map = new OpenHashMap[Int, Double]()
- RandomForest.computeFeatureImportance(node, map)
+ TreeEnsembleModel.computeFeatureImportance(node, map)
assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
}
@@ -493,7 +493,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
new DecisionTreeClassificationModel(root, numFeatures = 2, numClasses = 3)
.asInstanceOf[DecisionTreeModel]
}
- val importances: Vector = RandomForest.featureImportances(trees, 2)
+ val importances: Vector = TreeEnsembleModel.featureImportances(trees, 2)
val tree2norm = feature0importance + feature1importance
val expected = Vectors.dense((1.0 + feature0importance / tree2norm) / 2.0,
(feature1importance / tree2norm) / 2.0)
@@ -504,7 +504,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
val map = new OpenHashMap[Int, Double]()
map(0) = 1.0
map(2) = 2.0
- RandomForest.normalizeMapValues(map)
+ TreeEnsembleModel.normalizeMapValues(map)
val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0)
assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
}