aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala48
1 files changed, 27 insertions, 21 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index b896099e31..aaaa429103 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.tree.LeafNode
import org.apache.spark.ml.tree.impl.TreeTests
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
@@ -34,7 +34,8 @@ import org.apache.spark.sql.{DataFrame, Row}
/**
* Test suite for [[RandomForestClassifier]].
*/
-class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
+class RandomForestClassifierSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
import RandomForestClassifierSuite.compareAPIs
@@ -178,31 +179,36 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
assert(importances.toArray.forall(_ >= 0.0))
}
+ test("should support all NumericType labels and not support other types") {
+ val rf = new RandomForestClassifier().setMaxDepth(1)
+ MLTestingUtils.checkNumericTypes[RandomForestClassificationModel, RandomForestClassifier](
+ rf, isClassification = true, sqlContext) { (expected, actual) =>
+ TreeTests.checkEqual(expected, actual)
+ }
+ }
+
/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
- // TODO: Reinstate test once save/load are implemented SPARK-6725
- /*
- test("model save/load") {
- val tempDir = Utils.createTempDir()
- val path = tempDir.toURI.toString
-
- val trees =
- Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Classification)).toArray
- val oldModel = new OldRandomForestModel(OldAlgo.Classification, trees)
- val newModel = RandomForestClassificationModel.fromOld(oldModel)
-
- // Save model, load it back, and compare.
- try {
- newModel.save(sc, path)
- val sameNewModel = RandomForestClassificationModel.load(sc, path)
- TreeTests.checkEqual(newModel, sameNewModel)
- } finally {
- Utils.deleteRecursively(tempDir)
+ test("read/write") {
+ def checkModelData(
+ model: RandomForestClassificationModel,
+ model2: RandomForestClassificationModel): Unit = {
+ TreeTests.checkEqual(model, model2)
+ assert(model.numFeatures === model2.numFeatures)
+ assert(model.numClasses === model2.numClasses)
}
+
+ val rf = new RandomForestClassifier().setNumTrees(2)
+ val rdd = TreeTests.getTreeReadWriteData(sc)
+
+ val allParamSettings = TreeTests.allParamSettings ++ Map("impurity" -> "entropy")
+
+ val continuousData: DataFrame =
+ TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2)
+ testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, checkModelData)
}
- */
}
private object RandomForestClassifierSuite extends SparkFunSuite {