aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-08-03 12:17:46 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-08-03 12:17:46 -0700
commitff9169a002f1b75231fd25b7d04157a912503038 (patch)
treeef57aa63ad02760806657e491a78f15f5daa7f66 /mllib/src/test
parent703e44bff19f4c394f6f9bff1ce9152cdc68c51e (diff)
downloadspark-ff9169a002f1b75231fd25b7d04157a912503038.tar.gz
spark-ff9169a002f1b75231fd25b7d04157a912503038.tar.bz2
spark-ff9169a002f1b75231fd25b7d04157a912503038.zip
[SPARK-5133] [ML] Added featureImportance to RandomForestClassifier and Regressor
Added featureImportance to RandomForestClassifier and Regressor. This follows the scikit-learn implementation here: [https://github.com/scikit-learn/scikit-learn/blob/a95203b249c1cf392f86d001ad999e29b2392739/sklearn/tree/_tree.pyx#L3341] CC: yanboliang Would you mind taking a look? Thanks! Author: Joseph K. Bradley <joseph@databricks.com> Author: Feynman Liang <fliang@databricks.com> Closes #7838 from jkbradley/dt-feature-importance and squashes the following commits: 72a167a [Joseph K. Bradley] fixed unit test 86cea5f [Joseph K. Bradley] Modified RF featuresImportances to return Vector instead of Map 5aa74f0 [Joseph K. Bradley] finally fixed unit test for real 33df5db [Joseph K. Bradley] fix unit test 42a2d3b [Joseph K. Bradley] fix unit test fe94e72 [Joseph K. Bradley] modified feature importance unit tests cc693ee [Feynman Liang] Add classifier tests 79a6f87 [Feynman Liang] Compare dense vectors in test 21d01fc [Feynman Liang] Added failing SKLearn test ac0b254 [Joseph K. Bradley] Added featureImportance to RandomForestClassifier/Regressor. Need to add unit tests
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java2
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala31
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala18
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala27
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala107
6 files changed, 185 insertions, 2 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
index 32d0b3856b..a66a1e1292 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
@@ -29,6 +29,7 @@ import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.impl.TreeTests;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
+import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.DataFrame;
@@ -85,6 +86,7 @@ public class JavaRandomForestClassifierSuite implements Serializable {
model.toDebugString();
model.trees();
model.treeWeights();
+ Vector importances = model.featureImportances();
/*
// TODO: Add test once save/load are implemented. SPARK-6725
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
index e306ebadfe..a00ce5e249 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
@@ -29,6 +29,7 @@ import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.ml.impl.TreeTests;
+import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.DataFrame;
@@ -85,6 +86,7 @@ public class JavaRandomForestRegressorSuite implements Serializable {
model.toDebugString();
model.trees();
model.treeWeights();
+ Vector importances = model.featureImportances();
/*
// TODO: Add test once save/load are implemented. SPARK-6725
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index edf848b21a..6ca4b5aa5f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -67,7 +67,7 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
test("params") {
ParamsSuite.checkParams(new RandomForestClassifier)
val model = new RandomForestClassificationModel("rfc",
- Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2)), 2)
+ Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2)), 2, 2)
ParamsSuite.checkParams(model)
}
@@ -150,6 +150,35 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
}
/////////////////////////////////////////////////////////////////////////////
+ // Tests of feature importance
+ /////////////////////////////////////////////////////////////////////////////
+ test("Feature importance with toy data") {
+ val numClasses = 2
+ val rf = new RandomForestClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(3)
+ .setNumTrees(3)
+ .setFeatureSubsetStrategy("all")
+ .setSubsamplingRate(1.0)
+ .setSeed(123)
+
+ // In this data, feature 1 is very important.
+ val data: RDD[LabeledPoint] = sc.parallelize(Seq(
+ new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 1)),
+ new LabeledPoint(1, Vectors.dense(1, 1, 0, 1, 0)),
+ new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)),
+ new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 0)),
+ new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0))
+ ))
+ val categoricalFeatures = Map.empty[Int, Int]
+ val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses)
+
+ val importances = rf.fit(df).featureImportances
+ val mostImportantFeature = importances.argmax
+ assert(mostImportantFeature === 1)
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
diff --git a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
index 778abcba22..460849c79f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
@@ -124,4 +124,22 @@ private[ml] object TreeTests extends SparkFunSuite {
"checkEqual failed since the two tree ensembles were not identical")
}
}
+
+ /**
+ * Helper method for constructing a tree for testing.
+ * Given left, right children, construct a parent node.
+ * @param split Split for parent node
+ * @return Parent node with children attached
+ */
+ def buildParentNode(left: Node, right: Node, split: Split): Node = {
+ val leftImp = left.impurityStats
+ val rightImp = right.impurityStats
+ val parentImp = leftImp.copy.add(rightImp)
+ val leftWeight = leftImp.count / parentImp.count.toDouble
+ val rightWeight = rightImp.count / parentImp.count.toDouble
+ val gain = parentImp.calculate() -
+ (leftWeight * leftImp.calculate() + rightWeight * rightImp.calculate())
+ val pred = parentImp.predict
+ new InternalNode(pred, parentImp.calculate(), gain, left, right, split, parentImp)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
index b24ecaa57c..992ce95624 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
+import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
@@ -26,7 +27,6 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
-
/**
* Test suite for [[RandomForestRegressor]].
*/
@@ -71,6 +71,31 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex
regressionTestWithContinuousFeatures(rf)
}
+ test("Feature importance with toy data") {
+ val rf = new RandomForestRegressor()
+ .setImpurity("variance")
+ .setMaxDepth(3)
+ .setNumTrees(3)
+ .setFeatureSubsetStrategy("all")
+ .setSubsamplingRate(1.0)
+ .setSeed(123)
+
+ // In this data, feature 1 is very important.
+ val data: RDD[LabeledPoint] = sc.parallelize(Seq(
+ new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 1)),
+ new LabeledPoint(1, Vectors.dense(1, 1, 0, 1, 0)),
+ new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)),
+ new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 0)),
+ new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0))
+ ))
+ val categoricalFeatures = Map.empty[Int, Int]
+ val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0)
+
+ val importances = rf.fit(df).featureImportances
+ val mostImportantFeature = importances.argmax
+ assert(mostImportantFeature === 1)
+ }
+
/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
new file mode 100644
index 0000000000..dc852795c7
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree.impl
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.classification.DecisionTreeClassificationModel
+import org.apache.spark.ml.impl.TreeTests
+import org.apache.spark.ml.tree.{ContinuousSplit, DecisionTreeModel, LeafNode, Node}
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.tree.impurity.GiniCalculator
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.util.collection.OpenHashMap
+
+/**
+ * Test suite for [[RandomForest]].
+ */
+class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+ import RandomForestSuite.mapToVec
+
+ test("computeFeatureImportance, featureImportances") {
+ /* Build tree for testing, with this structure:
+ grandParent
+ left2 parent
+ left right
+ */
+ val leftImp = new GiniCalculator(Array(3.0, 2.0, 1.0))
+ val left = new LeafNode(0.0, leftImp.calculate(), leftImp)
+
+ val rightImp = new GiniCalculator(Array(1.0, 2.0, 5.0))
+ val right = new LeafNode(2.0, rightImp.calculate(), rightImp)
+
+ val parent = TreeTests.buildParentNode(left, right, new ContinuousSplit(0, 0.5))
+ val parentImp = parent.impurityStats
+
+ val left2Imp = new GiniCalculator(Array(1.0, 6.0, 1.0))
+ val left2 = new LeafNode(0.0, left2Imp.calculate(), left2Imp)
+
+ val grandParent = TreeTests.buildParentNode(left2, parent, new ContinuousSplit(1, 1.0))
+ val grandImp = grandParent.impurityStats
+
+ // Test feature importance computed at different subtrees.
+ def testNode(node: Node, expected: Map[Int, Double]): Unit = {
+ val map = new OpenHashMap[Int, Double]()
+ RandomForest.computeFeatureImportance(node, map)
+ assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
+ }
+
+ // Leaf node
+ testNode(left, Map.empty[Int, Double])
+
+ // Internal node with 2 leaf children
+ val feature0importance = parentImp.calculate() * parentImp.count -
+ (leftImp.calculate() * leftImp.count + rightImp.calculate() * rightImp.count)
+ testNode(parent, Map(0 -> feature0importance))
+
+ // Full tree
+ val feature1importance = grandImp.calculate() * grandImp.count -
+ (left2Imp.calculate() * left2Imp.count + parentImp.calculate() * parentImp.count)
+ testNode(grandParent, Map(0 -> feature0importance, 1 -> feature1importance))
+
+ // Forest consisting of (full tree) + (internal node with 2 leafs)
+ val trees = Array(parent, grandParent).map { root =>
+ new DecisionTreeClassificationModel(root, numClasses = 3).asInstanceOf[DecisionTreeModel]
+ }
+ val importances: Vector = RandomForest.featureImportances(trees, 2)
+ val tree2norm = feature0importance + feature1importance
+ val expected = Vectors.dense((1.0 + feature0importance / tree2norm) / 2.0,
+ (feature1importance / tree2norm) / 2.0)
+ assert(importances ~== expected relTol 0.01)
+ }
+
+ test("normalizeMapValues") {
+ val map = new OpenHashMap[Int, Double]()
+ map(0) = 1.0
+ map(2) = 2.0
+ RandomForest.normalizeMapValues(map)
+ val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0)
+ assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
+ }
+
+}
+
+private object RandomForestSuite {
+
+ def mapToVec(map: Map[Int, Double]): Vector = {
+ val size = (map.keys.toSeq :+ 0).max + 1
+ val (indices, values) = map.toSeq.sortBy(_._1).unzip
+ Vectors.sparse(size, indices.toArray, values.toArray)
+ }
+}