aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache
diff options
context:
space:
mode:
authorYong Tang <yong.tang.github@outlook.com>2016-04-12 16:53:26 +0200
committerNick Pentreath <nick.pentreath@gmail.com>2016-04-12 16:53:26 +0200
commitda60b34d2f6eba19633e4f1b46504ce92cd6c179 (patch)
tree0e86dad722e512c70e6936aa89dee567f956629c /mllib/src/main/scala/org/apache
parent124cbfb683a5e959e1b5181d4d0cc56956b50385 (diff)
downloadspark-da60b34d2f6eba19633e4f1b46504ce92cd6c179.tar.gz
spark-da60b34d2f6eba19633e4f1b46504ce92cd6c179.tar.bz2
spark-da60b34d2f6eba19633e4f1b46504ce92cd6c179.zip
[SPARK-3724][ML] RandomForest: More options for feature subset size.
## What changes were proposed in this pull request? This PR tries to support more options for feature subset size in RandomForest implementation. Previously, RandomForest only support "auto", "all", "sort", "log2", "onethird". This PR tries to support any given value to allow model search. In this PR, `featureSubsetStrategy` could be passed with: a) a real number in the range of `(0.0-1.0]` that represents the fraction of the number of features in each subset, b) an integer number (`>0`) that represents the number of features in each subset. ## How was this patch tested? Two tests `JavaRandomForestClassifierSuite` and `JavaRandomForestRegressorSuite` have been updated to check the additional options for params in this PR. An additional test has been added to `org.apache.spark.mllib.tree.RandomForestSuite` to cover the cases in this PR. Author: Yong Tang <yong.tang.github@outlook.com> Closes #11989 from yongtang/SPARK-3724.
Diffstat (limited to 'mllib/src/main/scala/org/apache')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala11
3 files changed, 21 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
index df8eb5d1f9..c7cde1563f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
@@ -183,11 +183,16 @@ private[spark] object DecisionTreeMetadata extends Logging {
}
case _ => featureSubsetStrategy
}
+
+ val isIntRegex = "^([1-9]\\d*)$".r
+ val isFractionRegex = "^(0?\\.\\d*[1-9]\\d*|1\\.0+)$".r
val numFeaturesPerNode: Int = _featureSubsetStrategy match {
case "all" => numFeatures
case "sqrt" => math.sqrt(numFeatures).ceil.toInt
case "log2" => math.max(1, (math.log(numFeatures) / math.log(2)).ceil.toInt)
case "onethird" => (numFeatures / 3.0).ceil.toInt
+ case isIntRegex(number) => if (BigInt(number) > numFeatures) numFeatures else number.toInt
+ case isFractionRegex(fraction) => (fraction.toDouble * numFeatures).ceil.toInt
}
new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index 78e6d3bfac..0767dc17e5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -329,6 +329,8 @@ private[ml] trait HasFeatureSubsetStrategy extends Params {
* - "onethird": use 1/3 of the features
* - "sqrt": use sqrt(number of features)
* - "log2": use log2(number of features)
+ * - "n": when n is in the range (0, 1.0], use n * number of features. When n
+ * is in the range (1, number of features), use n features.
* (default = "auto")
*
* These various settings are based on the following references:
@@ -346,7 +348,8 @@ private[ml] trait HasFeatureSubsetStrategy extends Params {
"The number of features to consider for splits at each tree node." +
s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}",
(value: String) =>
- RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase))
+ RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase)
+ || value.matches(RandomForestParams.supportedFeatureSubsetStrategiesRegex))
setDefault(featureSubsetStrategy -> "auto")
@@ -393,6 +396,9 @@ private[spark] object RandomForestParams {
// These options should be lowercase.
final val supportedFeatureSubsetStrategies: Array[String] =
Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase)
+
+ // The regex to capture "(0.0-1.0]", and "n" for integer 0 < n <= (number of features)
+ final val supportedFeatureSubsetStrategiesRegex = "^(?:[1-9]\\d*|0?\\.\\d*[1-9]\\d*|1\\.0+)$"
}
private[ml] trait RandomForestClassifierParams
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index 1841fa4a95..26755849ad 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -55,10 +55,15 @@ import org.apache.spark.util.Utils
* @param numTrees If 1, then no bootstrapping is used. If > 1, then bootstrapping is done.
* @param featureSubsetStrategy Number of features to consider for splits at each node.
* Supported values: "auto", "all", "sqrt", "log2", "onethird".
+ * Supported numerical values: "(0.0-1.0]", "[1-n]".
* If "auto" is set, this parameter is set based on numTrees:
* if numTrees == 1, set to "all";
* if numTrees > 1 (forest) set to "sqrt" for classification and
* to "onethird" for regression.
+ * If a real value "n" in the range (0, 1.0] is set,
+ * use n * number of features.
+ * If an integer value "n" in the range (1, num features) is set,
+ * use n features.
* @param seed Random seed for bootstrapping and choosing feature subsets.
*/
private class RandomForest (
@@ -70,9 +75,11 @@ private class RandomForest (
strategy.assertValid()
require(numTrees > 0, s"RandomForest requires numTrees > 0, but was given numTrees = $numTrees.")
- require(RandomForest.supportedFeatureSubsetStrategies.contains(featureSubsetStrategy),
+ require(RandomForest.supportedFeatureSubsetStrategies.contains(featureSubsetStrategy)
+ || featureSubsetStrategy.matches(NewRFParams.supportedFeatureSubsetStrategiesRegex),
s"RandomForest given invalid featureSubsetStrategy: $featureSubsetStrategy." +
- s" Supported values: ${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}.")
+ s" Supported values: ${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}," +
+ s" (0.0-1.0], [1-n].")
/**
* Method to train a decision tree model over an RDD