aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala11
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java19
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java19
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala36
6 files changed, 95 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
index df8eb5d1f9..c7cde1563f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
@@ -183,11 +183,16 @@ private[spark] object DecisionTreeMetadata extends Logging {
}
case _ => featureSubsetStrategy
}
+
+ val isIntRegex = "^([1-9]\\d*)$".r
+ val isFractionRegex = "^(0?\\.\\d*[1-9]\\d*|1\\.0+)$".r
val numFeaturesPerNode: Int = _featureSubsetStrategy match {
case "all" => numFeatures
case "sqrt" => math.sqrt(numFeatures).ceil.toInt
case "log2" => math.max(1, (math.log(numFeatures) / math.log(2)).ceil.toInt)
case "onethird" => (numFeatures / 3.0).ceil.toInt
+ case isIntRegex(number) => if (BigInt(number) > numFeatures) numFeatures else number.toInt
+ case isFractionRegex(fraction) => (fraction.toDouble * numFeatures).ceil.toInt
}
new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index 78e6d3bfac..0767dc17e5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -329,6 +329,8 @@ private[ml] trait HasFeatureSubsetStrategy extends Params {
* - "onethird": use 1/3 of the features
* - "sqrt": use sqrt(number of features)
* - "log2": use log2(number of features)
+ * - "n": when n is in the range (0, 1.0], use n * number of features. When n
+ * is in the range (1, number of features), use n features.
* (default = "auto")
*
* These various settings are based on the following references:
@@ -346,7 +348,8 @@ private[ml] trait HasFeatureSubsetStrategy extends Params {
"The number of features to consider for splits at each tree node." +
s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}",
(value: String) =>
- RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase))
+ RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase)
+ || value.matches(RandomForestParams.supportedFeatureSubsetStrategiesRegex))
setDefault(featureSubsetStrategy -> "auto")
@@ -393,6 +396,9 @@ private[spark] object RandomForestParams {
// These options should be lowercase.
final val supportedFeatureSubsetStrategies: Array[String] =
Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase)
+
+ // The regex to capture "(0.0-1.0]", and "n" for integer 0 < n <= (number of features)
+ final val supportedFeatureSubsetStrategiesRegex = "^(?:[1-9]\\d*|0?\\.\\d*[1-9]\\d*|1\\.0+)$"
}
private[ml] trait RandomForestClassifierParams
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index 1841fa4a95..26755849ad 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -55,10 +55,15 @@ import org.apache.spark.util.Utils
* @param numTrees If 1, then no bootstrapping is used. If > 1, then bootstrapping is done.
* @param featureSubsetStrategy Number of features to consider for splits at each node.
* Supported values: "auto", "all", "sqrt", "log2", "onethird".
+ * Supported numerical values: "(0.0-1.0]", "[1-n]".
* If "auto" is set, this parameter is set based on numTrees:
* if numTrees == 1, set to "all";
* if numTrees > 1 (forest) set to "sqrt" for classification and
* to "onethird" for regression.
+ * If a real value "n" in the range (0, 1.0] is set,
+ * use n * number of features.
+ * If an integer value "n" in the range (1, num features) is set,
+ * use n features.
* @param seed Random seed for bootstrapping and choosing feature subsets.
*/
private class RandomForest (
@@ -70,9 +75,11 @@ private class RandomForest (
strategy.assertValid()
require(numTrees > 0, s"RandomForest requires numTrees > 0, but was given numTrees = $numTrees.")
- require(RandomForest.supportedFeatureSubsetStrategies.contains(featureSubsetStrategy),
+ require(RandomForest.supportedFeatureSubsetStrategies.contains(featureSubsetStrategy)
+ || featureSubsetStrategy.matches(NewRFParams.supportedFeatureSubsetStrategiesRegex),
s"RandomForest given invalid featureSubsetStrategy: $featureSubsetStrategy." +
- s" Supported values: ${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}.")
+ s" Supported values: ${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}," +
+ s" (0.0-1.0], [1-n].")
/**
* Method to train a decision tree model over an RDD
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
index 75061464e5..5aec52ac72 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
@@ -22,6 +22,7 @@ import java.util.HashMap;
import java.util.Map;
import org.junit.After;
+import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
@@ -80,6 +81,24 @@ public class JavaRandomForestClassifierSuite implements Serializable {
for (String featureSubsetStrategy: RandomForestClassifier.supportedFeatureSubsetStrategies()) {
rf.setFeatureSubsetStrategy(featureSubsetStrategy);
}
+ String realStrategies[] = {".1", ".10", "0.10", "0.1", "0.9", "1.0"};
+ for (String strategy: realStrategies) {
+ rf.setFeatureSubsetStrategy(strategy);
+ }
+ String integerStrategies[] = {"1", "10", "100", "1000", "10000"};
+ for (String strategy: integerStrategies) {
+ rf.setFeatureSubsetStrategy(strategy);
+ }
+ String invalidStrategies[] = {"-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0"};
+ for (String strategy: invalidStrategies) {
+ try {
+ rf.setFeatureSubsetStrategy(strategy);
+ Assert.fail("Expected exception to be thrown for invalid strategies");
+ } catch (Exception e) {
+ Assert.assertTrue(e instanceof IllegalArgumentException);
+ }
+ }
+
RandomForestClassificationModel model = rf.fit(dataFrame);
model.transform(dataFrame);
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
index b6f793f6de..a8736669f7 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
@@ -22,6 +22,7 @@ import java.util.HashMap;
import java.util.Map;
import org.junit.After;
+import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
@@ -80,6 +81,24 @@ public class JavaRandomForestRegressorSuite implements Serializable {
for (String featureSubsetStrategy: RandomForestRegressor.supportedFeatureSubsetStrategies()) {
rf.setFeatureSubsetStrategy(featureSubsetStrategy);
}
+ String realStrategies[] = {".1", ".10", "0.10", "0.1", "0.9", "1.0"};
+ for (String strategy: realStrategies) {
+ rf.setFeatureSubsetStrategy(strategy);
+ }
+ String integerStrategies[] = {"1", "10", "100", "1000", "10000"};
+ for (String strategy: integerStrategies) {
+ rf.setFeatureSubsetStrategy(strategy);
+ }
+ String invalidStrategies[] = {"-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0"};
+ for (String strategy: invalidStrategies) {
+ try {
+ rf.setFeatureSubsetStrategy(strategy);
+ Assert.fail("Expected exception to be thrown for invalid strategies");
+ } catch (Exception e) {
+ Assert.assertTrue(e instanceof IllegalArgumentException);
+ }
+ }
+
RandomForestRegressionModel model = rf.fit(dataFrame);
model.transform(dataFrame);
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
index cd402b1e1f..6db9ce150d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
@@ -426,12 +426,48 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
(math.log(numFeatures) / math.log(2)).ceil.toInt)
checkFeatureSubsetStrategy(numTrees = 1, "onethird", (numFeatures / 3.0).ceil.toInt)
+ val realStrategies = Array(".1", ".10", "0.10", "0.1", "0.9", "1.0")
+ for (strategy <- realStrategies) {
+ val expected = (strategy.toDouble * numFeatures).ceil.toInt
+ checkFeatureSubsetStrategy(numTrees = 1, strategy, expected)
+ }
+
+ val integerStrategies = Array("1", "10", "100", "1000", "10000")
+ for (strategy <- integerStrategies) {
+ val expected = if (strategy.toInt < numFeatures) strategy.toInt else numFeatures
+ checkFeatureSubsetStrategy(numTrees = 1, strategy, expected)
+ }
+
+ val invalidStrategies = Array("-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0")
+ for (invalidStrategy <- invalidStrategies) {
+ intercept[MatchError]{
+ val metadata =
+ DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees = 1, invalidStrategy)
+ }
+ }
+
checkFeatureSubsetStrategy(numTrees = 2, "all", numFeatures)
checkFeatureSubsetStrategy(numTrees = 2, "auto", math.sqrt(numFeatures).ceil.toInt)
checkFeatureSubsetStrategy(numTrees = 2, "sqrt", math.sqrt(numFeatures).ceil.toInt)
checkFeatureSubsetStrategy(numTrees = 2, "log2",
(math.log(numFeatures) / math.log(2)).ceil.toInt)
checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 3.0).ceil.toInt)
+
+ for (strategy <- realStrategies) {
+ val expected = (strategy.toDouble * numFeatures).ceil.toInt
+ checkFeatureSubsetStrategy(numTrees = 2, strategy, expected)
+ }
+
+ for (strategy <- integerStrategies) {
+ val expected = if (strategy.toInt < numFeatures) strategy.toInt else numFeatures
+ checkFeatureSubsetStrategy(numTrees = 2, strategy, expected)
+ }
+ for (invalidStrategy <- invalidStrategies) {
+ intercept[MatchError]{
+ val metadata =
+ DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees = 2, invalidStrategy)
+ }
+ }
}
test("Binary classification with continuous features: subsampling features") {