aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala
diff options
context:
space:
mode:
authorYong Tang <yong.tang.github@outlook.com>2016-04-12 16:53:26 +0200
committerNick Pentreath <nick.pentreath@gmail.com>2016-04-12 16:53:26 +0200
commitda60b34d2f6eba19633e4f1b46504ce92cd6c179 (patch)
tree0e86dad722e512c70e6936aa89dee567f956629c /mllib/src/test/scala
parent124cbfb683a5e959e1b5181d4d0cc56956b50385 (diff)
downloadspark-da60b34d2f6eba19633e4f1b46504ce92cd6c179.tar.gz
spark-da60b34d2f6eba19633e4f1b46504ce92cd6c179.tar.bz2
spark-da60b34d2f6eba19633e4f1b46504ce92cd6c179.zip
[SPARK-3724][ML] RandomForest: More options for feature subset size.
## What changes were proposed in this pull request? This PR tries to support more options for feature subset size in RandomForest implementation. Previously, RandomForest only support "auto", "all", "sort", "log2", "onethird". This PR tries to support any given value to allow model search. In this PR, `featureSubsetStrategy` could be passed with: a) a real number in the range of `(0.0-1.0]` that represents the fraction of the number of features in each subset, b) an integer number (`>0`) that represents the number of features in each subset. ## How was this patch tested? Two tests `JavaRandomForestClassifierSuite` and `JavaRandomForestRegressorSuite` have been updated to check the additional options for params in this PR. An additional test has been added to `org.apache.spark.mllib.tree.RandomForestSuite` to cover the cases in this PR. Author: Yong Tang <yong.tang.github@outlook.com> Closes #11989 from yongtang/SPARK-3724.
Diffstat (limited to 'mllib/src/test/scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala36
1 files changed, 36 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
index cd402b1e1f..6db9ce150d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
@@ -426,12 +426,48 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
(math.log(numFeatures) / math.log(2)).ceil.toInt)
checkFeatureSubsetStrategy(numTrees = 1, "onethird", (numFeatures / 3.0).ceil.toInt)
+ val realStrategies = Array(".1", ".10", "0.10", "0.1", "0.9", "1.0")
+ for (strategy <- realStrategies) {
+ val expected = (strategy.toDouble * numFeatures).ceil.toInt
+ checkFeatureSubsetStrategy(numTrees = 1, strategy, expected)
+ }
+
+ val integerStrategies = Array("1", "10", "100", "1000", "10000")
+ for (strategy <- integerStrategies) {
+ val expected = if (strategy.toInt < numFeatures) strategy.toInt else numFeatures
+ checkFeatureSubsetStrategy(numTrees = 1, strategy, expected)
+ }
+
+ val invalidStrategies = Array("-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0")
+ for (invalidStrategy <- invalidStrategies) {
+ intercept[MatchError]{
+ val metadata =
+ DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees = 1, invalidStrategy)
+ }
+ }
+
checkFeatureSubsetStrategy(numTrees = 2, "all", numFeatures)
checkFeatureSubsetStrategy(numTrees = 2, "auto", math.sqrt(numFeatures).ceil.toInt)
checkFeatureSubsetStrategy(numTrees = 2, "sqrt", math.sqrt(numFeatures).ceil.toInt)
checkFeatureSubsetStrategy(numTrees = 2, "log2",
(math.log(numFeatures) / math.log(2)).ceil.toInt)
checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 3.0).ceil.toInt)
+
+ for (strategy <- realStrategies) {
+ val expected = (strategy.toDouble * numFeatures).ceil.toInt
+ checkFeatureSubsetStrategy(numTrees = 2, strategy, expected)
+ }
+
+ for (strategy <- integerStrategies) {
+ val expected = if (strategy.toInt < numFeatures) strategy.toInt else numFeatures
+ checkFeatureSubsetStrategy(numTrees = 2, strategy, expected)
+ }
+ for (invalidStrategy <- invalidStrategies) {
+ intercept[MatchError]{
+ val metadata =
+ DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees = 2, invalidStrategy)
+ }
+ }
}
test("Binary classification with continuous features: subsampling features") {