diff options
author | Joseph K. Bradley <joseph@databricks.com> | 2015-08-03 12:17:46 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2015-08-03 12:17:46 -0700 |
commit | ff9169a002f1b75231fd25b7d04157a912503038 (patch) | |
tree | ef57aa63ad02760806657e491a78f15f5daa7f66 /mllib/src/test/java/org | |
parent | 703e44bff19f4c394f6f9bff1ce9152cdc68c51e (diff) | |
download | spark-ff9169a002f1b75231fd25b7d04157a912503038.tar.gz spark-ff9169a002f1b75231fd25b7d04157a912503038.tar.bz2 spark-ff9169a002f1b75231fd25b7d04157a912503038.zip |
[SPARK-5133] [ML] Added featureImportance to RandomForestClassifier and Regressor
Added featureImportance to RandomForestClassifier and Regressor.
This follows the scikit-learn implementation here: [https://github.com/scikit-learn/scikit-learn/blob/a95203b249c1cf392f86d001ad999e29b2392739/sklearn/tree/_tree.pyx#L3341]
CC: yanboliang Would you mind taking a look? Thanks!
Author: Joseph K. Bradley <joseph@databricks.com>
Author: Feynman Liang <fliang@databricks.com>
Closes #7838 from jkbradley/dt-feature-importance and squashes the following commits:
72a167a [Joseph K. Bradley] fixed unit test
86cea5f [Joseph K. Bradley] Modified RF featuresImportances to return Vector instead of Map
5aa74f0 [Joseph K. Bradley] finally fixed unit test for real
33df5db [Joseph K. Bradley] fix unit test
42a2d3b [Joseph K. Bradley] fix unit test
fe94e72 [Joseph K. Bradley] modified feature importance unit tests
cc693ee [Feynman Liang] Add classifier tests
79a6f87 [Feynman Liang] Compare dense vectors in test
21d01fc [Feynman Liang] Added failing SKLearn test
ac0b254 [Joseph K. Bradley] Added featureImportance to RandomForestClassifier/Regressor. Need to add unit tests
Diffstat (limited to 'mllib/src/test/java/org')
2 files changed, 4 insertions, 0 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java index 32d0b3856b..a66a1e1292 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java @@ -29,6 +29,7 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.impl.TreeTests; import org.apache.spark.mllib.classification.LogisticRegressionSuite; +import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.DataFrame; @@ -85,6 +86,7 @@ public class JavaRandomForestClassifierSuite implements Serializable { model.toDebugString(); model.trees(); model.treeWeights(); + Vector importances = model.featureImportances(); /* // TODO: Add test once save/load are implemented. SPARK-6725 diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java index e306ebadfe..a00ce5e249 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java @@ -29,6 +29,7 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.ml.impl.TreeTests; +import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.DataFrame; @@ -85,6 +86,7 @@ public class JavaRandomForestRegressorSuite implements Serializable { model.toDebugString(); model.trees(); model.treeWeights(); + Vector importances = model.featureImportances(); /* // TODO: Add test once save/load are implemented. SPARK-6725 |