aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala37
1 files changed, 17 insertions, 20 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index 76d8c9372e..7e6aec6b1b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -22,7 +22,7 @@ import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree.LeafNode
import org.apache.spark.ml.tree.impl.TreeTests
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
@@ -34,7 +34,8 @@ import org.apache.spark.util.Utils
/**
* Test suite for [[GBTClassifier]].
*/
-class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
+class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
+ with DefaultReadWriteTest {
import GBTClassifierSuite.compareAPIs
@@ -156,27 +157,23 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
- // TODO: Reinstate test once save/load are implemented SPARK-6725
- /*
test("model save/load") {
- val tempDir = Utils.createTempDir()
- val path = tempDir.toURI.toString
-
- val trees = Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray
- val treeWeights = Array(0.1, 0.3, 1.1)
- val oldModel = new OldGBTModel(OldAlgo.Classification, trees, treeWeights)
- val newModel = GBTClassificationModel.fromOld(oldModel)
-
- // Save model, load it back, and compare.
- try {
- newModel.save(sc, path)
- val sameNewModel = GBTClassificationModel.load(sc, path)
- TreeTests.checkEqual(newModel, sameNewModel)
- } finally {
- Utils.deleteRecursively(tempDir)
+ def checkModelData(
+ model: GBTClassificationModel,
+ model2: GBTClassificationModel): Unit = {
+ TreeTests.checkEqual(model, model2)
+ assert(model.numFeatures === model2.numFeatures)
}
+
+ val gbt = new GBTClassifier()
+ val rdd = TreeTests.getTreeReadWriteData(sc)
+
+ val allParamSettings = TreeTests.allParamSettings ++ Map("lossType" -> "logistic")
+
+ val continuousData: DataFrame =
+ TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2)
+ testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, checkModelData)
}
- */
}
private object GBTClassifierSuite extends SparkFunSuite {