From da60b34d2f6eba19633e4f1b46504ce92cd6c179 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Tue, 12 Apr 2016 16:53:26 +0200 Subject: [SPARK-3724][ML] RandomForest: More options for feature subset size. ## What changes were proposed in this pull request? This PR tries to support more options for feature subset size in RandomForest implementation. Previously, RandomForest only support "auto", "all", "sort", "log2", "onethird". This PR tries to support any given value to allow model search. In this PR, `featureSubsetStrategy` could be passed with: a) a real number in the range of `(0.0-1.0]` that represents the fraction of the number of features in each subset, b) an integer number (`>0`) that represents the number of features in each subset. ## How was this patch tested? Two tests `JavaRandomForestClassifierSuite` and `JavaRandomForestRegressorSuite` have been updated to check the additional options for params in this PR. An additional test has been added to `org.apache.spark.mllib.tree.RandomForestSuite` to cover the cases in this PR. Author: Yong Tang Closes #11989 from yongtang/SPARK-3724. --- .../JavaRandomForestClassifierSuite.java | 19 +++++++++++++++++++ .../ml/regression/JavaRandomForestRegressorSuite.java | 19 +++++++++++++++++++ 2 files changed, 38 insertions(+) (limited to 'mllib/src/test/java/org') diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java index 75061464e5..5aec52ac72 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java @@ -22,6 +22,7 @@ import java.util.HashMap; import java.util.Map; import org.junit.After; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; @@ -80,6 +81,24 @@ public class JavaRandomForestClassifierSuite implements Serializable { for (String featureSubsetStrategy: RandomForestClassifier.supportedFeatureSubsetStrategies()) { rf.setFeatureSubsetStrategy(featureSubsetStrategy); } + String realStrategies[] = {".1", ".10", "0.10", "0.1", "0.9", "1.0"}; + for (String strategy: realStrategies) { + rf.setFeatureSubsetStrategy(strategy); + } + String integerStrategies[] = {"1", "10", "100", "1000", "10000"}; + for (String strategy: integerStrategies) { + rf.setFeatureSubsetStrategy(strategy); + } + String invalidStrategies[] = {"-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0"}; + for (String strategy: invalidStrategies) { + try { + rf.setFeatureSubsetStrategy(strategy); + Assert.fail("Expected exception to be thrown for invalid strategies"); + } catch (Exception e) { + Assert.assertTrue(e instanceof IllegalArgumentException); + } + } + RandomForestClassificationModel model = rf.fit(dataFrame); model.transform(dataFrame); diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java index b6f793f6de..a8736669f7 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java @@ -22,6 +22,7 @@ import java.util.HashMap; import java.util.Map; import org.junit.After; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; @@ -80,6 +81,24 @@ public class JavaRandomForestRegressorSuite implements Serializable { for (String featureSubsetStrategy: RandomForestRegressor.supportedFeatureSubsetStrategies()) { rf.setFeatureSubsetStrategy(featureSubsetStrategy); } + String realStrategies[] = {".1", ".10", "0.10", "0.1", "0.9", "1.0"}; + for (String strategy: realStrategies) { + rf.setFeatureSubsetStrategy(strategy); + } + String integerStrategies[] = {"1", "10", "100", "1000", "10000"}; + for (String strategy: integerStrategies) { + rf.setFeatureSubsetStrategy(strategy); + } + String invalidStrategies[] = {"-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0"}; + for (String strategy: invalidStrategies) { + try { + rf.setFeatureSubsetStrategy(strategy); + Assert.fail("Expected exception to be thrown for invalid strategies"); + } catch (Exception e) { + Assert.assertTrue(e instanceof IllegalArgumentException); + } + } + RandomForestRegressionModel model = rf.fit(dataFrame); model.transform(dataFrame); -- cgit v1.2.3