aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java')
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java19
1 files changed, 19 insertions, 0 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
index b6f793f6de..a8736669f7 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
@@ -22,6 +22,7 @@ import java.util.HashMap;
import java.util.Map;
import org.junit.After;
+import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
@@ -80,6 +81,24 @@ public class JavaRandomForestRegressorSuite implements Serializable {
for (String featureSubsetStrategy: RandomForestRegressor.supportedFeatureSubsetStrategies()) {
rf.setFeatureSubsetStrategy(featureSubsetStrategy);
}
+ String realStrategies[] = {".1", ".10", "0.10", "0.1", "0.9", "1.0"};
+ for (String strategy: realStrategies) {
+ rf.setFeatureSubsetStrategy(strategy);
+ }
+ String integerStrategies[] = {"1", "10", "100", "1000", "10000"};
+ for (String strategy: integerStrategies) {
+ rf.setFeatureSubsetStrategy(strategy);
+ }
+ String invalidStrategies[] = {"-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0"};
+ for (String strategy: invalidStrategies) {
+ try {
+ rf.setFeatureSubsetStrategy(strategy);
+ Assert.fail("Expected exception to be thrown for invalid strategies");
+ } catch (Exception e) {
+ Assert.assertTrue(e instanceof IllegalArgumentException);
+ }
+ }
+
RandomForestRegressionModel model = rf.fit(dataFrame);
model.transform(dataFrame);