aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/java/org/apache
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-04-25 12:27:19 -0700
committerXiangrui Meng <meng@databricks.com>2015-04-25 12:27:19 -0700
commita7160c4e3aae22600d05e257d0b4d2428754b8ea (patch)
tree55ce372698ad266b258dceb469e8a6a9ae9709b0 /mllib/src/test/java/org/apache
parenta61d65fc8b97c01be0fa756b52afdc91c46a8561 (diff)
downloadspark-a7160c4e3aae22600d05e257d0b4d2428754b8ea.tar.gz
spark-a7160c4e3aae22600d05e257d0b4d2428754b8ea.tar.bz2
spark-a7160c4e3aae22600d05e257d0b4d2428754b8ea.zip
[SPARK-6113] [ML] Tree ensembles for Pipelines API
This is a continuation of [https://github.com/apache/spark/pull/5530] (which was for Decision Trees), but for ensembles: Random Forests and Gradient-Boosted Trees. Please refer to the JIRA [https://issues.apache.org/jira/browse/SPARK-6113], the design doc linked from the JIRA, and the previous PR linked above for design discussions. This PR follows the example set by the previous PR for Decision Trees. It includes a few cleanups to Decision Trees. Note: There is one issue which will be addressed in a separate PR: Ensembles' component Models have no parent or fittingParamMap. I plan to submit a separate PR which makes those values in Model be Options. It does not matter much which PR gets merged first. CC: mengxr manishamde codedeft chouqin Author: Joseph K. Bradley <joseph@databricks.com> Closes #5626 from jkbradley/dt-api-ensembles and squashes the following commits: 729167a [Joseph K. Bradley] small cleanups based on code review bbae2a2 [Joseph K. Bradley] Updated per all comments in code review 855aa9a [Joseph K. Bradley] scala style fix ea3d901 [Joseph K. Bradley] Added GBT to spark.ml, with tests and examples c0f30c1 [Joseph K. Bradley] Added random forests and test suites to spark.ml. Not tested yet. Need to add example as well d045ebd [Joseph K. Bradley] some more updates, but far from done ee1a10b [Joseph K. Bradley] Added files from old PR and did some initial updates.
Diffstat (limited to 'mllib/src/test/java/org/apache')
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java10
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java100
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java103
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java26
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java99
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java102
6 files changed, 420 insertions, 20 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
index 43b8787f9d..60f25e5cce 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
@@ -17,7 +17,6 @@
package org.apache.spark.ml.classification;
-import java.io.File;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
@@ -32,7 +31,6 @@ import org.apache.spark.ml.impl.TreeTests;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.DataFrame;
-import org.apache.spark.util.Utils;
public class JavaDecisionTreeClassifierSuite implements Serializable {
@@ -57,7 +55,7 @@ public class JavaDecisionTreeClassifierSuite implements Serializable {
double B = -1.5;
JavaRDD<LabeledPoint> data = sc.parallelize(
- LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
+ LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
Map<Integer, Integer> categoricalFeatures = new HashMap<Integer, Integer>();
DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
@@ -71,8 +69,8 @@ public class JavaDecisionTreeClassifierSuite implements Serializable {
.setCacheNodeIds(false)
.setCheckpointInterval(10)
.setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
- for (int i = 0; i < DecisionTreeClassifier.supportedImpurities().length; ++i) {
- dt.setImpurity(DecisionTreeClassifier.supportedImpurities()[i]);
+ for (String impurity: DecisionTreeClassifier.supportedImpurities()) {
+ dt.setImpurity(impurity);
}
DecisionTreeClassificationModel model = dt.fit(dataFrame);
@@ -82,7 +80,7 @@ public class JavaDecisionTreeClassifierSuite implements Serializable {
model.toDebugString();
/*
- // TODO: Add test once save/load are implemented.
+ // TODO: Add test once save/load are implemented. SPARK-6725
File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark");
String path = tempDir.toURI().toString();
try {
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java
new file mode 100644
index 0000000000..3c69467fa1
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification;
+
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.impl.TreeTests;
+import org.apache.spark.mllib.classification.LogisticRegressionSuite;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.DataFrame;
+
+
+public class JavaGBTClassifierSuite implements Serializable {
+
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaGBTClassifierSuite");
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ }
+
+ @Test
+ public void runDT() {
+ int nPoints = 20;
+ double A = 2.0;
+ double B = -1.5;
+
+ JavaRDD<LabeledPoint> data = sc.parallelize(
+ LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
+ Map<Integer, Integer> categoricalFeatures = new HashMap<Integer, Integer>();
+ DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
+
+ // This tests setters. Training with various options is tested in Scala.
+ GBTClassifier rf = new GBTClassifier()
+ .setMaxDepth(2)
+ .setMaxBins(10)
+ .setMinInstancesPerNode(5)
+ .setMinInfoGain(0.0)
+ .setMaxMemoryInMB(256)
+ .setCacheNodeIds(false)
+ .setCheckpointInterval(10)
+ .setSubsamplingRate(1.0)
+ .setSeed(1234)
+ .setMaxIter(3)
+ .setStepSize(0.1)
+ .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
+ for (String lossType: GBTClassifier.supportedLossTypes()) {
+ rf.setLossType(lossType);
+ }
+ GBTClassificationModel model = rf.fit(dataFrame);
+
+ model.transform(dataFrame);
+ model.totalNumNodes();
+ model.toDebugString();
+ model.trees();
+ model.treeWeights();
+
+ /*
+ // TODO: Add test once save/load are implemented. SPARK-6725
+ File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark");
+ String path = tempDir.toURI().toString();
+ try {
+ model3.save(sc.sc(), path);
+ GBTClassificationModel sameModel = GBTClassificationModel.load(sc.sc(), path);
+ TreeTests.checkEqual(model3, sameModel);
+ } finally {
+ Utils.deleteRecursively(tempDir);
+ }
+ */
+ }
+}
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
new file mode 100644
index 0000000000..32d0b3856b
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification;
+
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.impl.TreeTests;
+import org.apache.spark.mllib.classification.LogisticRegressionSuite;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.DataFrame;
+
+
+public class JavaRandomForestClassifierSuite implements Serializable {
+
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaRandomForestClassifierSuite");
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ }
+
+ @Test
+ public void runDT() {
+ int nPoints = 20;
+ double A = 2.0;
+ double B = -1.5;
+
+ JavaRDD<LabeledPoint> data = sc.parallelize(
+ LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
+ Map<Integer, Integer> categoricalFeatures = new HashMap<Integer, Integer>();
+ DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
+
+ // This tests setters. Training with various options is tested in Scala.
+ RandomForestClassifier rf = new RandomForestClassifier()
+ .setMaxDepth(2)
+ .setMaxBins(10)
+ .setMinInstancesPerNode(5)
+ .setMinInfoGain(0.0)
+ .setMaxMemoryInMB(256)
+ .setCacheNodeIds(false)
+ .setCheckpointInterval(10)
+ .setSubsamplingRate(1.0)
+ .setSeed(1234)
+ .setNumTrees(3)
+ .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
+ for (String impurity: RandomForestClassifier.supportedImpurities()) {
+ rf.setImpurity(impurity);
+ }
+ for (String featureSubsetStrategy: RandomForestClassifier.supportedFeatureSubsetStrategies()) {
+ rf.setFeatureSubsetStrategy(featureSubsetStrategy);
+ }
+ RandomForestClassificationModel model = rf.fit(dataFrame);
+
+ model.transform(dataFrame);
+ model.totalNumNodes();
+ model.toDebugString();
+ model.trees();
+ model.treeWeights();
+
+ /*
+ // TODO: Add test once save/load are implemented. SPARK-6725
+ File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark");
+ String path = tempDir.toURI().toString();
+ try {
+ model3.save(sc.sc(), path);
+ RandomForestClassificationModel sameModel =
+ RandomForestClassificationModel.load(sc.sc(), path);
+ TreeTests.checkEqual(model3, sameModel);
+ } finally {
+ Utils.deleteRecursively(tempDir);
+ }
+ */
+ }
+}
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
index a3a339004f..71b041818d 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
@@ -17,7 +17,6 @@
package org.apache.spark.ml.regression;
-import java.io.File;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
@@ -32,7 +31,6 @@ import org.apache.spark.ml.impl.TreeTests;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.DataFrame;
-import org.apache.spark.util.Utils;
public class JavaDecisionTreeRegressorSuite implements Serializable {
@@ -57,22 +55,22 @@ public class JavaDecisionTreeRegressorSuite implements Serializable {
double B = -1.5;
JavaRDD<LabeledPoint> data = sc.parallelize(
- LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
+ LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
Map<Integer, Integer> categoricalFeatures = new HashMap<Integer, Integer>();
DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
// This tests setters. Training with various options is tested in Scala.
DecisionTreeRegressor dt = new DecisionTreeRegressor()
- .setMaxDepth(2)
- .setMaxBins(10)
- .setMinInstancesPerNode(5)
- .setMinInfoGain(0.0)
- .setMaxMemoryInMB(256)
- .setCacheNodeIds(false)
- .setCheckpointInterval(10)
- .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
- for (int i = 0; i < DecisionTreeRegressor.supportedImpurities().length; ++i) {
- dt.setImpurity(DecisionTreeRegressor.supportedImpurities()[i]);
+ .setMaxDepth(2)
+ .setMaxBins(10)
+ .setMinInstancesPerNode(5)
+ .setMinInfoGain(0.0)
+ .setMaxMemoryInMB(256)
+ .setCacheNodeIds(false)
+ .setCheckpointInterval(10)
+ .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
+ for (String impurity: DecisionTreeRegressor.supportedImpurities()) {
+ dt.setImpurity(impurity);
}
DecisionTreeRegressionModel model = dt.fit(dataFrame);
@@ -82,7 +80,7 @@ public class JavaDecisionTreeRegressorSuite implements Serializable {
model.toDebugString();
/*
- // TODO: Add test once save/load are implemented.
+ // TODO: Add test once save/load are implemented. SPARK-6725
File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark");
String path = tempDir.toURI().toString();
try {
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java
new file mode 100644
index 0000000000..fc8c13db07
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression;
+
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.impl.TreeTests;
+import org.apache.spark.mllib.classification.LogisticRegressionSuite;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.DataFrame;
+
+
+public class JavaGBTRegressorSuite implements Serializable {
+
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaGBTRegressorSuite");
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ }
+
+ @Test
+ public void runDT() {
+ int nPoints = 20;
+ double A = 2.0;
+ double B = -1.5;
+
+ JavaRDD<LabeledPoint> data = sc.parallelize(
+ LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
+ Map<Integer, Integer> categoricalFeatures = new HashMap<Integer, Integer>();
+ DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
+
+ GBTRegressor rf = new GBTRegressor()
+ .setMaxDepth(2)
+ .setMaxBins(10)
+ .setMinInstancesPerNode(5)
+ .setMinInfoGain(0.0)
+ .setMaxMemoryInMB(256)
+ .setCacheNodeIds(false)
+ .setCheckpointInterval(10)
+ .setSubsamplingRate(1.0)
+ .setSeed(1234)
+ .setMaxIter(3)
+ .setStepSize(0.1)
+ .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
+ for (String lossType: GBTRegressor.supportedLossTypes()) {
+ rf.setLossType(lossType);
+ }
+ GBTRegressionModel model = rf.fit(dataFrame);
+
+ model.transform(dataFrame);
+ model.totalNumNodes();
+ model.toDebugString();
+ model.trees();
+ model.treeWeights();
+
+ /*
+ // TODO: Add test once save/load are implemented. SPARK-6725
+ File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark");
+ String path = tempDir.toURI().toString();
+ try {
+ model2.save(sc.sc(), path);
+ GBTRegressionModel sameModel = GBTRegressionModel.load(sc.sc(), path);
+ TreeTests.checkEqual(model2, sameModel);
+ } finally {
+ Utils.deleteRecursively(tempDir);
+ }
+ */
+ }
+}
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
new file mode 100644
index 0000000000..e306ebadfe
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression;
+
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.classification.LogisticRegressionSuite;
+import org.apache.spark.ml.impl.TreeTests;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.DataFrame;
+
+
+public class JavaRandomForestRegressorSuite implements Serializable {
+
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaRandomForestRegressorSuite");
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ }
+
+ @Test
+ public void runDT() {
+ int nPoints = 20;
+ double A = 2.0;
+ double B = -1.5;
+
+ JavaRDD<LabeledPoint> data = sc.parallelize(
+ LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
+ Map<Integer, Integer> categoricalFeatures = new HashMap<Integer, Integer>();
+ DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
+
+ // This tests setters. Training with various options is tested in Scala.
+ RandomForestRegressor rf = new RandomForestRegressor()
+ .setMaxDepth(2)
+ .setMaxBins(10)
+ .setMinInstancesPerNode(5)
+ .setMinInfoGain(0.0)
+ .setMaxMemoryInMB(256)
+ .setCacheNodeIds(false)
+ .setCheckpointInterval(10)
+ .setSubsamplingRate(1.0)
+ .setSeed(1234)
+ .setNumTrees(3)
+ .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
+ for (String impurity: RandomForestRegressor.supportedImpurities()) {
+ rf.setImpurity(impurity);
+ }
+ for (String featureSubsetStrategy: RandomForestRegressor.supportedFeatureSubsetStrategies()) {
+ rf.setFeatureSubsetStrategy(featureSubsetStrategy);
+ }
+ RandomForestRegressionModel model = rf.fit(dataFrame);
+
+ model.transform(dataFrame);
+ model.totalNumNodes();
+ model.toDebugString();
+ model.trees();
+ model.treeWeights();
+
+ /*
+ // TODO: Add test once save/load are implemented. SPARK-6725
+ File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark");
+ String path = tempDir.toURI().toString();
+ try {
+ model2.save(sc.sc(), path);
+ RandomForestRegressionModel sameModel = RandomForestRegressionModel.load(sc.sc(), path);
+ TreeTests.checkEqual(model2, sameModel);
+ } finally {
+ Utils.deleteRecursively(tempDir);
+ }
+ */
+ }
+}