aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala71
1 files changed, 50 insertions, 21 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index f3680ed044..7e6aec6b1b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -22,7 +22,7 @@ import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree.LeafNode
import org.apache.spark.ml.tree.impl.TreeTests
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
@@ -31,11 +31,11 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.util.Utils
-
/**
* Test suite for [[GBTClassifier]].
*/
-class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
+class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
+ with DefaultReadWriteTest {
import GBTClassifierSuite.compareAPIs
@@ -102,6 +102,14 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
Utils.deleteRecursively(tempDir)
}
+ test("should support all NumericType labels and not support other types") {
+ val gbt = new GBTClassifier().setMaxDepth(1)
+ MLTestingUtils.checkNumericTypes[GBTClassificationModel, GBTClassifier](
+ gbt, isClassification = true, sqlContext) { (expected, actual) =>
+ TreeTests.checkEqual(expected, actual)
+ }
+ }
+
// TODO: Reinstate test once runWithValidation is implemented SPARK-7132
/*
test("runWithValidation stops early and performs better on a validation dataset") {
@@ -121,30 +129,51 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
*/
/////////////////////////////////////////////////////////////////////////////
+ // Tests of feature importance
+ /////////////////////////////////////////////////////////////////////////////
+ test("Feature importance with toy data") {
+ val numClasses = 2
+ val gbt = new GBTClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(3)
+ .setMaxIter(5)
+ .setSubsamplingRate(1.0)
+ .setStepSize(0.5)
+ .setSeed(123)
+
+ // In this data, feature 1 is very important.
+ val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
+ val categoricalFeatures = Map.empty[Int, Int]
+ val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses)
+
+ val importances = gbt.fit(df).featureImportances
+ val mostImportantFeature = importances.argmax
+ assert(mostImportantFeature === 1)
+ assert(importances.toArray.sum === 1.0)
+ assert(importances.toArray.forall(_ >= 0.0))
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
- // TODO: Reinstate test once save/load are implemented SPARK-6725
- /*
test("model save/load") {
- val tempDir = Utils.createTempDir()
- val path = tempDir.toURI.toString
-
- val trees = Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray
- val treeWeights = Array(0.1, 0.3, 1.1)
- val oldModel = new OldGBTModel(OldAlgo.Classification, trees, treeWeights)
- val newModel = GBTClassificationModel.fromOld(oldModel)
-
- // Save model, load it back, and compare.
- try {
- newModel.save(sc, path)
- val sameNewModel = GBTClassificationModel.load(sc, path)
- TreeTests.checkEqual(newModel, sameNewModel)
- } finally {
- Utils.deleteRecursively(tempDir)
+ def checkModelData(
+ model: GBTClassificationModel,
+ model2: GBTClassificationModel): Unit = {
+ TreeTests.checkEqual(model, model2)
+ assert(model.numFeatures === model2.numFeatures)
}
+
+ val gbt = new GBTClassifier()
+ val rdd = TreeTests.getTreeReadWriteData(sc)
+
+ val allParamSettings = TreeTests.allParamSettings ++ Map("lossType" -> "logistic")
+
+ val continuousData: DataFrame =
+ TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2)
+ testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, checkModelData)
}
- */
}
private object GBTClassifierSuite extends SparkFunSuite {