diff options
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala | 46 |
1 files changed, 26 insertions, 20 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index 6be0c8bca0..ca400e1914 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.tree.impl.TreeTests -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} @@ -30,7 +30,8 @@ import org.apache.spark.sql.DataFrame /** * Test suite for [[RandomForestRegressor]]. */ -class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { +class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultReadWriteTest{ import RandomForestRegressorSuite.compareAPIs @@ -94,30 +95,35 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex assert(importances.toArray.forall(_ >= 0.0)) } + test("should support all NumericType labels and not support other types") { + val rf = new RandomForestRegressor().setMaxDepth(1) + MLTestingUtils.checkNumericTypes[RandomForestRegressionModel, RandomForestRegressor]( + rf, isClassification = false, sqlContext) { (expected, actual) => + TreeTests.checkEqual(expected, actual) + } + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// - // TODO: Reinstate test once save/load are implemented SPARK-6725 - /* - test("model save/load") { - val tempDir = Utils.createTempDir() - val path = tempDir.toURI.toString - - val trees = Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray - val oldModel = new OldRandomForestModel(OldAlgo.Regression, trees) - val newModel = RandomForestRegressionModel.fromOld(oldModel) - - // Save model, load it back, and compare. - try { - newModel.save(sc, path) - val sameNewModel = RandomForestRegressionModel.load(sc, path) - TreeTests.checkEqual(newModel, sameNewModel) - } finally { - Utils.deleteRecursively(tempDir) + test("read/write") { + def checkModelData( + model: RandomForestRegressionModel, + model2: RandomForestRegressionModel): Unit = { + TreeTests.checkEqual(model, model2) + assert(model.numFeatures === model2.numFeatures) } + + val rf = new RandomForestRegressor().setNumTrees(2) + val rdd = TreeTests.getTreeReadWriteData(sc) + + val allParamSettings = TreeTests.allParamSettings ++ Map("impurity" -> "variance") + + val continuousData: DataFrame = + TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0) + testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, checkModelData) } - */ } private object RandomForestRegressorSuite extends SparkFunSuite { |