diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2016-04-13 11:31:10 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-04-13 11:31:10 -0700 |
commit | f9d578eaa107d8e8503c1563a2b3990c85104298 (patch) | |
tree | c4410c020f61d9c48780eb0f108be250d254f42f /mllib/src/main | |
parent | 7d2ed8cc030f3d84fea47fded072c320c3d87ca7 (diff) | |
download | spark-f9d578eaa107d8e8503c1563a2b3990c85104298.tar.gz spark-f9d578eaa107d8e8503c1563a2b3990c85104298.tar.bz2 spark-f9d578eaa107d8e8503c1563a2b3990c85104298.zip |
[SPARK-13783][ML] Model export/import for spark.ml: GBTs
## What changes were proposed in this pull request?
* Added save/load for ```GBTClassifier/GBTClassificationModel/GBTRegressor/GBTRegressionModel```.
* Meanwhile, I modified ```EnsembleModelReadWrite.saveImpl/loadImpl``` to support save/load ```treeWeights```.
## How was this patch tested?
Adds standard unit tests for GBT save/load.
cc jkbradley GayathriMurali
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #12230 from yanboliang/spark-13783.
Diffstat (limited to 'mllib/src/main')
6 files changed, 229 insertions, 97 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 46e8b89d01..39a698af15 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -18,19 +18,21 @@ package org.apache.spark.ml.classification import com.github.fommil.netlib.BLAS.{getInstance => blas} +import org.json4s.{DefaultFormats, JObject} +import org.json4s.JsonDSL._ import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml.{PredictionModel, Predictor} -import org.apache.spark.ml.param.{Param, ParamMap} +import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.regression.DecisionTreeRegressionModel -import org.apache.spark.ml.tree.{GBTParams, TreeClassifierParams, TreeEnsembleModel} +import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.impl.GradientBoostedTrees -import org.apache.spark.ml.util.{Identifiable, MetadataUtils} +import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} -import org.apache.spark.mllib.tree.loss.{LogLoss => OldLogLoss, Loss => OldLoss} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset} @@ -58,7 +60,7 @@ import org.apache.spark.sql.functions._ final class GBTClassifier @Since("1.4.0") ( @Since("1.4.0") override val uid: String) extends Predictor[Vector, GBTClassifier, GBTClassificationModel] - with GBTParams with TreeClassifierParams with Logging { + with GBTClassifierParams with DefaultParamsWritable with Logging { @Since("1.4.0") def this() = this(Identifiable.randomUID("gbtc")) @@ -115,40 +117,12 @@ final class GBTClassifier @Since("1.4.0") ( @Since("1.4.0") override def setStepSize(value: Double): this.type = super.setStepSize(value) - // Parameters for GBTClassifier: - - /** - * Loss function which GBT tries to minimize. (case-insensitive) - * Supported: "logistic" - * (default = logistic) - * @group param - */ - @Since("1.4.0") - val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + - " tries to minimize (case-insensitive). Supported options:" + - s" ${GBTClassifier.supportedLossTypes.mkString(", ")}", - (value: String) => GBTClassifier.supportedLossTypes.contains(value.toLowerCase)) - - setDefault(lossType -> "logistic") + // Parameters from GBTClassifierParams: /** @group setParam */ @Since("1.4.0") def setLossType(value: String): this.type = set(lossType, value) - /** @group getParam */ - @Since("1.4.0") - def getLossType: String = $(lossType).toLowerCase - - /** (private[ml]) Convert new loss to old loss. */ - override private[ml] def getOldLossType: OldLoss = { - getLossType match { - case "logistic" => OldLogLoss - case _ => - // Should never happen because of check in setter method. - throw new RuntimeException(s"GBTClassifier was given bad loss type: $getLossType") - } - } - override protected def train(dataset: Dataset[_]): GBTClassificationModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) @@ -175,11 +149,14 @@ final class GBTClassifier @Since("1.4.0") ( @Since("1.4.0") @Experimental -object GBTClassifier { - // The losses below should be lowercase. +object GBTClassifier extends DefaultParamsReadable[GBTClassifier] { + /** Accessor for supported loss settings: logistic */ @Since("1.4.0") - final val supportedLossTypes: Array[String] = Array("logistic").map(_.toLowerCase) + final val supportedLossTypes: Array[String] = GBTClassifierParams.supportedLossTypes + + @Since("2.0.0") + override def load(path: String): GBTClassifier = super.load(path) } /** @@ -199,7 +176,8 @@ final class GBTClassificationModel private[ml]( private val _treeWeights: Array[Double], @Since("1.6.0") override val numFeatures: Int) extends PredictionModel[Vector, GBTClassificationModel] - with TreeEnsembleModel[DecisionTreeRegressionModel] with Serializable { + with GBTClassifierParams with TreeEnsembleModel[DecisionTreeRegressionModel] + with MLWritable with Serializable { require(_trees.nonEmpty, "GBTClassificationModel requires at least 1 tree.") require(_trees.length == _treeWeights.length, "GBTClassificationModel given trees, treeWeights" + @@ -267,12 +245,62 @@ final class GBTClassificationModel private[ml]( private[ml] def toOld: OldGBTModel = { new OldGBTModel(OldAlgo.Classification, _trees.map(_.toOld), _treeWeights) } + + @Since("2.0.0") + override def write: MLWriter = new GBTClassificationModel.GBTClassificationModelWriter(this) } -private[ml] object GBTClassificationModel { +@Since("2.0.0") +object GBTClassificationModel extends MLReadable[GBTClassificationModel] { + + @Since("2.0.0") + override def read: MLReader[GBTClassificationModel] = new GBTClassificationModelReader + + @Since("2.0.0") + override def load(path: String): GBTClassificationModel = super.load(path) + + private[GBTClassificationModel] + class GBTClassificationModelWriter(instance: GBTClassificationModel) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + + val extraMetadata: JObject = Map( + "numFeatures" -> instance.numFeatures, + "numTrees" -> instance.getNumTrees) + EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata) + } + } + + private class GBTClassificationModelReader extends MLReader[GBTClassificationModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[GBTClassificationModel].getName + private val treeClassName = classOf[DecisionTreeRegressionModel].getName + + override def load(path: String): GBTClassificationModel = { + implicit val format = DefaultFormats + val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = + EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName) + val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] + val numTrees = (metadata.metadata \ "numTrees").extract[Int] + + val trees: Array[DecisionTreeRegressionModel] = treesData.map { + case (treeMetadata, root) => + val tree = + new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures) + DefaultParamsReader.getAndSetParams(tree, treeMetadata) + tree + } + require(numTrees == trees.length, s"GBTClassificationModel.load expected $numTrees" + + s" trees based on metadata but found ${trees.length} trees.") + val model = new GBTClassificationModel(metadata.uid, trees, treeWeights, numFeatures) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } - /** (private[ml]) Convert a model from the old API */ - def fromOld( + /** Convert a model from the old API */ + private[ml] def fromOld( oldModel: OldGBTModel, parent: GBTClassifier, categoricalFeatures: Map[Int, Int], diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 9d80b8eb68..dfa711b243 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -294,7 +294,7 @@ object RandomForestClassificationModel extends MLReadable[RandomForestClassifica override def load(path: String): RandomForestClassificationModel = { implicit val format = DefaultFormats - val (metadata: Metadata, treesData: Array[(Metadata, Node)]) = + val (metadata: Metadata, treesData: Array[(Metadata, Node)], _) = EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numClasses = (metadata.metadata \ "numClasses").extract[Int] diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 0b52fe2d13..741724d7a1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -18,19 +18,20 @@ package org.apache.spark.ml.regression import com.github.fommil.netlib.BLAS.{getInstance => blas} +import org.json4s.{DefaultFormats, JObject} +import org.json4s.JsonDSL._ import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml.{PredictionModel, Predictor} -import org.apache.spark.ml.param.{Param, ParamMap} -import org.apache.spark.ml.tree.{GBTParams, TreeEnsembleModel, TreeRegressorParams} +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.impl.GradientBoostedTrees -import org.apache.spark.ml.util.{Identifiable, MetadataUtils} +import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} -import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, Loss => OldLoss, - SquaredError => OldSquaredError} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset} @@ -58,7 +59,7 @@ import org.apache.spark.sql.functions._ @Experimental final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends Predictor[Vector, GBTRegressor, GBTRegressionModel] - with GBTParams with TreeRegressorParams with Logging { + with GBTRegressorParams with DefaultParamsWritable with Logging { @Since("1.4.0") def this() = this(Identifiable.randomUID("gbtr")) @@ -112,41 +113,12 @@ final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: Stri @Since("1.4.0") override def setStepSize(value: Double): this.type = super.setStepSize(value) - // Parameters for GBTRegressor: - - /** - * Loss function which GBT tries to minimize. (case-insensitive) - * Supported: "squared" (L2) and "absolute" (L1) - * (default = squared) - * @group param - */ - @Since("1.4.0") - val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + - " tries to minimize (case-insensitive). Supported options:" + - s" ${GBTRegressor.supportedLossTypes.mkString(", ")}", - (value: String) => GBTRegressor.supportedLossTypes.contains(value.toLowerCase)) - - setDefault(lossType -> "squared") + // Parameters from GBTRegressorParams: /** @group setParam */ @Since("1.4.0") def setLossType(value: String): this.type = set(lossType, value) - /** @group getParam */ - @Since("1.4.0") - def getLossType: String = $(lossType).toLowerCase - - /** (private[ml]) Convert new loss to old loss. */ - override private[ml] def getOldLossType: OldLoss = { - getLossType match { - case "squared" => OldSquaredError - case "absolute" => OldAbsoluteError - case _ => - // Should never happen because of check in setter method. - throw new RuntimeException(s"GBTRegressorParams was given bad loss type: $getLossType") - } - } - override protected def train(dataset: Dataset[_]): GBTRegressionModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) @@ -164,11 +136,14 @@ final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: Stri @Since("1.4.0") @Experimental -object GBTRegressor { - // The losses below should be lowercase. +object GBTRegressor extends DefaultParamsReadable[GBTRegressor] { + /** Accessor for supported loss settings: squared (L2), absolute (L1) */ @Since("1.4.0") - final val supportedLossTypes: Array[String] = Array("squared", "absolute").map(_.toLowerCase) + final val supportedLossTypes: Array[String] = GBTRegressorParams.supportedLossTypes + + @Since("2.0.0") + override def load(path: String): GBTRegressor = super.load(path) } /** @@ -188,7 +163,8 @@ final class GBTRegressionModel private[ml]( private val _treeWeights: Array[Double], override val numFeatures: Int) extends PredictionModel[Vector, GBTRegressionModel] - with TreeEnsembleModel[DecisionTreeRegressionModel] with Serializable { + with GBTRegressorParams with TreeEnsembleModel[DecisionTreeRegressionModel] + with MLWritable with Serializable { require(_trees.nonEmpty, "GBTRegressionModel requires at least 1 tree.") require(_trees.length == _treeWeights.length, "GBTRegressionModel given trees, treeWeights of" + @@ -255,12 +231,64 @@ final class GBTRegressionModel private[ml]( private[ml] def toOld: OldGBTModel = { new OldGBTModel(OldAlgo.Regression, _trees.map(_.toOld), _treeWeights) } + + @Since("2.0.0") + override def write: MLWriter = new GBTRegressionModel.GBTRegressionModelWriter(this) } -private[ml] object GBTRegressionModel { +@Since("2.0.0") +object GBTRegressionModel extends MLReadable[GBTRegressionModel] { + + @Since("2.0.0") + override def read: MLReader[GBTRegressionModel] = new GBTRegressionModelReader + + @Since("2.0.0") + override def load(path: String): GBTRegressionModel = super.load(path) + + private[GBTRegressionModel] + class GBTRegressionModelWriter(instance: GBTRegressionModel) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val extraMetadata: JObject = Map( + "numFeatures" -> instance.numFeatures, + "numTrees" -> instance.getNumTrees) + EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata) + } + } + + private class GBTRegressionModelReader extends MLReader[GBTRegressionModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[GBTRegressionModel].getName + private val treeClassName = classOf[DecisionTreeRegressionModel].getName + + override def load(path: String): GBTRegressionModel = { + implicit val format = DefaultFormats + val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = + EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName) + + val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] + val numTrees = (metadata.metadata \ "numTrees").extract[Int] + + val trees: Array[DecisionTreeRegressionModel] = treesData.map { + case (treeMetadata, root) => + val tree = + new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures) + DefaultParamsReader.getAndSetParams(tree, treeMetadata) + tree + } + + require(numTrees == trees.length, s"GBTRegressionModel.load expected $numTrees" + + s" trees based on metadata but found ${trees.length} trees.") + + val model = new GBTRegressionModel(metadata.uid, trees, treeWeights, numFeatures) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } - /** (private[ml]) Convert a model from the old API */ - def fromOld( + /** Convert a model from the old API */ + private[ml] def fromOld( oldModel: OldGBTModel, parent: GBTRegressor, categoricalFeatures: Map[Int, Int], diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index bee13c2ebf..4c4ff278d4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -249,7 +249,7 @@ object RandomForestRegressionModel extends MLReadable[RandomForestRegressionMode override def load(path: String): RandomForestRegressionModel = { implicit val format = DefaultFormats - val (metadata: Metadata, treesData: Array[(Metadata, Node)]) = + val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numTrees = (metadata.metadata \ "numTrees").extract[Int] diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index c4ab673d9a..f38e1ec7c0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -396,12 +396,14 @@ private[ml] object EnsembleModelReadWrite { sql: SQLContext, extraMetadata: JObject): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sql.sparkContext, Some(extraMetadata)) - val treesMetadataJson: Array[(Int, String)] = instance.trees.zipWithIndex.map { + val treesMetadataWeights: Array[(Int, String, Double)] = instance.trees.zipWithIndex.map { case (tree, treeID) => - treeID -> DefaultParamsWriter.getMetadataToSave(tree.asInstanceOf[Params], sql.sparkContext) + (treeID, + DefaultParamsWriter.getMetadataToSave(tree.asInstanceOf[Params], sql.sparkContext), + instance.treeWeights(treeID)) } val treesMetadataPath = new Path(path, "treesMetadata").toString - sql.createDataFrame(treesMetadataJson).toDF("treeID", "metadata") + sql.createDataFrame(treesMetadataWeights).toDF("treeID", "metadata", "weights") .write.parquet(treesMetadataPath) val dataPath = new Path(path, "data").toString val nodeDataRDD = sql.sparkContext.parallelize(instance.trees.zipWithIndex).flatMap { @@ -424,7 +426,7 @@ private[ml] object EnsembleModelReadWrite { path: String, sql: SQLContext, className: String, - treeClassName: String): (Metadata, Array[(Metadata, Node)]) = { + treeClassName: String): (Metadata, Array[(Metadata, Node)], Array[Double]) = { import sql.implicits._ implicit val format = DefaultFormats val metadata = DefaultParamsReader.loadMetadata(path, sql.sparkContext, className) @@ -436,12 +438,15 @@ private[ml] object EnsembleModelReadWrite { } val treesMetadataPath = new Path(path, "treesMetadata").toString - val treesMetadataRDD: RDD[(Int, Metadata)] = sql.read.parquet(treesMetadataPath) - .select("treeID", "metadata").as[(Int, String)].rdd.map { - case (treeID: Int, json: String) => - treeID -> DefaultParamsReader.parseMetadata(json, treeClassName) + val treesMetadataRDD: RDD[(Int, (Metadata, Double))] = sql.read.parquet(treesMetadataPath) + .select("treeID", "metadata", "weights").as[(Int, String, Double)].rdd.map { + case (treeID: Int, json: String, weights: Double) => + treeID -> (DefaultParamsReader.parseMetadata(json, treeClassName), weights) } - val treesMetadata: Array[Metadata] = treesMetadataRDD.sortByKey().values.collect() + + val treesMetadataWeights = treesMetadataRDD.sortByKey().values.collect() + val treesMetadata = treesMetadataWeights.map(_._1) + val treesWeights = treesMetadataWeights.map(_._2) val dataPath = new Path(path, "data").toString val nodeData: Dataset[EnsembleNodeData] = @@ -452,7 +457,7 @@ private[ml] object EnsembleModelReadWrite { treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes(nodeData.toArray, impurityType) } val rootNodes: Array[Node] = rootNodesRDD.sortByKey().values.collect() - (metadata, treesMetadata.zip(rootNodes)) + (metadata, treesMetadata.zip(rootNodes), treesWeights) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 0767dc17e5..b6783911ad 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -23,7 +23,7 @@ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance} -import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} +import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, LogLoss => OldLogLoss, Loss => OldLoss, SquaredError => OldSquaredError} import org.apache.spark.sql.types.{DataType, DoubleType, StructType} /** @@ -462,3 +462,74 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS /** Get old Gradient Boosting Loss type */ private[ml] def getOldLossType: OldLoss } + +private[ml] object GBTClassifierParams { + // The losses below should be lowercase. + /** Accessor for supported loss settings: logistic */ + final val supportedLossTypes: Array[String] = Array("logistic").map(_.toLowerCase) +} + +private[ml] trait GBTClassifierParams extends GBTParams with TreeClassifierParams { + + /** + * Loss function which GBT tries to minimize. (case-insensitive) + * Supported: "logistic" + * (default = logistic) + * @group param + */ + val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + + " tries to minimize (case-insensitive). Supported options:" + + s" ${GBTClassifierParams.supportedLossTypes.mkString(", ")}", + (value: String) => GBTClassifierParams.supportedLossTypes.contains(value.toLowerCase)) + + setDefault(lossType -> "logistic") + + /** @group getParam */ + def getLossType: String = $(lossType).toLowerCase + + /** (private[ml]) Convert new loss to old loss. */ + override private[ml] def getOldLossType: OldLoss = { + getLossType match { + case "logistic" => OldLogLoss + case _ => + // Should never happen because of check in setter method. + throw new RuntimeException(s"GBTClassifier was given bad loss type: $getLossType") + } + } +} + +private[ml] object GBTRegressorParams { + // The losses below should be lowercase. + /** Accessor for supported loss settings: squared (L2), absolute (L1) */ + final val supportedLossTypes: Array[String] = Array("squared", "absolute").map(_.toLowerCase) +} + +private[ml] trait GBTRegressorParams extends GBTParams with TreeRegressorParams { + + /** + * Loss function which GBT tries to minimize. (case-insensitive) + * Supported: "squared" (L2) and "absolute" (L1) + * (default = squared) + * @group param + */ + val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + + " tries to minimize (case-insensitive). Supported options:" + + s" ${GBTRegressorParams.supportedLossTypes.mkString(", ")}", + (value: String) => GBTRegressorParams.supportedLossTypes.contains(value.toLowerCase)) + + setDefault(lossType -> "squared") + + /** @group getParam */ + def getLossType: String = $(lossType).toLowerCase + + /** (private[ml]) Convert new loss to old loss. */ + override private[ml] def getOldLossType: OldLoss = { + getLossType match { + case "squared" => OldSquaredError + case "absolute" => OldAbsoluteError + case _ => + // Should never happen because of check in setter method. + throw new RuntimeException(s"GBTRegressorParams was given bad loss type: $getLossType") + } + } +} |