aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala147
1 files changed, 124 insertions, 23 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index 4fbd957677..b6783911ad 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -23,7 +23,7 @@ import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance}
-import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
+import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, LogLoss => OldLogLoss, Loss => OldLoss, SquaredError => OldSquaredError}
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
/**
@@ -315,22 +315,8 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams {
}
}
-/**
- * Parameters for Random Forest algorithms.
- *
- * Note: Marked as private and DeveloperApi since this may be made public in the future.
- */
-private[ml] trait RandomForestParams extends TreeEnsembleParams {
-
- /**
- * Number of trees to train (>= 1).
- * If 1, then no bootstrapping is used. If > 1, then bootstrapping is done.
- * TODO: Change to always do bootstrapping (simpler). SPARK-7130
- * (default = 20)
- * @group param
- */
- final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)",
- ParamValidators.gtEq(1))
+/** Used for [[RandomForestParams]] */
+private[ml] trait HasFeatureSubsetStrategy extends Params {
/**
* The number of features to consider for splits at each tree node.
@@ -343,6 +329,8 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams {
* - "onethird": use 1/3 of the features
* - "sqrt": use sqrt(number of features)
* - "log2": use log2(number of features)
+ * - "n": when n is in the range (0, 1.0], use n * number of features. When n
+ * is in the range (1, number of features), use n features.
* (default = "auto")
*
* These various settings are based on the following references:
@@ -360,29 +348,71 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams {
"The number of features to consider for splits at each tree node." +
s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}",
(value: String) =>
- RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase))
+ RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase)
+ || value.matches(RandomForestParams.supportedFeatureSubsetStrategiesRegex))
- setDefault(numTrees -> 20, featureSubsetStrategy -> "auto")
+ setDefault(featureSubsetStrategy -> "auto")
/** @group setParam */
- def setNumTrees(value: Int): this.type = set(numTrees, value)
+ def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value)
/** @group getParam */
- final def getNumTrees: Int = $(numTrees)
+ final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase
+}
+
+/**
+ * Used for [[RandomForestParams]].
+ * This is separated out from [[RandomForestParams]] because of an issue with the
+ * `numTrees` method conflicting with this Param in the Estimator.
+ */
+private[ml] trait HasNumTrees extends Params {
+
+ /**
+ * Number of trees to train (>= 1).
+ * If 1, then no bootstrapping is used. If > 1, then bootstrapping is done.
+ * TODO: Change to always do bootstrapping (simpler). SPARK-7130
+ * (default = 20)
+ * @group param
+ */
+ final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)",
+ ParamValidators.gtEq(1))
+
+ setDefault(numTrees -> 20)
/** @group setParam */
- def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value)
+ def setNumTrees(value: Int): this.type = set(numTrees, value)
/** @group getParam */
- final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase
+ final def getNumTrees: Int = $(numTrees)
}
+/**
+ * Parameters for Random Forest algorithms.
+ */
+private[ml] trait RandomForestParams extends TreeEnsembleParams
+ with HasFeatureSubsetStrategy with HasNumTrees
+
private[spark] object RandomForestParams {
// These options should be lowercase.
final val supportedFeatureSubsetStrategies: Array[String] =
Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase)
+
+ // The regex to capture "(0.0-1.0]", and "n" for integer 0 < n <= (number of features)
+ final val supportedFeatureSubsetStrategiesRegex = "^(?:[1-9]\\d*|0?\\.\\d*[1-9]\\d*|1\\.0+)$"
}
+private[ml] trait RandomForestClassifierParams
+ extends RandomForestParams with TreeClassifierParams
+
+private[ml] trait RandomForestClassificationModelParams extends TreeEnsembleParams
+ with HasFeatureSubsetStrategy with TreeClassifierParams
+
+private[ml] trait RandomForestRegressorParams
+ extends RandomForestParams with TreeRegressorParams
+
+private[ml] trait RandomForestRegressionModelParams extends TreeEnsembleParams
+ with HasFeatureSubsetStrategy with TreeRegressorParams
+
/**
* Parameters for Gradient-Boosted Tree algorithms.
*
@@ -432,3 +462,74 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS
/** Get old Gradient Boosting Loss type */
private[ml] def getOldLossType: OldLoss
}
+
+private[ml] object GBTClassifierParams {
+ // The losses below should be lowercase.
+ /** Accessor for supported loss settings: logistic */
+ final val supportedLossTypes: Array[String] = Array("logistic").map(_.toLowerCase)
+}
+
+private[ml] trait GBTClassifierParams extends GBTParams with TreeClassifierParams {
+
+ /**
+ * Loss function which GBT tries to minimize. (case-insensitive)
+ * Supported: "logistic"
+ * (default = logistic)
+ * @group param
+ */
+ val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" +
+ " tries to minimize (case-insensitive). Supported options:" +
+ s" ${GBTClassifierParams.supportedLossTypes.mkString(", ")}",
+ (value: String) => GBTClassifierParams.supportedLossTypes.contains(value.toLowerCase))
+
+ setDefault(lossType -> "logistic")
+
+ /** @group getParam */
+ def getLossType: String = $(lossType).toLowerCase
+
+ /** (private[ml]) Convert new loss to old loss. */
+ override private[ml] def getOldLossType: OldLoss = {
+ getLossType match {
+ case "logistic" => OldLogLoss
+ case _ =>
+ // Should never happen because of check in setter method.
+ throw new RuntimeException(s"GBTClassifier was given bad loss type: $getLossType")
+ }
+ }
+}
+
+private[ml] object GBTRegressorParams {
+ // The losses below should be lowercase.
+ /** Accessor for supported loss settings: squared (L2), absolute (L1) */
+ final val supportedLossTypes: Array[String] = Array("squared", "absolute").map(_.toLowerCase)
+}
+
+private[ml] trait GBTRegressorParams extends GBTParams with TreeRegressorParams {
+
+ /**
+ * Loss function which GBT tries to minimize. (case-insensitive)
+ * Supported: "squared" (L2) and "absolute" (L1)
+ * (default = squared)
+ * @group param
+ */
+ val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" +
+ " tries to minimize (case-insensitive). Supported options:" +
+ s" ${GBTRegressorParams.supportedLossTypes.mkString(", ")}",
+ (value: String) => GBTRegressorParams.supportedLossTypes.contains(value.toLowerCase))
+
+ setDefault(lossType -> "squared")
+
+ /** @group getParam */
+ def getLossType: String = $(lossType).toLowerCase
+
+ /** (private[ml]) Convert new loss to old loss. */
+ override private[ml] def getOldLossType: OldLoss = {
+ getLossType match {
+ case "squared" => OldSquaredError
+ case "absolute" => OldAbsoluteError
+ case _ =>
+ // Should never happen because of check in setter method.
+ throw new RuntimeException(s"GBTRegressorParams was given bad loss type: $getLossType")
+ }
+ }
+}