aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala46
1 files changed, 26 insertions, 20 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
index 6be0c8bca0..ca400e1914 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.tree.impl.TreeTests
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
@@ -30,7 +30,8 @@ import org.apache.spark.sql.DataFrame
/**
* Test suite for [[RandomForestRegressor]].
*/
-class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
+class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContext
+ with DefaultReadWriteTest{
import RandomForestRegressorSuite.compareAPIs
@@ -94,30 +95,35 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex
assert(importances.toArray.forall(_ >= 0.0))
}
+ test("should support all NumericType labels and not support other types") {
+ val rf = new RandomForestRegressor().setMaxDepth(1)
+ MLTestingUtils.checkNumericTypes[RandomForestRegressionModel, RandomForestRegressor](
+ rf, isClassification = false, sqlContext) { (expected, actual) =>
+ TreeTests.checkEqual(expected, actual)
+ }
+ }
+
/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
- // TODO: Reinstate test once save/load are implemented SPARK-6725
- /*
- test("model save/load") {
- val tempDir = Utils.createTempDir()
- val path = tempDir.toURI.toString
-
- val trees = Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray
- val oldModel = new OldRandomForestModel(OldAlgo.Regression, trees)
- val newModel = RandomForestRegressionModel.fromOld(oldModel)
-
- // Save model, load it back, and compare.
- try {
- newModel.save(sc, path)
- val sameNewModel = RandomForestRegressionModel.load(sc, path)
- TreeTests.checkEqual(newModel, sameNewModel)
- } finally {
- Utils.deleteRecursively(tempDir)
+ test("read/write") {
+ def checkModelData(
+ model: RandomForestRegressionModel,
+ model2: RandomForestRegressionModel): Unit = {
+ TreeTests.checkEqual(model, model2)
+ assert(model.numFeatures === model2.numFeatures)
}
+
+ val rf = new RandomForestRegressor().setNumTrees(2)
+ val rdd = TreeTests.getTreeReadWriteData(sc)
+
+ val allParamSettings = TreeTests.allParamSettings ++ Map("impurity" -> "variance")
+
+ val continuousData: DataFrame =
+ TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0)
+ testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, checkModelData)
}
- */
}
private object RandomForestRegressorSuite extends SparkFunSuite {