aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala69
1 files changed, 47 insertions, 22 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index 84148a8a4a..216377959e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.tree.impl.TreeTests
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
@@ -29,11 +29,11 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.util.Utils
-
/**
* Test suite for [[GBTRegressor]].
*/
-class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
+class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext
+ with DefaultReadWriteTest {
import GBTRegressorSuite.compareAPIs
@@ -54,7 +54,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80), 2)
}
- test("Regression with continuous features: SquaredError") {
+ test("Regression with continuous features") {
val categoricalFeatures = Map.empty[Int, Int]
GBTRegressor.supportedLossTypes.foreach { loss =>
testCombinations.foreach {
@@ -110,7 +110,14 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
sc.checkpointDir = None
Utils.deleteRecursively(tempDir)
+ }
+ test("should support all NumericType labels and not support other types") {
+ val gbt = new GBTRegressor().setMaxDepth(1)
+ MLTestingUtils.checkNumericTypes[GBTRegressionModel, GBTRegressor](
+ gbt, isClassification = false, sqlContext) { (expected, actual) =>
+ TreeTests.checkEqual(expected, actual)
+ }
}
// TODO: Reinstate test once runWithValidation is implemented SPARK-7132
@@ -132,30 +139,48 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
*/
/////////////////////////////////////////////////////////////////////////////
+ // Tests of feature importance
+ /////////////////////////////////////////////////////////////////////////////
+ test("Feature importance with toy data") {
+ val gbt = new GBTRegressor()
+ .setMaxDepth(3)
+ .setMaxIter(5)
+ .setSubsamplingRate(1.0)
+ .setStepSize(0.5)
+ .setSeed(123)
+
+ // In this data, feature 1 is very important.
+ val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
+ val categoricalFeatures = Map.empty[Int, Int]
+ val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0)
+
+ val importances = gbt.fit(df).featureImportances
+ val mostImportantFeature = importances.argmax
+ assert(mostImportantFeature === 1)
+ assert(importances.toArray.sum === 1.0)
+ assert(importances.toArray.forall(_ >= 0.0))
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
- // TODO: Reinstate test once save/load are implemented SPARK-6725
- /*
test("model save/load") {
- val tempDir = Utils.createTempDir()
- val path = tempDir.toURI.toString
-
- val trees = Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray
- val treeWeights = Array(0.1, 0.3, 1.1)
- val oldModel = new OldGBTModel(OldAlgo.Regression, trees, treeWeights)
- val newModel = GBTRegressionModel.fromOld(oldModel)
-
- // Save model, load it back, and compare.
- try {
- newModel.save(sc, path)
- val sameNewModel = GBTRegressionModel.load(sc, path)
- TreeTests.checkEqual(newModel, sameNewModel)
- } finally {
- Utils.deleteRecursively(tempDir)
+ def checkModelData(
+ model: GBTRegressionModel,
+ model2: GBTRegressionModel): Unit = {
+ TreeTests.checkEqual(model, model2)
+ assert(model.numFeatures === model2.numFeatures)
}
+
+ val gbt = new GBTRegressor()
+ val rdd = TreeTests.getTreeReadWriteData(sc)
+
+ val allParamSettings = TreeTests.allParamSettings ++ Map("lossType" -> "squared")
+ val continuousData: DataFrame =
+ TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0)
+ testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, checkModelData)
}
- */
}
private object GBTRegressorSuite extends SparkFunSuite {