aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2016-03-16 14:18:35 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-03-16 14:18:35 -0700
commit6fc2b6541fd5ab73b289af5f7296fc602b5b4dce (patch)
treeec8da69765b849a72e0faf5914f25a6dbd4d21f6 /mllib/src/test/scala/org
parent3f06eb72ca0c3e5779a702c7c677229e0c480751 (diff)
downloadspark-6fc2b6541fd5ab73b289af5f7296fc602b5b4dce.tar.gz
spark-6fc2b6541fd5ab73b289af5f7296fc602b5b4dce.tar.bz2
spark-6fc2b6541fd5ab73b289af5f7296fc602b5b4dce.zip
[SPARK-11888][ML] Decision tree persistence in spark.ml
### What changes were proposed in this pull request? Made these MLReadable and MLWritable: DecisionTreeClassifier, DecisionTreeClassificationModel, DecisionTreeRegressor, DecisionTreeRegressionModel * The shared implementation is in treeModels.scala * I use case classes to create a DataFrame to save, and I use the Dataset API to parse loaded files. Other changes: * Made CategoricalSplit.numCategories public (to use in persistence) * Fixed a bug in DefaultReadWriteTest.testEstimatorAndModelReadWrite, where it did not call the checkModelData function passed as an argument. This caused an error in LDASuite, which I fixed. ### How was this patch tested? Persistence is tested via unit tests. For each algorithm, there are 2 non-trivial trees (depth 2). One is built with continuous features, and one with categorical; this ensures that both types of splits are tested. Author: Joseph K. Bradley <joseph@databricks.com> Closes #11581 from jkbradley/dt-io.
Diffstat (limited to 'mllib/src/test/scala/org')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala50
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala35
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala1
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala (renamed from mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala)37
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala2
9 files changed, 101 insertions, 32 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index 6d68364499..2b07524815 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -18,10 +18,10 @@
package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.tree.{CategoricalSplit, InternalNode, LeafNode}
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.tree.impl.TreeTests
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite}
@@ -30,7 +30,8 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Row
-class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
+class DecisionTreeClassifierSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
import DecisionTreeClassifierSuite.compareAPIs
@@ -338,25 +339,34 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
- // TODO: Reinstate test once save/load are implemented SPARK-6725
- /*
- test("model save/load") {
- val tempDir = Utils.createTempDir()
- val path = tempDir.toURI.toString
-
- val oldModel = OldDecisionTreeSuite.createModel(OldAlgo.Classification)
- val newModel = DecisionTreeClassificationModel.fromOld(oldModel)
-
- // Save model, load it back, and compare.
- try {
- newModel.save(sc, path)
- val sameNewModel = DecisionTreeClassificationModel.load(sc, path)
- TreeTests.checkEqual(newModel, sameNewModel)
- } finally {
- Utils.deleteRecursively(tempDir)
+ test("read/write") {
+ def checkModelData(
+ model: DecisionTreeClassificationModel,
+ model2: DecisionTreeClassificationModel): Unit = {
+ TreeTests.checkEqual(model, model2)
+ assert(model.numFeatures === model2.numFeatures)
+ assert(model.numClasses === model2.numClasses)
}
+
+ val dt = new DecisionTreeClassifier()
+ val rdd = TreeTests.getTreeReadWriteData(sc)
+
+ val allParamSettings = TreeTests.allParamSettings ++ Map("impurity" -> "entropy")
+
+ // Categorical splits with tree depth 2
+ val categoricalData: DataFrame =
+ TreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 2)
+ testEstimatorAndModelReadWrite(dt, categoricalData, allParamSettings, checkModelData)
+
+ // Continuous splits with tree depth 2
+ val continuousData: DataFrame =
+ TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2)
+ testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings, checkModelData)
+
+ // Continuous splits with tree depth 0
+ testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings ++ Map("maxDepth" -> 0),
+ checkModelData)
}
- */
}
private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index 039141aeb6..29efd675ab 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -18,10 +18,10 @@
package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree.LeafNode
+import org.apache.spark.ml.tree.impl.TreeTests
import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index 4c7c56782c..b896099e31 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -18,9 +18,9 @@
package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.tree.LeafNode
+import org.apache.spark.ml.tree.impl.TreeTests
import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index 56b335a33a..662e3fc679 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -18,8 +18,8 @@
package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.impl.TreeTests
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.tree.impl.TreeTests
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree,
@@ -28,7 +28,8 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}
-class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
+class DecisionTreeRegressorSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
import DecisionTreeRegressorSuite.compareAPIs
@@ -120,7 +121,33 @@ class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContex
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
- // TODO: test("model save/load") SPARK-6725
+ test("read/write") {
+ def checkModelData(
+ model: DecisionTreeRegressionModel,
+ model2: DecisionTreeRegressionModel): Unit = {
+ TreeTests.checkEqual(model, model2)
+ assert(model.numFeatures === model2.numFeatures)
+ }
+
+ val dt = new DecisionTreeRegressor()
+ val rdd = TreeTests.getTreeReadWriteData(sc)
+
+ // Categorical splits with tree depth 2
+ val categoricalData: DataFrame =
+ TreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 0)
+ testEstimatorAndModelReadWrite(dt, categoricalData,
+ TreeTests.allParamSettings, checkModelData)
+
+ // Continuous splits with tree depth 2
+ val continuousData: DataFrame =
+ TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0)
+ testEstimatorAndModelReadWrite(dt, continuousData,
+ TreeTests.allParamSettings, checkModelData)
+
+ // Continuous splits with tree depth 0
+ testEstimatorAndModelReadWrite(dt, continuousData,
+ TreeTests.allParamSettings ++ Map("maxDepth" -> 0), checkModelData)
+ }
}
private[ml] object DecisionTreeRegressorSuite extends SparkFunSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index 244db8637b..db68606397 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -18,7 +18,7 @@
package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.impl.TreeTests
+import org.apache.spark.ml.tree.impl.TreeTests
import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
index efb117f8f9..6be0c8bca0 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
@@ -18,7 +18,7 @@
package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.impl.TreeTests
+import org.apache.spark.ml.tree.impl.TreeTests
import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
index d5c238e9ae..9d922291a6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
@@ -19,7 +19,6 @@ package org.apache.spark.ml.tree.impl
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
-import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.ml.tree.{ContinuousSplit, DecisionTreeModel, LeafNode, Node}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.tree.impurity.GiniCalculator
diff --git a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
index 5561f6f0ef..12808b0305 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
@@ -15,12 +15,11 @@
* limitations under the License.
*/
-package org.apache.spark.ml.impl
+package org.apache.spark.ml.tree.impl
import scala.collection.JavaConverters._
-import org.apache.spark.SparkContext
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute}
import org.apache.spark.ml.tree._
@@ -154,4 +153,36 @@ private[ml] object TreeTests extends SparkFunSuite {
new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 0)),
new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0))
))
+
+ /**
+ * Mapping from all Params to valid settings which differ from the defaults.
+ * This is useful for tests which need to exercise all Params, such as save/load.
+ * This excludes input columns to simplify some tests.
+ *
+ * This set of Params is for all Decision Tree-based models.
+ */
+ val allParamSettings: Map[String, Any] = Map(
+ "checkpointInterval" -> 7,
+ "seed" -> 543L,
+ "maxDepth" -> 2,
+ "maxBins" -> 20,
+ "minInstancesPerNode" -> 2,
+ "minInfoGain" -> 1e-14,
+ "maxMemoryInMB" -> 257,
+ "cacheNodeIds" -> true
+ )
+
+ /** Data for tree read/write tests which produces a non-trivial tree. */
+ def getTreeReadWriteData(sc: SparkContext): RDD[LabeledPoint] = {
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0, 1.0)),
+ LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
+ LabeledPoint(0.0, Vectors.dense(0.0, 2.0)),
+ LabeledPoint(0.0, Vectors.dense(1.0, 0.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0, 1.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0, 0.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0, 2.0)))
+ sc.parallelize(arr)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
index 8e5365af84..16280473c6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
@@ -33,6 +33,7 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
* Checks "overwrite" option and params.
* This saves to and loads from [[tempDir]], but creates a subdirectory with a random name
* in order to avoid conflicts from multiple calls to this method.
+ *
* @param instance ML instance to test saving/loading
* @param testParams If true, then test values of Params. Otherwise, just test overwrite option.
* @tparam T ML instance type
@@ -85,6 +86,7 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
* - Compare model data
*
* This requires that the [[Estimator]] and [[Model]] share the same set of [[Param]]s.
+ *
* @param estimator Estimator to test
* @param dataset Dataset to pass to [[Estimator.fit()]]
* @param testParams Set of [[Param]] values to set in estimator