aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2016-04-04 10:24:02 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-04-04 10:24:02 -0700
commit89f3befab6c150f87de2fb91b50ea8b414c69095 (patch)
tree5b6e77a97a6ca8247fec9f750640d80353c7ef1d /mllib/src/test
parent745425332f41e2ae94649f9d1ad675243f36f743 (diff)
downloadspark-89f3befab6c150f87de2fb91b50ea8b414c69095.tar.gz
spark-89f3befab6c150f87de2fb91b50ea8b414c69095.tar.bz2
spark-89f3befab6c150f87de2fb91b50ea8b414c69095.zip
[SPARK-13784][ML] Persistence for RandomForestClassifier, RandomForestRegressor
## What changes were proposed in this pull request? **Main change**: Added save/load for RandomForestClassifier, RandomForestRegressor (implementation details below) Modified numTrees method (*deprecation*) * Goal: Use default implementations of unit tests which assume Estimators and Models share the same set of Params. * What this PR does: Moves method numTrees outside of trait TreeEnsembleModel. Adds it to GBT and RF Models. Deprecates it in RF Models in favor of new method getNumTrees. In Spark 2.1, we can have RF Models include Param numTrees. Minor items * Fixes bugs in GBTClassificationModel, GBTRegressionModel fromOld methods where they assign the wrong old UID. **Implementation details** * Split DecisionTreeModelReadWrite.loadTreeNodes into 2 methods in order to reuse some code for ensembles. * Added EnsembleModelReadWrite object with save/load implementations usable for RFs and GBTs * These store all trees' nodes in a single DataFrame, and all trees' metadata in a second DataFrame. * Split trait RandomForestParams into parts in order to add more Estimator Params to RF models * Split DefaultParamsWriter.saveMetadata into two methods to allow ensembles to store sub-models' metadata in a single DataFrame. Same for DefaultParamsReader.loadMetadata ## How was this patch tested? Adds standard unit tests for RF save/load Author: Joseph K. Bradley <joseph@databricks.com> Author: GayathriMurali <gayathri.m.softie@gmail.com> Closes #12118 from jkbradley/GayathriMurali-SPARK-13784.
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala40
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala38
2 files changed, 37 insertions, 41 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index 052bc83c38..aaaa429103 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.tree.LeafNode
import org.apache.spark.ml.tree.impl.TreeTests
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
@@ -34,7 +34,8 @@ import org.apache.spark.sql.{DataFrame, Row}
/**
* Test suite for [[RandomForestClassifier]].
*/
-class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
+class RandomForestClassifierSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
import RandomForestClassifierSuite.compareAPIs
@@ -190,27 +191,24 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
- // TODO: Reinstate test once save/load are implemented SPARK-6725
- /*
- test("model save/load") {
- val tempDir = Utils.createTempDir()
- val path = tempDir.toURI.toString
-
- val trees =
- Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Classification)).toArray
- val oldModel = new OldRandomForestModel(OldAlgo.Classification, trees)
- val newModel = RandomForestClassificationModel.fromOld(oldModel)
-
- // Save model, load it back, and compare.
- try {
- newModel.save(sc, path)
- val sameNewModel = RandomForestClassificationModel.load(sc, path)
- TreeTests.checkEqual(newModel, sameNewModel)
- } finally {
- Utils.deleteRecursively(tempDir)
+ test("read/write") {
+ def checkModelData(
+ model: RandomForestClassificationModel,
+ model2: RandomForestClassificationModel): Unit = {
+ TreeTests.checkEqual(model, model2)
+ assert(model.numFeatures === model2.numFeatures)
+ assert(model.numClasses === model2.numClasses)
}
+
+ val rf = new RandomForestClassifier().setNumTrees(2)
+ val rdd = TreeTests.getTreeReadWriteData(sc)
+
+ val allParamSettings = TreeTests.allParamSettings ++ Map("impurity" -> "entropy")
+
+ val continuousData: DataFrame =
+ TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2)
+ testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, checkModelData)
}
- */
}
private object RandomForestClassifierSuite extends SparkFunSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
index 2ab4f1b146..ca400e1914 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.tree.impl.TreeTests
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
@@ -30,7 +30,8 @@ import org.apache.spark.sql.DataFrame
/**
* Test suite for [[RandomForestRegressor]].
*/
-class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
+class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContext
+ with DefaultReadWriteTest{
import RandomForestRegressorSuite.compareAPIs
@@ -106,26 +107,23 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
- // TODO: Reinstate test once save/load are implemented SPARK-6725
- /*
- test("model save/load") {
- val tempDir = Utils.createTempDir()
- val path = tempDir.toURI.toString
-
- val trees = Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray
- val oldModel = new OldRandomForestModel(OldAlgo.Regression, trees)
- val newModel = RandomForestRegressionModel.fromOld(oldModel)
-
- // Save model, load it back, and compare.
- try {
- newModel.save(sc, path)
- val sameNewModel = RandomForestRegressionModel.load(sc, path)
- TreeTests.checkEqual(newModel, sameNewModel)
- } finally {
- Utils.deleteRecursively(tempDir)
+ test("read/write") {
+ def checkModelData(
+ model: RandomForestRegressionModel,
+ model2: RandomForestRegressionModel): Unit = {
+ TreeTests.checkEqual(model, model2)
+ assert(model.numFeatures === model2.numFeatures)
}
+
+ val rf = new RandomForestRegressor().setNumTrees(2)
+ val rdd = TreeTests.getTreeReadWriteData(sc)
+
+ val allParamSettings = TreeTests.allParamSettings ++ Map("impurity" -> "variance")
+
+ val continuousData: DataFrame =
+ TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0)
+ testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, checkModelData)
}
- */
}
private object RandomForestRegressorSuite extends SparkFunSuite {