diff options
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala | 48 |
1 files changed, 27 insertions, 21 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index b896099e31..aaaa429103 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.tree.LeafNode import org.apache.spark.ml.tree.impl.TreeTests -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} @@ -34,7 +34,8 @@ import org.apache.spark.sql.{DataFrame, Row} /** * Test suite for [[RandomForestClassifier]]. */ -class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { +class RandomForestClassifierSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { import RandomForestClassifierSuite.compareAPIs @@ -178,31 +179,36 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte assert(importances.toArray.forall(_ >= 0.0)) } + test("should support all NumericType labels and not support other types") { + val rf = new RandomForestClassifier().setMaxDepth(1) + MLTestingUtils.checkNumericTypes[RandomForestClassificationModel, RandomForestClassifier]( + rf, isClassification = true, sqlContext) { (expected, actual) => + TreeTests.checkEqual(expected, actual) + } + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// - // TODO: Reinstate test once save/load are implemented SPARK-6725 - /* - test("model save/load") { - val tempDir = Utils.createTempDir() - val path = tempDir.toURI.toString - - val trees = - Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Classification)).toArray - val oldModel = new OldRandomForestModel(OldAlgo.Classification, trees) - val newModel = RandomForestClassificationModel.fromOld(oldModel) - - // Save model, load it back, and compare. - try { - newModel.save(sc, path) - val sameNewModel = RandomForestClassificationModel.load(sc, path) - TreeTests.checkEqual(newModel, sameNewModel) - } finally { - Utils.deleteRecursively(tempDir) + test("read/write") { + def checkModelData( + model: RandomForestClassificationModel, + model2: RandomForestClassificationModel): Unit = { + TreeTests.checkEqual(model, model2) + assert(model.numFeatures === model2.numFeatures) + assert(model.numClasses === model2.numClasses) } + + val rf = new RandomForestClassifier().setNumTrees(2) + val rdd = TreeTests.getTreeReadWriteData(sc) + + val allParamSettings = TreeTests.allParamSettings ++ Map("impurity" -> "entropy") + + val continuousData: DataFrame = + TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2) + testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, checkModelData) } - */ } private object RandomForestClassifierSuite extends SparkFunSuite { |