diff options
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala | 37 |
1 files changed, 17 insertions, 20 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 76d8c9372e..7e6aec6b1b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree.LeafNode import org.apache.spark.ml.tree.impl.TreeTests -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} @@ -34,7 +34,8 @@ import org.apache.spark.util.Utils /** * Test suite for [[GBTClassifier]]. */ -class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { +class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultReadWriteTest { import GBTClassifierSuite.compareAPIs @@ -156,27 +157,23 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// - // TODO: Reinstate test once save/load are implemented SPARK-6725 - /* test("model save/load") { - val tempDir = Utils.createTempDir() - val path = tempDir.toURI.toString - - val trees = Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray - val treeWeights = Array(0.1, 0.3, 1.1) - val oldModel = new OldGBTModel(OldAlgo.Classification, trees, treeWeights) - val newModel = GBTClassificationModel.fromOld(oldModel) - - // Save model, load it back, and compare. - try { - newModel.save(sc, path) - val sameNewModel = GBTClassificationModel.load(sc, path) - TreeTests.checkEqual(newModel, sameNewModel) - } finally { - Utils.deleteRecursively(tempDir) + def checkModelData( + model: GBTClassificationModel, + model2: GBTClassificationModel): Unit = { + TreeTests.checkEqual(model, model2) + assert(model.numFeatures === model2.numFeatures) } + + val gbt = new GBTClassifier() + val rdd = TreeTests.getTreeReadWriteData(sc) + + val allParamSettings = TreeTests.allParamSettings ++ Map("lossType" -> "logistic") + + val continuousData: DataFrame = + TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2) + testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, checkModelData) } - */ } private object GBTClassifierSuite extends SparkFunSuite { |