diff options
author | Joseph K. Bradley <joseph@databricks.com> | 2016-03-16 14:18:35 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-03-16 14:18:35 -0700 |
commit | 6fc2b6541fd5ab73b289af5f7296fc602b5b4dce (patch) | |
tree | ec8da69765b849a72e0faf5914f25a6dbd4d21f6 /mllib/src/test/scala/org | |
parent | 3f06eb72ca0c3e5779a702c7c677229e0c480751 (diff) | |
download | spark-6fc2b6541fd5ab73b289af5f7296fc602b5b4dce.tar.gz spark-6fc2b6541fd5ab73b289af5f7296fc602b5b4dce.tar.bz2 spark-6fc2b6541fd5ab73b289af5f7296fc602b5b4dce.zip |
[SPARK-11888][ML] Decision tree persistence in spark.ml
### What changes were proposed in this pull request?
Made these MLReadable and MLWritable: DecisionTreeClassifier, DecisionTreeClassificationModel, DecisionTreeRegressor, DecisionTreeRegressionModel
* The shared implementation is in treeModels.scala
* I use case classes to create a DataFrame to save, and I use the Dataset API to parse loaded files.
Other changes:
* Made CategoricalSplit.numCategories public (to use in persistence)
* Fixed a bug in DefaultReadWriteTest.testEstimatorAndModelReadWrite, where it did not call the checkModelData function passed as an argument. This caused an error in LDASuite, which I fixed.
### How was this patch tested?
Persistence is tested via unit tests. For each algorithm, there are 2 non-trivial trees (depth 2). One is built with continuous features, and one with categorical; this ensures that both types of splits are tested.
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #11581 from jkbradley/dt-io.
Diffstat (limited to 'mllib/src/test/scala/org')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala | 50 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala | 2 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala | 2 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala | 35 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala | 2 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala | 2 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala | 1 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala (renamed from mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala) | 37 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala | 2 |
9 files changed, 101 insertions, 32 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 6d68364499..2b07524815 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -18,10 +18,10 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.impl.TreeTests import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.tree.{CategoricalSplit, InternalNode, LeafNode} -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.tree.impl.TreeTests +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite} @@ -30,7 +30,8 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.apache.spark.sql.Row -class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { +class DecisionTreeClassifierSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { import DecisionTreeClassifierSuite.compareAPIs @@ -338,25 +339,34 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// - // TODO: Reinstate test once save/load are implemented SPARK-6725 - /* - test("model save/load") { - val tempDir = Utils.createTempDir() - val path = tempDir.toURI.toString - - val oldModel = OldDecisionTreeSuite.createModel(OldAlgo.Classification) - val newModel = DecisionTreeClassificationModel.fromOld(oldModel) - - // Save model, load it back, and compare. - try { - newModel.save(sc, path) - val sameNewModel = DecisionTreeClassificationModel.load(sc, path) - TreeTests.checkEqual(newModel, sameNewModel) - } finally { - Utils.deleteRecursively(tempDir) + test("read/write") { + def checkModelData( + model: DecisionTreeClassificationModel, + model2: DecisionTreeClassificationModel): Unit = { + TreeTests.checkEqual(model, model2) + assert(model.numFeatures === model2.numFeatures) + assert(model.numClasses === model2.numClasses) } + + val dt = new DecisionTreeClassifier() + val rdd = TreeTests.getTreeReadWriteData(sc) + + val allParamSettings = TreeTests.allParamSettings ++ Map("impurity" -> "entropy") + + // Categorical splits with tree depth 2 + val categoricalData: DataFrame = + TreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 2) + testEstimatorAndModelReadWrite(dt, categoricalData, allParamSettings, checkModelData) + + // Continuous splits with tree depth 2 + val continuousData: DataFrame = + TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2) + testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings, checkModelData) + + // Continuous splits with tree depth 0 + testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings ++ Map("maxDepth" -> 0), + checkModelData) } - */ } private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 039141aeb6..29efd675ab 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -18,10 +18,10 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.impl.TreeTests import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree.LeafNode +import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index 4c7c56782c..b896099e31 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -18,9 +18,9 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.impl.TreeTests import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.tree.LeafNode +import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 56b335a33a..662e3fc679 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -18,8 +18,8 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.impl.TreeTests -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.tree.impl.TreeTests +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, @@ -28,7 +28,8 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} -class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { +class DecisionTreeRegressorSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { import DecisionTreeRegressorSuite.compareAPIs @@ -120,7 +121,33 @@ class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContex // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// - // TODO: test("model save/load") SPARK-6725 + test("read/write") { + def checkModelData( + model: DecisionTreeRegressionModel, + model2: DecisionTreeRegressionModel): Unit = { + TreeTests.checkEqual(model, model2) + assert(model.numFeatures === model2.numFeatures) + } + + val dt = new DecisionTreeRegressor() + val rdd = TreeTests.getTreeReadWriteData(sc) + + // Categorical splits with tree depth 2 + val categoricalData: DataFrame = + TreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 0) + testEstimatorAndModelReadWrite(dt, categoricalData, + TreeTests.allParamSettings, checkModelData) + + // Continuous splits with tree depth 2 + val continuousData: DataFrame = + TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0) + testEstimatorAndModelReadWrite(dt, continuousData, + TreeTests.allParamSettings, checkModelData) + + // Continuous splits with tree depth 0 + testEstimatorAndModelReadWrite(dt, continuousData, + TreeTests.allParamSettings ++ Map("maxDepth" -> 0), checkModelData) + } } private[ml] object DecisionTreeRegressorSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 244db8637b..db68606397 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index efb117f8f9..6be0c8bca0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index d5c238e9ae..9d922291a6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.ml.tree.impl import org.apache.spark.SparkFunSuite import org.apache.spark.ml.classification.DecisionTreeClassificationModel -import org.apache.spark.ml.impl.TreeTests import org.apache.spark.ml.tree.{ContinuousSplit, DecisionTreeModel, LeafNode, Node} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.tree.impurity.GiniCalculator diff --git a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala index 5561f6f0ef..12808b0305 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala @@ -15,12 +15,11 @@ * limitations under the License. */ -package org.apache.spark.ml.impl +package org.apache.spark.ml.tree.impl import scala.collection.JavaConverters._ -import org.apache.spark.SparkContext -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.api.java.JavaRDD import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute} import org.apache.spark.ml.tree._ @@ -154,4 +153,36 @@ private[ml] object TreeTests extends SparkFunSuite { new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 0)), new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)) )) + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + * + * This set of Params is for all Decision Tree-based models. + */ + val allParamSettings: Map[String, Any] = Map( + "checkpointInterval" -> 7, + "seed" -> 543L, + "maxDepth" -> 2, + "maxBins" -> 20, + "minInstancesPerNode" -> 2, + "minInfoGain" -> 1e-14, + "maxMemoryInMB" -> 257, + "cacheNodeIds" -> true + ) + + /** Data for tree read/write tests which produces a non-trivial tree. */ + def getTreeReadWriteData(sc: SparkContext): RDD[LabeledPoint] = { + val arr = Array( + LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), + LabeledPoint(0.0, Vectors.dense(0.0, 2.0)), + LabeledPoint(0.0, Vectors.dense(1.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(1.0, 1.0)), + LabeledPoint(1.0, Vectors.dense(1.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(1.0, 2.0))) + sc.parallelize(arr) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index 8e5365af84..16280473c6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -33,6 +33,7 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => * Checks "overwrite" option and params. * This saves to and loads from [[tempDir]], but creates a subdirectory with a random name * in order to avoid conflicts from multiple calls to this method. + * * @param instance ML instance to test saving/loading * @param testParams If true, then test values of Params. Otherwise, just test overwrite option. * @tparam T ML instance type @@ -85,6 +86,7 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => * - Compare model data * * This requires that the [[Estimator]] and [[Model]] share the same set of [[Param]]s. + * * @param estimator Estimator to test * @param dataset Dataset to pass to [[Estimator.fit()]] * @param testParams Set of [[Param]] values to set in estimator |