aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala73
1 files changed, 72 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index 0767dc17e5..b6783911ad 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -23,7 +23,7 @@ import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance}
-import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
+import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, LogLoss => OldLogLoss, Loss => OldLoss, SquaredError => OldSquaredError}
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
/**
@@ -462,3 +462,74 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS
/** Get old Gradient Boosting Loss type */
private[ml] def getOldLossType: OldLoss
}
+
+private[ml] object GBTClassifierParams {
+ // The losses below should be lowercase.
+ /** Accessor for supported loss settings: logistic */
+ final val supportedLossTypes: Array[String] = Array("logistic").map(_.toLowerCase)
+}
+
+private[ml] trait GBTClassifierParams extends GBTParams with TreeClassifierParams {
+
+ /**
+ * Loss function which GBT tries to minimize. (case-insensitive)
+ * Supported: "logistic"
+ * (default = logistic)
+ * @group param
+ */
+ val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" +
+ " tries to minimize (case-insensitive). Supported options:" +
+ s" ${GBTClassifierParams.supportedLossTypes.mkString(", ")}",
+ (value: String) => GBTClassifierParams.supportedLossTypes.contains(value.toLowerCase))
+
+ setDefault(lossType -> "logistic")
+
+ /** @group getParam */
+ def getLossType: String = $(lossType).toLowerCase
+
+ /** (private[ml]) Convert new loss to old loss. */
+ override private[ml] def getOldLossType: OldLoss = {
+ getLossType match {
+ case "logistic" => OldLogLoss
+ case _ =>
+ // Should never happen because of check in setter method.
+ throw new RuntimeException(s"GBTClassifier was given bad loss type: $getLossType")
+ }
+ }
+}
+
+private[ml] object GBTRegressorParams {
+ // The losses below should be lowercase.
+ /** Accessor for supported loss settings: squared (L2), absolute (L1) */
+ final val supportedLossTypes: Array[String] = Array("squared", "absolute").map(_.toLowerCase)
+}
+
+private[ml] trait GBTRegressorParams extends GBTParams with TreeRegressorParams {
+
+ /**
+ * Loss function which GBT tries to minimize. (case-insensitive)
+ * Supported: "squared" (L2) and "absolute" (L1)
+ * (default = squared)
+ * @group param
+ */
+ val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" +
+ " tries to minimize (case-insensitive). Supported options:" +
+ s" ${GBTRegressorParams.supportedLossTypes.mkString(", ")}",
+ (value: String) => GBTRegressorParams.supportedLossTypes.contains(value.toLowerCase))
+
+ setDefault(lossType -> "squared")
+
+ /** @group getParam */
+ def getLossType: String = $(lossType).toLowerCase
+
+ /** (private[ml]) Convert new loss to old loss. */
+ override private[ml] def getOldLossType: OldLoss = {
+ getLossType match {
+ case "squared" => OldSquaredError
+ case "absolute" => OldAbsoluteError
+ case _ =>
+ // Should never happen because of check in setter method.
+ throw new RuntimeException(s"GBTRegressorParams was given bad loss type: $getLossType")
+ }
+ }
+}