aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala11
3 files changed, 21 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
index df8eb5d1f9..c7cde1563f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
@@ -183,11 +183,16 @@ private[spark] object DecisionTreeMetadata extends Logging {
}
case _ => featureSubsetStrategy
}
+
+ val isIntRegex = "^([1-9]\\d*)$".r
+ val isFractionRegex = "^(0?\\.\\d*[1-9]\\d*|1\\.0+)$".r
val numFeaturesPerNode: Int = _featureSubsetStrategy match {
case "all" => numFeatures
case "sqrt" => math.sqrt(numFeatures).ceil.toInt
case "log2" => math.max(1, (math.log(numFeatures) / math.log(2)).ceil.toInt)
case "onethird" => (numFeatures / 3.0).ceil.toInt
+ case isIntRegex(number) => if (BigInt(number) > numFeatures) numFeatures else number.toInt
+ case isFractionRegex(fraction) => (fraction.toDouble * numFeatures).ceil.toInt
}
new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index 78e6d3bfac..0767dc17e5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -329,6 +329,8 @@ private[ml] trait HasFeatureSubsetStrategy extends Params {
* - "onethird": use 1/3 of the features
* - "sqrt": use sqrt(number of features)
* - "log2": use log2(number of features)
+ * - "n": when n is in the range (0, 1.0], use n * number of features. When n
+ * is in the range (1, number of features), use n features.
* (default = "auto")
*
* These various settings are based on the following references:
@@ -346,7 +348,8 @@ private[ml] trait HasFeatureSubsetStrategy extends Params {
"The number of features to consider for splits at each tree node." +
s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}",
(value: String) =>
- RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase))
+ RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase)
+ || value.matches(RandomForestParams.supportedFeatureSubsetStrategiesRegex))
setDefault(featureSubsetStrategy -> "auto")
@@ -393,6 +396,9 @@ private[spark] object RandomForestParams {
// These options should be lowercase.
final val supportedFeatureSubsetStrategies: Array[String] =
Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase)
+
+ // The regex to capture "(0.0-1.0]", and "n" for integer 0 < n <= (number of features)
+ final val supportedFeatureSubsetStrategiesRegex = "^(?:[1-9]\\d*|0?\\.\\d*[1-9]\\d*|1\\.0+)$"
}
private[ml] trait RandomForestClassifierParams
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index 1841fa4a95..26755849ad 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -55,10 +55,15 @@ import org.apache.spark.util.Utils
* @param numTrees If 1, then no bootstrapping is used. If > 1, then bootstrapping is done.
* @param featureSubsetStrategy Number of features to consider for splits at each node.
* Supported values: "auto", "all", "sqrt", "log2", "onethird".
+ * Supported numerical values: "(0.0-1.0]", "[1-n]".
* If "auto" is set, this parameter is set based on numTrees:
* if numTrees == 1, set to "all";
* if numTrees > 1 (forest) set to "sqrt" for classification and
* to "onethird" for regression.
+ * If a real value "n" in the range (0, 1.0] is set,
+ * use n * number of features.
+ * If an integer value "n" in the range (1, num features) is set,
+ * use n features.
* @param seed Random seed for bootstrapping and choosing feature subsets.
*/
private class RandomForest (
@@ -70,9 +75,11 @@ private class RandomForest (
strategy.assertValid()
require(numTrees > 0, s"RandomForest requires numTrees > 0, but was given numTrees = $numTrees.")
- require(RandomForest.supportedFeatureSubsetStrategies.contains(featureSubsetStrategy),
+ require(RandomForest.supportedFeatureSubsetStrategies.contains(featureSubsetStrategy)
+ || featureSubsetStrategy.matches(NewRFParams.supportedFeatureSubsetStrategiesRegex),
s"RandomForest given invalid featureSubsetStrategy: $featureSubsetStrategy." +
- s" Supported values: ${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}.")
+ s" Supported values: ${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}," +
+ s" (0.0-1.0], [1-n].")
/**
* Method to train a decision tree model over an RDD