aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache
diff options
context:
space:
mode:
authorsethah <seth.hendrickson16@gmail.com>2016-03-09 14:44:51 -0800
committerJoseph K. Bradley <joseph@databricks.com>2016-03-09 14:44:51 -0800
commite1772d3f19bed7e69a80de7900ed22d3eeb05300 (patch)
tree9db2d2a2b3ac0786141cc51790dc4de0f8e307c5 /mllib/src/test/scala/org/apache
parentc6aa356cd831ea2d159568b699bd5b791f3d8f25 (diff)
downloadspark-e1772d3f19bed7e69a80de7900ed22d3eeb05300.tar.gz
spark-e1772d3f19bed7e69a80de7900ed22d3eeb05300.tar.bz2
spark-e1772d3f19bed7e69a80de7900ed22d3eeb05300.zip
[SPARK-11861][ML] Add feature importances for decision trees
This patch adds an API entry point for single decision tree feature importances. Author: sethah <seth.hendrickson16@gmail.com> Closes #9912 from sethah/SPARK-11861.
Diffstat (limited to 'mllib/src/test/scala/org/apache')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala21
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala10
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala13
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala20
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala13
5 files changed, 60 insertions, 17 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index 9169bcd390..6d68364499 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -313,6 +313,27 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
}
}
+ test("Feature importance with toy data") {
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("gini")
+ .setMaxDepth(3)
+ .setSeed(123)
+
+ // In this data, feature 1 is very important.
+ val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
+ val numFeatures = data.first().features.size
+ val categoricalFeatures = (0 to numFeatures).map(i => (i, 2)).toMap
+ val df = TreeTests.setMetadata(data, categoricalFeatures, 2)
+
+ val model = dt.fit(df)
+
+ val importances = model.featureImportances
+ val mostImportantFeature = importances.argmax
+ assert(mostImportantFeature === 1)
+ assert(importances.toArray.sum === 1.0)
+ assert(importances.toArray.forall(_ >= 0.0))
+ }
+
/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index deb8ec771c..6b810ab9ea 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -167,19 +167,15 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
.setSeed(123)
// In this data, feature 1 is very important.
- val data: RDD[LabeledPoint] = sc.parallelize(Seq(
- new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 1)),
- new LabeledPoint(1, Vectors.dense(1, 1, 0, 1, 0)),
- new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)),
- new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 0)),
- new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0))
- ))
+ val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
val categoricalFeatures = Map.empty[Int, Int]
val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses)
val importances = rf.fit(df).featureImportances
val mostImportantFeature = importances.argmax
assert(mostImportantFeature === 1)
+ assert(importances.toArray.sum === 1.0)
+ assert(importances.toArray.forall(_ >= 0.0))
}
/////////////////////////////////////////////////////////////////////////////
diff --git a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
index a808177cb9..5561f6f0ef 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
@@ -19,10 +19,12 @@ package org.apache.spark.ml.impl
import scala.collection.JavaConverters._
+import org.apache.spark.SparkContext
import org.apache.spark.SparkFunSuite
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute}
import org.apache.spark.ml.tree._
+import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SQLContext}
@@ -141,4 +143,15 @@ private[ml] object TreeTests extends SparkFunSuite {
val pred = parentImp.predict
new InternalNode(pred, parentImp.calculate(), gain, left, right, split, parentImp)
}
+
+ /**
+ * Create some toy data for testing feature importances.
+ */
+ def featureImportanceData(sc: SparkContext): RDD[LabeledPoint] = sc.parallelize(Seq(
+ new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 1)),
+ new LabeledPoint(1, Vectors.dense(1, 1, 0, 1, 0)),
+ new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)),
+ new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 0)),
+ new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0))
+ ))
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index 13165f6701..56b335a33a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -96,6 +96,26 @@ class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContex
}
}
+ test("Feature importance with toy data") {
+ val dt = new DecisionTreeRegressor()
+ .setImpurity("variance")
+ .setMaxDepth(3)
+ .setSeed(123)
+
+ // In this data, feature 1 is very important.
+ val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
+ val categoricalFeatures = Map.empty[Int, Int]
+ val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0)
+
+ val model = dt.fit(df)
+
+ val importances = model.featureImportances
+ val mostImportantFeature = importances.argmax
+ assert(mostImportantFeature === 1)
+ assert(importances.toArray.sum === 1.0)
+ assert(importances.toArray.forall(_ >= 0.0))
+ }
+
/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
index 7e751e4b55..efb117f8f9 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
@@ -20,7 +20,6 @@ package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.ml.util.MLTestingUtils
-import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
@@ -82,23 +81,17 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex
.setSeed(123)
// In this data, feature 1 is very important.
- val data: RDD[LabeledPoint] = sc.parallelize(Seq(
- new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 1)),
- new LabeledPoint(1, Vectors.dense(1, 1, 0, 1, 0)),
- new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)),
- new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 0)),
- new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0))
- ))
+ val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
val categoricalFeatures = Map.empty[Int, Int]
val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0)
val model = rf.fit(df)
- // copied model must have the same parent.
- MLTestingUtils.checkCopy(model)
val importances = model.featureImportances
val mostImportantFeature = importances.argmax
assert(mostImportantFeature === 1)
+ assert(importances.toArray.sum === 1.0)
+ assert(importances.toArray.forall(_ >= 0.0))
}
/////////////////////////////////////////////////////////////////////////////