aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala114
1 files changed, 71 insertions, 43 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index 0b52fe2d13..741724d7a1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -18,19 +18,20 @@
package org.apache.spark.ml.regression
import com.github.fommil.netlib.BLAS.{getInstance => blas}
+import org.json4s.{DefaultFormats, JObject}
+import org.json4s.JsonDSL._
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.internal.Logging
import org.apache.spark.ml.{PredictionModel, Predictor}
-import org.apache.spark.ml.param.{Param, ParamMap}
-import org.apache.spark.ml.tree.{GBTParams, TreeEnsembleModel, TreeRegressorParams}
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.impl.GradientBoostedTrees
-import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
+import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
-import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, Loss => OldLoss,
- SquaredError => OldSquaredError}
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset}
@@ -58,7 +59,7 @@ import org.apache.spark.sql.functions._
@Experimental
final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Predictor[Vector, GBTRegressor, GBTRegressionModel]
- with GBTParams with TreeRegressorParams with Logging {
+ with GBTRegressorParams with DefaultParamsWritable with Logging {
@Since("1.4.0")
def this() = this(Identifiable.randomUID("gbtr"))
@@ -112,41 +113,12 @@ final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: Stri
@Since("1.4.0")
override def setStepSize(value: Double): this.type = super.setStepSize(value)
- // Parameters for GBTRegressor:
-
- /**
- * Loss function which GBT tries to minimize. (case-insensitive)
- * Supported: "squared" (L2) and "absolute" (L1)
- * (default = squared)
- * @group param
- */
- @Since("1.4.0")
- val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" +
- " tries to minimize (case-insensitive). Supported options:" +
- s" ${GBTRegressor.supportedLossTypes.mkString(", ")}",
- (value: String) => GBTRegressor.supportedLossTypes.contains(value.toLowerCase))
-
- setDefault(lossType -> "squared")
+ // Parameters from GBTRegressorParams:
/** @group setParam */
@Since("1.4.0")
def setLossType(value: String): this.type = set(lossType, value)
- /** @group getParam */
- @Since("1.4.0")
- def getLossType: String = $(lossType).toLowerCase
-
- /** (private[ml]) Convert new loss to old loss. */
- override private[ml] def getOldLossType: OldLoss = {
- getLossType match {
- case "squared" => OldSquaredError
- case "absolute" => OldAbsoluteError
- case _ =>
- // Should never happen because of check in setter method.
- throw new RuntimeException(s"GBTRegressorParams was given bad loss type: $getLossType")
- }
- }
-
override protected def train(dataset: Dataset[_]): GBTRegressionModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
@@ -164,11 +136,14 @@ final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: Stri
@Since("1.4.0")
@Experimental
-object GBTRegressor {
- // The losses below should be lowercase.
+object GBTRegressor extends DefaultParamsReadable[GBTRegressor] {
+
/** Accessor for supported loss settings: squared (L2), absolute (L1) */
@Since("1.4.0")
- final val supportedLossTypes: Array[String] = Array("squared", "absolute").map(_.toLowerCase)
+ final val supportedLossTypes: Array[String] = GBTRegressorParams.supportedLossTypes
+
+ @Since("2.0.0")
+ override def load(path: String): GBTRegressor = super.load(path)
}
/**
@@ -188,7 +163,8 @@ final class GBTRegressionModel private[ml](
private val _treeWeights: Array[Double],
override val numFeatures: Int)
extends PredictionModel[Vector, GBTRegressionModel]
- with TreeEnsembleModel[DecisionTreeRegressionModel] with Serializable {
+ with GBTRegressorParams with TreeEnsembleModel[DecisionTreeRegressionModel]
+ with MLWritable with Serializable {
require(_trees.nonEmpty, "GBTRegressionModel requires at least 1 tree.")
require(_trees.length == _treeWeights.length, "GBTRegressionModel given trees, treeWeights of" +
@@ -255,12 +231,64 @@ final class GBTRegressionModel private[ml](
private[ml] def toOld: OldGBTModel = {
new OldGBTModel(OldAlgo.Regression, _trees.map(_.toOld), _treeWeights)
}
+
+ @Since("2.0.0")
+ override def write: MLWriter = new GBTRegressionModel.GBTRegressionModelWriter(this)
}
-private[ml] object GBTRegressionModel {
+@Since("2.0.0")
+object GBTRegressionModel extends MLReadable[GBTRegressionModel] {
+
+ @Since("2.0.0")
+ override def read: MLReader[GBTRegressionModel] = new GBTRegressionModelReader
+
+ @Since("2.0.0")
+ override def load(path: String): GBTRegressionModel = super.load(path)
+
+ private[GBTRegressionModel]
+ class GBTRegressionModelWriter(instance: GBTRegressionModel) extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ val extraMetadata: JObject = Map(
+ "numFeatures" -> instance.numFeatures,
+ "numTrees" -> instance.getNumTrees)
+ EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata)
+ }
+ }
+
+ private class GBTRegressionModelReader extends MLReader[GBTRegressionModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[GBTRegressionModel].getName
+ private val treeClassName = classOf[DecisionTreeRegressionModel].getName
+
+ override def load(path: String): GBTRegressionModel = {
+ implicit val format = DefaultFormats
+ val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) =
+ EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName)
+
+ val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
+ val numTrees = (metadata.metadata \ "numTrees").extract[Int]
+
+ val trees: Array[DecisionTreeRegressionModel] = treesData.map {
+ case (treeMetadata, root) =>
+ val tree =
+ new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures)
+ DefaultParamsReader.getAndSetParams(tree, treeMetadata)
+ tree
+ }
+
+ require(numTrees == trees.length, s"GBTRegressionModel.load expected $numTrees" +
+ s" trees based on metadata but found ${trees.length} trees.")
+
+ val model = new GBTRegressionModel(metadata.uid, trees, treeWeights, numFeatures)
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
- /** (private[ml]) Convert a model from the old API */
- def fromOld(
+ /** Convert a model from the old API */
+ private[ml] def fromOld(
oldModel: OldGBTModel,
parent: GBTRegressor,
categoricalFeatures: Map[Int, Int],