From f9d578eaa107d8e8503c1563a2b3990c85104298 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 13 Apr 2016 11:31:10 -0700 Subject: [SPARK-13783][ML] Model export/import for spark.ml: GBTs ## What changes were proposed in this pull request? * Added save/load for ```GBTClassifier/GBTClassificationModel/GBTRegressor/GBTRegressionModel```. * Meanwhile, I modified ```EnsembleModelReadWrite.saveImpl/loadImpl``` to support save/load ```treeWeights```. ## How was this patch tested? Adds standard unit tests for GBT save/load. cc jkbradley GayathriMurali Author: Yanbo Liang Closes #12230 from yanboliang/spark-13783. --- .../ml/classification/GBTClassifierSuite.scala | 37 ++++++++++------------ .../spark/ml/regression/GBTRegressorSuite.scala | 36 ++++++++++----------- 2 files changed, 33 insertions(+), 40 deletions(-) (limited to 'mllib/src/test') diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 76d8c9372e..7e6aec6b1b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree.LeafNode import org.apache.spark.ml.tree.impl.TreeTests -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} @@ -34,7 +34,8 @@ import org.apache.spark.util.Utils /** * Test suite for [[GBTClassifier]]. */ -class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { +class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultReadWriteTest { import GBTClassifierSuite.compareAPIs @@ -156,27 +157,23 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// - // TODO: Reinstate test once save/load are implemented SPARK-6725 - /* test("model save/load") { - val tempDir = Utils.createTempDir() - val path = tempDir.toURI.toString - - val trees = Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray - val treeWeights = Array(0.1, 0.3, 1.1) - val oldModel = new OldGBTModel(OldAlgo.Classification, trees, treeWeights) - val newModel = GBTClassificationModel.fromOld(oldModel) - - // Save model, load it back, and compare. - try { - newModel.save(sc, path) - val sameNewModel = GBTClassificationModel.load(sc, path) - TreeTests.checkEqual(newModel, sameNewModel) - } finally { - Utils.deleteRecursively(tempDir) + def checkModelData( + model: GBTClassificationModel, + model2: GBTClassificationModel): Unit = { + TreeTests.checkEqual(model, model2) + assert(model.numFeatures === model2.numFeatures) } + + val gbt = new GBTClassifier() + val rdd = TreeTests.getTreeReadWriteData(sc) + + val allParamSettings = TreeTests.allParamSettings ++ Map("lossType" -> "logistic") + + val continuousData: DataFrame = + TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2) + testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, checkModelData) } - */ } private object GBTClassifierSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 3c11631f98..216377959e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.tree.impl.TreeTests -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} @@ -32,7 +32,8 @@ import org.apache.spark.util.Utils /** * Test suite for [[GBTRegressor]]. */ -class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { +class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultReadWriteTest { import GBTRegressorSuite.compareAPIs @@ -164,27 +165,22 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// - // TODO: Reinstate test once save/load are implemented SPARK-6725 - /* test("model save/load") { - val tempDir = Utils.createTempDir() - val path = tempDir.toURI.toString - - val trees = Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray - val treeWeights = Array(0.1, 0.3, 1.1) - val oldModel = new OldGBTModel(OldAlgo.Regression, trees, treeWeights) - val newModel = GBTRegressionModel.fromOld(oldModel) - - // Save model, load it back, and compare. - try { - newModel.save(sc, path) - val sameNewModel = GBTRegressionModel.load(sc, path) - TreeTests.checkEqual(newModel, sameNewModel) - } finally { - Utils.deleteRecursively(tempDir) + def checkModelData( + model: GBTRegressionModel, + model2: GBTRegressionModel): Unit = { + TreeTests.checkEqual(model, model2) + assert(model.numFeatures === model2.numFeatures) } + + val gbt = new GBTRegressor() + val rdd = TreeTests.getTreeReadWriteData(sc) + + val allParamSettings = TreeTests.allParamSettings ++ Map("lossType" -> "squared") + val continuousData: DataFrame = + TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0) + testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, checkModelData) } - */ } private object GBTRegressorSuite extends SparkFunSuite { -- cgit v1.2.3