diff options
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala | 147 |
1 files changed, 124 insertions, 23 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 4fbd957677..b6783911ad 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -23,7 +23,7 @@ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance} -import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} +import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, LogLoss => OldLogLoss, Loss => OldLoss, SquaredError => OldSquaredError} import org.apache.spark.sql.types.{DataType, DoubleType, StructType} /** @@ -315,22 +315,8 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams { } } -/** - * Parameters for Random Forest algorithms. - * - * Note: Marked as private and DeveloperApi since this may be made public in the future. - */ -private[ml] trait RandomForestParams extends TreeEnsembleParams { - - /** - * Number of trees to train (>= 1). - * If 1, then no bootstrapping is used. If > 1, then bootstrapping is done. - * TODO: Change to always do bootstrapping (simpler). SPARK-7130 - * (default = 20) - * @group param - */ - final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)", - ParamValidators.gtEq(1)) +/** Used for [[RandomForestParams]] */ +private[ml] trait HasFeatureSubsetStrategy extends Params { /** * The number of features to consider for splits at each tree node. @@ -343,6 +329,8 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams { * - "onethird": use 1/3 of the features * - "sqrt": use sqrt(number of features) * - "log2": use log2(number of features) + * - "n": when n is in the range (0, 1.0], use n * number of features. When n + * is in the range (1, number of features), use n features. * (default = "auto") * * These various settings are based on the following references: @@ -360,29 +348,71 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams { "The number of features to consider for splits at each tree node." + s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}", (value: String) => - RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase)) + RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase) + || value.matches(RandomForestParams.supportedFeatureSubsetStrategiesRegex)) - setDefault(numTrees -> 20, featureSubsetStrategy -> "auto") + setDefault(featureSubsetStrategy -> "auto") /** @group setParam */ - def setNumTrees(value: Int): this.type = set(numTrees, value) + def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) /** @group getParam */ - final def getNumTrees: Int = $(numTrees) + final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase +} + +/** + * Used for [[RandomForestParams]]. + * This is separated out from [[RandomForestParams]] because of an issue with the + * `numTrees` method conflicting with this Param in the Estimator. + */ +private[ml] trait HasNumTrees extends Params { + + /** + * Number of trees to train (>= 1). + * If 1, then no bootstrapping is used. If > 1, then bootstrapping is done. + * TODO: Change to always do bootstrapping (simpler). SPARK-7130 + * (default = 20) + * @group param + */ + final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)", + ParamValidators.gtEq(1)) + + setDefault(numTrees -> 20) /** @group setParam */ - def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) + def setNumTrees(value: Int): this.type = set(numTrees, value) /** @group getParam */ - final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase + final def getNumTrees: Int = $(numTrees) } +/** + * Parameters for Random Forest algorithms. + */ +private[ml] trait RandomForestParams extends TreeEnsembleParams + with HasFeatureSubsetStrategy with HasNumTrees + private[spark] object RandomForestParams { // These options should be lowercase. final val supportedFeatureSubsetStrategies: Array[String] = Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase) + + // The regex to capture "(0.0-1.0]", and "n" for integer 0 < n <= (number of features) + final val supportedFeatureSubsetStrategiesRegex = "^(?:[1-9]\\d*|0?\\.\\d*[1-9]\\d*|1\\.0+)$" } +private[ml] trait RandomForestClassifierParams + extends RandomForestParams with TreeClassifierParams + +private[ml] trait RandomForestClassificationModelParams extends TreeEnsembleParams + with HasFeatureSubsetStrategy with TreeClassifierParams + +private[ml] trait RandomForestRegressorParams + extends RandomForestParams with TreeRegressorParams + +private[ml] trait RandomForestRegressionModelParams extends TreeEnsembleParams + with HasFeatureSubsetStrategy with TreeRegressorParams + /** * Parameters for Gradient-Boosted Tree algorithms. * @@ -432,3 +462,74 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS /** Get old Gradient Boosting Loss type */ private[ml] def getOldLossType: OldLoss } + +private[ml] object GBTClassifierParams { + // The losses below should be lowercase. + /** Accessor for supported loss settings: logistic */ + final val supportedLossTypes: Array[String] = Array("logistic").map(_.toLowerCase) +} + +private[ml] trait GBTClassifierParams extends GBTParams with TreeClassifierParams { + + /** + * Loss function which GBT tries to minimize. (case-insensitive) + * Supported: "logistic" + * (default = logistic) + * @group param + */ + val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + + " tries to minimize (case-insensitive). Supported options:" + + s" ${GBTClassifierParams.supportedLossTypes.mkString(", ")}", + (value: String) => GBTClassifierParams.supportedLossTypes.contains(value.toLowerCase)) + + setDefault(lossType -> "logistic") + + /** @group getParam */ + def getLossType: String = $(lossType).toLowerCase + + /** (private[ml]) Convert new loss to old loss. */ + override private[ml] def getOldLossType: OldLoss = { + getLossType match { + case "logistic" => OldLogLoss + case _ => + // Should never happen because of check in setter method. + throw new RuntimeException(s"GBTClassifier was given bad loss type: $getLossType") + } + } +} + +private[ml] object GBTRegressorParams { + // The losses below should be lowercase. + /** Accessor for supported loss settings: squared (L2), absolute (L1) */ + final val supportedLossTypes: Array[String] = Array("squared", "absolute").map(_.toLowerCase) +} + +private[ml] trait GBTRegressorParams extends GBTParams with TreeRegressorParams { + + /** + * Loss function which GBT tries to minimize. (case-insensitive) + * Supported: "squared" (L2) and "absolute" (L1) + * (default = squared) + * @group param + */ + val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + + " tries to minimize (case-insensitive). Supported options:" + + s" ${GBTRegressorParams.supportedLossTypes.mkString(", ")}", + (value: String) => GBTRegressorParams.supportedLossTypes.contains(value.toLowerCase)) + + setDefault(lossType -> "squared") + + /** @group getParam */ + def getLossType: String = $(lossType).toLowerCase + + /** (private[ml]) Convert new loss to old loss. */ + override private[ml] def getOldLossType: OldLoss = { + getLossType match { + case "squared" => OldSquaredError + case "absolute" => OldAbsoluteError + case _ => + // Should never happen because of check in setter method. + throw new RuntimeException(s"GBTRegressorParams was given bad loss type: $getLossType") + } + } +} |