aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-02-09 22:09:07 -0800
committerXiangrui Meng <meng@databricks.com>2015-02-09 22:09:07 -0800
commitef2f55b97f58fa06acb30e9e0172fb66fba383bc (patch)
tree3e3cfab15830e2f85db891f2adb12a3c9d7a09c1 /mllib/src/test
parentbd0b5ea708aa5b84adb67c039ec52408289718bb (diff)
downloadspark-ef2f55b97f58fa06acb30e9e0172fb66fba383bc.tar.gz
spark-ef2f55b97f58fa06acb30e9e0172fb66fba383bc.tar.bz2
spark-ef2f55b97f58fa06acb30e9e0172fb66fba383bc.zip
[SPARK-5597][MLLIB] save/load for decision trees and emsembles
This is based on #4444 from jkbradley with the following changes: 1. Node schema updated to ~~~ treeId: int nodeId: Int predict/ |- predict: Double |- prob: Double impurity: Double isLeaf: Boolean split/ |- feature: Int |- threshold: Double |- featureType: Int |- categories: Array[Double] leftNodeId: Integer rightNodeId: Integer infoGain: Double ~~~ 2. Some refactor of the implementation. Closes #4444. Author: Joseph K. Bradley <joseph@databricks.com> Author: Xiangrui Meng <meng@databricks.com> Closes #4493 from mengxr/SPARK-5597 and squashes the following commits: 75e3bb6 [Xiangrui Meng] fix style 2b0033d [Xiangrui Meng] update tree export schema and refactor the implementation 45873a2 [Joseph K. Bradley] org imports 1d4c264 [Joseph K. Bradley] Added save/load for tree ensembles dcdbf85 [Joseph K. Bradley] added save/load for decision tree but need to generalize it to ensembles
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala120
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala81
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala28
3 files changed, 195 insertions, 34 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 9347eaf922..7b1aed5ffe 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -29,8 +29,10 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.tree.configuration.{QuantileStrategy, Strategy}
import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, TreePoint}
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
-import org.apache.spark.mllib.tree.model.{InformationGainStats, DecisionTreeModel, Node}
+import org.apache.spark.mllib.tree.model._
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.util.Utils
+
class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
@@ -857,9 +859,32 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
assert(topNode.leftNode.get.impurity === 0.0)
assert(topNode.rightNode.get.impurity === 0.0)
}
+
+ test("Node.subtreeIterator") {
+ val model = DecisionTreeSuite.createModel(Classification)
+ val nodeIds = model.topNode.subtreeIterator.map(_.id).toArray.sorted
+ assert(nodeIds === DecisionTreeSuite.createdModelNodeIds)
+ }
+
+ test("model save/load") {
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ Array(Classification, Regression).foreach { algo =>
+ val model = DecisionTreeSuite.createModel(algo)
+ // Save model, load it back, and compare.
+ try {
+ model.save(sc, path)
+ val sameModel = DecisionTreeModel.load(sc, path)
+ DecisionTreeSuite.checkEqual(model, sameModel)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+ }
}
-object DecisionTreeSuite {
+object DecisionTreeSuite extends FunSuite {
def validateClassifier(
model: DecisionTreeModel,
@@ -979,4 +1004,95 @@ object DecisionTreeSuite {
arr
}
+ /** Create a leaf node with the given node ID */
+ private def createLeafNode(id: Int): Node = {
+ Node(nodeIndex = id, new Predict(0.0, 1.0), impurity = 0.5, isLeaf = true)
+ }
+
+ /**
+ * Create an internal node with the given node ID and feature type.
+ * Note: This does NOT set the child nodes.
+ */
+ private def createInternalNode(id: Int, featureType: FeatureType): Node = {
+ val node = Node(nodeIndex = id, new Predict(0.0, 1.0), impurity = 0.5, isLeaf = false)
+ featureType match {
+ case Continuous =>
+ node.split = Some(new Split(feature = 0, threshold = 0.5, Continuous,
+ categories = List.empty[Double]))
+ case Categorical =>
+ node.split = Some(new Split(feature = 1, threshold = 0.0, Categorical,
+ categories = List(0.0, 1.0)))
+ }
+ // TODO: The information gain stats should be consistent with the same info stored in children.
+ node.stats = Some(new InformationGainStats(gain = 0.1, impurity = 0.2,
+ leftImpurity = 0.3, rightImpurity = 0.4, new Predict(1.0, 0.4), new Predict(0.0, 0.6)))
+ node
+ }
+
+ /**
+ * Create a tree model. This is deterministic and contains a variety of node and feature types.
+ */
+ private[tree] def createModel(algo: Algo): DecisionTreeModel = {
+ val topNode = createInternalNode(id = 1, Continuous)
+ val (node2, node3) = (createLeafNode(id = 2), createInternalNode(id = 3, Categorical))
+ val (node6, node7) = (createLeafNode(id = 6), createLeafNode(id = 7))
+ topNode.leftNode = Some(node2)
+ topNode.rightNode = Some(node3)
+ node3.leftNode = Some(node6)
+ node3.rightNode = Some(node7)
+ new DecisionTreeModel(topNode, algo)
+ }
+
+ /** Sorted Node IDs matching the model returned by [[createModel()]] */
+ private val createdModelNodeIds = Array(1, 2, 3, 6, 7)
+
+ /**
+ * Check if the two trees are exactly the same.
+ * Note: I hesitate to override Node.equals since it could cause problems if users
+ * make mistakes such as creating loops of Nodes.
+ * If the trees are not equal, this prints the two trees and throws an exception.
+ */
+ private[tree] def checkEqual(a: DecisionTreeModel, b: DecisionTreeModel): Unit = {
+ try {
+ assert(a.algo === b.algo)
+ checkEqual(a.topNode, b.topNode)
+ } catch {
+ case ex: Exception =>
+ throw new AssertionError("checkEqual failed since the two trees were not identical.\n" +
+ "TREE A:\n" + a.toDebugString + "\n" +
+ "TREE B:\n" + b.toDebugString + "\n", ex)
+ }
+ }
+
+ /**
+ * Return true iff the two nodes and their descendents are exactly the same.
+ * Note: I hesitate to override Node.equals since it could cause problems if users
+ * make mistakes such as creating loops of Nodes.
+ */
+ private def checkEqual(a: Node, b: Node): Unit = {
+ assert(a.id === b.id)
+ assert(a.predict === b.predict)
+ assert(a.impurity === b.impurity)
+ assert(a.isLeaf === b.isLeaf)
+ assert(a.split === b.split)
+ (a.stats, b.stats) match {
+ // TODO: Check other fields besides the infomation gain.
+ case (Some(aStats), Some(bStats)) => assert(aStats.gain === bStats.gain)
+ case (None, None) =>
+ case _ => throw new AssertionError(
+ s"Only one instance has stats defined. (a.stats: ${a.stats}, b.stats: ${b.stats})")
+ }
+ (a.leftNode, b.leftNode) match {
+ case (Some(aNode), Some(bNode)) => checkEqual(aNode, bNode)
+ case (None, None) =>
+ case _ => throw new AssertionError("Only one instance has leftNode defined. " +
+ s"(a.leftNode: ${a.leftNode}, b.leftNode: ${b.leftNode})")
+ }
+ (a.rightNode, b.rightNode) match {
+ case (Some(aNode: Node), Some(bNode: Node)) => checkEqual(aNode, bNode)
+ case (None, None) =>
+ case _ => throw new AssertionError("Only one instance has rightNode defined. " +
+ s"(a.rightNode: ${a.rightNode}, b.rightNode: ${b.rightNode})")
+ }
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
index e8341a5d0d..bde47606eb 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
@@ -24,8 +24,10 @@ import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy}
import org.apache.spark.mllib.tree.impurity.Variance
import org.apache.spark.mllib.tree.loss.{AbsoluteError, SquaredError, LogLoss}
-
+import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.util.Utils
+
/**
* Test suite for [[GradientBoostedTrees]].
@@ -35,32 +37,30 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
test("Regression with continuous features: SquaredError") {
GradientBoostedTreesSuite.testCombinations.foreach {
case (numIterations, learningRate, subsamplingRate) =>
- GradientBoostedTreesSuite.randomSeeds.foreach { randomSeed =>
- val rdd = sc.parallelize(GradientBoostedTreesSuite.data, 2)
-
- val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
- categoricalFeaturesInfo = Map.empty, subsamplingRate = subsamplingRate)
- val boostingStrategy =
- new BoostingStrategy(treeStrategy, SquaredError, numIterations, learningRate)
-
- val gbt = GradientBoostedTrees.train(rdd, boostingStrategy)
-
- assert(gbt.trees.size === numIterations)
- try {
- EnsembleTestHelper.validateRegressor(gbt, GradientBoostedTreesSuite.data, 0.06)
- } catch {
- case e: java.lang.AssertionError =>
- println(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," +
- s" subsamplingRate=$subsamplingRate")
- throw e
- }
-
- val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
- val dt = DecisionTree.train(remappedInput, treeStrategy)
-
- // Make sure trees are the same.
- assert(gbt.trees.head.toString == dt.toString)
+ val rdd = sc.parallelize(GradientBoostedTreesSuite.data, 2)
+
+ val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
+ categoricalFeaturesInfo = Map.empty, subsamplingRate = subsamplingRate)
+ val boostingStrategy =
+ new BoostingStrategy(treeStrategy, SquaredError, numIterations, learningRate)
+
+ val gbt = GradientBoostedTrees.train(rdd, boostingStrategy)
+
+ assert(gbt.trees.size === numIterations)
+ try {
+ EnsembleTestHelper.validateRegressor(gbt, GradientBoostedTreesSuite.data, 0.06)
+ } catch {
+ case e: java.lang.AssertionError =>
+ println(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," +
+ s" subsamplingRate=$subsamplingRate")
+ throw e
}
+
+ val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
+ val dt = DecisionTree.train(remappedInput, treeStrategy)
+
+ // Make sure trees are the same.
+ assert(gbt.trees.head.toString == dt.toString)
}
}
@@ -133,14 +133,37 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
BoostingStrategy.defaultParams(algo)
}
}
+
+ test("model save/load") {
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ val trees = Range(0, 3).map(_ => DecisionTreeSuite.createModel(Regression)).toArray
+ val treeWeights = Array(0.1, 0.3, 1.1)
+
+ Array(Classification, Regression).foreach { algo =>
+ val model = new GradientBoostedTreesModel(algo, trees, treeWeights)
+
+ // Save model, load it back, and compare.
+ try {
+ model.save(sc, path)
+ val sameModel = GradientBoostedTreesModel.load(sc, path)
+ assert(model.algo == sameModel.algo)
+ model.trees.zip(sameModel.trees).foreach { case (treeA, treeB) =>
+ DecisionTreeSuite.checkEqual(treeA, treeB)
+ }
+ assert(model.treeWeights === sameModel.treeWeights)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+ }
}
-object GradientBoostedTreesSuite {
+private object GradientBoostedTreesSuite {
// Combinations for estimators, learning rates and subsamplingRate
val testCombinations = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 0.5, 0.75), (10, 0.1, 0.75))
- val randomSeeds = Array(681283, 4398)
-
val data = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100)
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
index 55e963977b..ee3bc98486 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
@@ -27,8 +27,10 @@ import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata
import org.apache.spark.mllib.tree.impurity.{Gini, Variance}
-import org.apache.spark.mllib.tree.model.Node
+import org.apache.spark.mllib.tree.model.{Node, RandomForestModel}
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.util.Utils
+
/**
* Test suite for [[RandomForest]].
@@ -212,6 +214,26 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext {
assert(rf1.toDebugString != rf2.toDebugString)
}
-}
-
+ test("model save/load") {
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ Array(Classification, Regression).foreach { algo =>
+ val trees = Range(0, 3).map(_ => DecisionTreeSuite.createModel(algo)).toArray
+ val model = new RandomForestModel(algo, trees)
+
+ // Save model, load it back, and compare.
+ try {
+ model.save(sc, path)
+ val sameModel = RandomForestModel.load(sc, path)
+ assert(model.algo == sameModel.algo)
+ model.trees.zip(sameModel.trees).foreach { case (treeA, treeB) =>
+ DecisionTreeSuite.checkEqual(treeA, treeB)
+ }
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+ }
+}