aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala36
1 files changed, 16 insertions, 20 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index 3c11631f98..216377959e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.tree.impl.TreeTests
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
@@ -32,7 +32,8 @@ import org.apache.spark.util.Utils
/**
* Test suite for [[GBTRegressor]].
*/
-class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
+class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext
+ with DefaultReadWriteTest {
import GBTRegressorSuite.compareAPIs
@@ -164,27 +165,22 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
- // TODO: Reinstate test once save/load are implemented SPARK-6725
- /*
test("model save/load") {
- val tempDir = Utils.createTempDir()
- val path = tempDir.toURI.toString
-
- val trees = Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray
- val treeWeights = Array(0.1, 0.3, 1.1)
- val oldModel = new OldGBTModel(OldAlgo.Regression, trees, treeWeights)
- val newModel = GBTRegressionModel.fromOld(oldModel)
-
- // Save model, load it back, and compare.
- try {
- newModel.save(sc, path)
- val sameNewModel = GBTRegressionModel.load(sc, path)
- TreeTests.checkEqual(newModel, sameNewModel)
- } finally {
- Utils.deleteRecursively(tempDir)
+ def checkModelData(
+ model: GBTRegressionModel,
+ model2: GBTRegressionModel): Unit = {
+ TreeTests.checkEqual(model, model2)
+ assert(model.numFeatures === model2.numFeatures)
}
+
+ val gbt = new GBTRegressor()
+ val rdd = TreeTests.getTreeReadWriteData(sc)
+
+ val allParamSettings = TreeTests.allParamSettings ++ Map("lossType" -> "squared")
+ val continuousData: DataFrame =
+ TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0)
+ testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, checkModelData)
}
- */
}
private object GBTRegressorSuite extends SparkFunSuite {