diff options
author | Joseph K. Bradley <joseph@databricks.com> | 2015-08-03 12:17:46 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2015-08-03 12:17:46 -0700 |
commit | ff9169a002f1b75231fd25b7d04157a912503038 (patch) | |
tree | ef57aa63ad02760806657e491a78f15f5daa7f66 /mllib/src/test | |
parent | 703e44bff19f4c394f6f9bff1ce9152cdc68c51e (diff) | |
download | spark-ff9169a002f1b75231fd25b7d04157a912503038.tar.gz spark-ff9169a002f1b75231fd25b7d04157a912503038.tar.bz2 spark-ff9169a002f1b75231fd25b7d04157a912503038.zip |
[SPARK-5133] [ML] Added featureImportance to RandomForestClassifier and Regressor
Added featureImportance to RandomForestClassifier and Regressor.
This follows the scikit-learn implementation here: [https://github.com/scikit-learn/scikit-learn/blob/a95203b249c1cf392f86d001ad999e29b2392739/sklearn/tree/_tree.pyx#L3341]
CC: yanboliang Would you mind taking a look? Thanks!
Author: Joseph K. Bradley <joseph@databricks.com>
Author: Feynman Liang <fliang@databricks.com>
Closes #7838 from jkbradley/dt-feature-importance and squashes the following commits:
72a167a [Joseph K. Bradley] fixed unit test
86cea5f [Joseph K. Bradley] Modified RF featuresImportances to return Vector instead of Map
5aa74f0 [Joseph K. Bradley] finally fixed unit test for real
33df5db [Joseph K. Bradley] fix unit test
42a2d3b [Joseph K. Bradley] fix unit test
fe94e72 [Joseph K. Bradley] modified feature importance unit tests
cc693ee [Feynman Liang] Add classifier tests
79a6f87 [Feynman Liang] Compare dense vectors in test
21d01fc [Feynman Liang] Added failing SKLearn test
ac0b254 [Joseph K. Bradley] Added featureImportance to RandomForestClassifier/Regressor. Need to add unit tests
Diffstat (limited to 'mllib/src/test')
6 files changed, 185 insertions, 2 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java index 32d0b3856b..a66a1e1292 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java @@ -29,6 +29,7 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.impl.TreeTests; import org.apache.spark.mllib.classification.LogisticRegressionSuite; +import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.DataFrame; @@ -85,6 +86,7 @@ public class JavaRandomForestClassifierSuite implements Serializable { model.toDebugString(); model.trees(); model.treeWeights(); + Vector importances = model.featureImportances(); /* // TODO: Add test once save/load are implemented. SPARK-6725 diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java index e306ebadfe..a00ce5e249 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java @@ -29,6 +29,7 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.ml.impl.TreeTests; +import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.DataFrame; @@ -85,6 +86,7 @@ public class JavaRandomForestRegressorSuite implements Serializable { model.toDebugString(); model.trees(); model.treeWeights(); + Vector importances = model.featureImportances(); /* // TODO: Add test once save/load are implemented. SPARK-6725 diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index edf848b21a..6ca4b5aa5f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -67,7 +67,7 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte test("params") { ParamsSuite.checkParams(new RandomForestClassifier) val model = new RandomForestClassificationModel("rfc", - Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2)), 2) + Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2)), 2, 2) ParamsSuite.checkParams(model) } @@ -150,6 +150,35 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte } ///////////////////////////////////////////////////////////////////////////// + // Tests of feature importance + ///////////////////////////////////////////////////////////////////////////// + test("Feature importance with toy data") { + val numClasses = 2 + val rf = new RandomForestClassifier() + .setImpurity("Gini") + .setMaxDepth(3) + .setNumTrees(3) + .setFeatureSubsetStrategy("all") + .setSubsamplingRate(1.0) + .setSeed(123) + + // In this data, feature 1 is very important. + val data: RDD[LabeledPoint] = sc.parallelize(Seq( + new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 1)), + new LabeledPoint(1, Vectors.dense(1, 1, 0, 1, 0)), + new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)), + new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 0)), + new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)) + )) + val categoricalFeatures = Map.empty[Int, Int] + val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses) + + val importances = rf.fit(df).featureImportances + val mostImportantFeature = importances.argmax + assert(mostImportantFeature === 1) + } + + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala index 778abcba22..460849c79f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala @@ -124,4 +124,22 @@ private[ml] object TreeTests extends SparkFunSuite { "checkEqual failed since the two tree ensembles were not identical") } } + + /** + * Helper method for constructing a tree for testing. + * Given left, right children, construct a parent node. + * @param split Split for parent node + * @return Parent node with children attached + */ + def buildParentNode(left: Node, right: Node, split: Split): Node = { + val leftImp = left.impurityStats + val rightImp = right.impurityStats + val parentImp = leftImp.copy.add(rightImp) + val leftWeight = leftImp.count / parentImp.count.toDouble + val rightWeight = rightImp.count / parentImp.count.toDouble + val gain = parentImp.calculate() - + (leftWeight * leftImp.calculate() + rightWeight * rightImp.calculate()) + val pred = parentImp.predict + new InternalNode(pred, parentImp.calculate(), gain, left, right, split, parentImp) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index b24ecaa57c..992ce95624 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} @@ -26,7 +27,6 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame - /** * Test suite for [[RandomForestRegressor]]. */ @@ -71,6 +71,31 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex regressionTestWithContinuousFeatures(rf) } + test("Feature importance with toy data") { + val rf = new RandomForestRegressor() + .setImpurity("variance") + .setMaxDepth(3) + .setNumTrees(3) + .setFeatureSubsetStrategy("all") + .setSubsamplingRate(1.0) + .setSeed(123) + + // In this data, feature 1 is very important. + val data: RDD[LabeledPoint] = sc.parallelize(Seq( + new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 1)), + new LabeledPoint(1, Vectors.dense(1, 1, 0, 1, 0)), + new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)), + new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 0)), + new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)) + )) + val categoricalFeatures = Map.empty[Int, Int] + val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0) + + val importances = rf.fit(df).featureImportances + val mostImportantFeature = importances.argmax + assert(mostImportantFeature === 1) + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala new file mode 100644 index 0000000000..dc852795c7 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.classification.DecisionTreeClassificationModel +import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.ml.tree.{ContinuousSplit, DecisionTreeModel, LeafNode, Node} +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.tree.impurity.GiniCalculator +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.collection.OpenHashMap + +/** + * Test suite for [[RandomForest]]. + */ +class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { + + import RandomForestSuite.mapToVec + + test("computeFeatureImportance, featureImportances") { + /* Build tree for testing, with this structure: + grandParent + left2 parent + left right + */ + val leftImp = new GiniCalculator(Array(3.0, 2.0, 1.0)) + val left = new LeafNode(0.0, leftImp.calculate(), leftImp) + + val rightImp = new GiniCalculator(Array(1.0, 2.0, 5.0)) + val right = new LeafNode(2.0, rightImp.calculate(), rightImp) + + val parent = TreeTests.buildParentNode(left, right, new ContinuousSplit(0, 0.5)) + val parentImp = parent.impurityStats + + val left2Imp = new GiniCalculator(Array(1.0, 6.0, 1.0)) + val left2 = new LeafNode(0.0, left2Imp.calculate(), left2Imp) + + val grandParent = TreeTests.buildParentNode(left2, parent, new ContinuousSplit(1, 1.0)) + val grandImp = grandParent.impurityStats + + // Test feature importance computed at different subtrees. + def testNode(node: Node, expected: Map[Int, Double]): Unit = { + val map = new OpenHashMap[Int, Double]() + RandomForest.computeFeatureImportance(node, map) + assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01) + } + + // Leaf node + testNode(left, Map.empty[Int, Double]) + + // Internal node with 2 leaf children + val feature0importance = parentImp.calculate() * parentImp.count - + (leftImp.calculate() * leftImp.count + rightImp.calculate() * rightImp.count) + testNode(parent, Map(0 -> feature0importance)) + + // Full tree + val feature1importance = grandImp.calculate() * grandImp.count - + (left2Imp.calculate() * left2Imp.count + parentImp.calculate() * parentImp.count) + testNode(grandParent, Map(0 -> feature0importance, 1 -> feature1importance)) + + // Forest consisting of (full tree) + (internal node with 2 leafs) + val trees = Array(parent, grandParent).map { root => + new DecisionTreeClassificationModel(root, numClasses = 3).asInstanceOf[DecisionTreeModel] + } + val importances: Vector = RandomForest.featureImportances(trees, 2) + val tree2norm = feature0importance + feature1importance + val expected = Vectors.dense((1.0 + feature0importance / tree2norm) / 2.0, + (feature1importance / tree2norm) / 2.0) + assert(importances ~== expected relTol 0.01) + } + + test("normalizeMapValues") { + val map = new OpenHashMap[Int, Double]() + map(0) = 1.0 + map(2) = 2.0 + RandomForest.normalizeMapValues(map) + val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0) + assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01) + } + +} + +private object RandomForestSuite { + + def mapToVec(map: Map[Int, Double]): Vector = { + val size = (map.keys.toSeq :+ 0).max + 1 + val (indices, values) = map.toSeq.sortBy(_._1).unzip + Vectors.sparse(size, indices.toArray, values.toArray) + } +} |