aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-02-09 22:09:07 -0800
committerXiangrui Meng <meng@databricks.com>2015-02-09 22:09:07 -0800
commitef2f55b97f58fa06acb30e9e0172fb66fba383bc (patch)
tree3e3cfab15830e2f85db891f2adb12a3c9d7a09c1 /mllib/src
parentbd0b5ea708aa5b84adb67c039ec52408289718bb (diff)
downloadspark-ef2f55b97f58fa06acb30e9e0172fb66fba383bc.tar.gz
spark-ef2f55b97f58fa06acb30e9e0172fb66fba383bc.tar.bz2
spark-ef2f55b97f58fa06acb30e9e0172fb66fba383bc.zip
[SPARK-5597][MLLIB] save/load for decision trees and emsembles
This is based on #4444 from jkbradley with the following changes: 1. Node schema updated to ~~~ treeId: int nodeId: Int predict/ |- predict: Double |- prob: Double impurity: Double isLeaf: Boolean split/ |- feature: Int |- threshold: Double |- featureType: Int |- categories: Array[Double] leftNodeId: Integer rightNodeId: Integer infoGain: Double ~~~ 2. Some refactor of the implementation. Closes #4444. Author: Joseph K. Bradley <joseph@databricks.com> Author: Xiangrui Meng <meng@databricks.com> Closes #4493 from mengxr/SPARK-5597 and squashes the following commits: 75e3bb6 [Xiangrui Meng] fix style 2b0033d [Xiangrui Meng] update tree export schema and refactor the implementation 45873a2 [Joseph K. Bradley] org imports 1d4c264 [Joseph K. Bradley] Added save/load for tree ensembles dcdbf85 [Joseph K. Bradley] added save/load for decision tree but need to generalize it to ensembles
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala197
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala7
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala157
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala120
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala81
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala28
8 files changed, 561 insertions, 38 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
index a25e625a40..89ecf3773d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -17,11 +17,17 @@
package org.apache.spark.mllib.tree.model
+import scala.collection.mutable
+
+import org.apache.spark.SparkContext
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.tree.configuration.{Algo, FeatureType}
import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
/**
* :: Experimental ::
@@ -31,7 +37,7 @@ import org.apache.spark.rdd.RDD
* @param algo algorithm type -- classification or regression
*/
@Experimental
-class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable {
+class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable with Saveable {
/**
* Predict values for a single data point using the model trained.
@@ -98,4 +104,193 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
header + topNode.subtreeToString(2)
}
+ override def save(sc: SparkContext, path: String): Unit = {
+ DecisionTreeModel.SaveLoadV1_0.save(sc, path, this)
+ }
+
+ override protected def formatVersion: String = "1.0"
+}
+
+object DecisionTreeModel extends Loader[DecisionTreeModel] {
+
+ private[tree] object SaveLoadV1_0 {
+
+ def thisFormatVersion = "1.0"
+
+ // Hard-code class name string in case it changes in the future
+ def thisClassName = "org.apache.spark.mllib.tree.DecisionTreeModel"
+
+ case class PredictData(predict: Double, prob: Double) {
+ def toPredict: Predict = new Predict(predict, prob)
+ }
+
+ object PredictData {
+ def apply(p: Predict): PredictData = PredictData(p.predict, p.prob)
+
+ def apply(r: Row): PredictData = PredictData(r.getDouble(0), r.getDouble(1))
+ }
+
+ case class SplitData(
+ feature: Int,
+ threshold: Double,
+ featureType: Int,
+ categories: Seq[Double]) { // TODO: Change to List once SPARK-3365 is fixed
+ def toSplit: Split = {
+ new Split(feature, threshold, FeatureType(featureType), categories.toList)
+ }
+ }
+
+ object SplitData {
+ def apply(s: Split): SplitData = {
+ SplitData(s.feature, s.threshold, s.featureType.id, s.categories)
+ }
+
+ def apply(r: Row): SplitData = {
+ SplitData(r.getInt(0), r.getDouble(1), r.getInt(2), r.getAs[Seq[Double]](3))
+ }
+ }
+
+ /** Model data for model import/export */
+ case class NodeData(
+ treeId: Int,
+ nodeId: Int,
+ predict: PredictData,
+ impurity: Double,
+ isLeaf: Boolean,
+ split: Option[SplitData],
+ leftNodeId: Option[Int],
+ rightNodeId: Option[Int],
+ infoGain: Option[Double])
+
+ object NodeData {
+ def apply(treeId: Int, n: Node): NodeData = {
+ NodeData(treeId, n.id, PredictData(n.predict), n.impurity, n.isLeaf,
+ n.split.map(SplitData.apply), n.leftNode.map(_.id), n.rightNode.map(_.id),
+ n.stats.map(_.gain))
+ }
+
+ def apply(r: Row): NodeData = {
+ val split = if (r.isNullAt(5)) None else Some(SplitData(r.getStruct(5)))
+ val leftNodeId = if (r.isNullAt(6)) None else Some(r.getInt(6))
+ val rightNodeId = if (r.isNullAt(7)) None else Some(r.getInt(7))
+ val infoGain = if (r.isNullAt(8)) None else Some(r.getDouble(8))
+ NodeData(r.getInt(0), r.getInt(1), PredictData(r.getStruct(2)), r.getDouble(3),
+ r.getBoolean(4), split, leftNodeId, rightNodeId, infoGain)
+ }
+ }
+
+ def save(sc: SparkContext, path: String, model: DecisionTreeModel): Unit = {
+ val sqlContext = new SQLContext(sc)
+ import sqlContext.implicits._
+
+ // Create JSON metadata.
+ val metadataRDD = sc.parallelize(
+ Seq((thisClassName, thisFormatVersion, model.algo.toString, model.numNodes)), 1)
+ .toDataFrame("class", "version", "algo", "numNodes")
+ metadataRDD.toJSON.saveAsTextFile(Loader.metadataPath(path))
+
+ // Create Parquet data.
+ val nodes = model.topNode.subtreeIterator.toSeq
+ val dataRDD: DataFrame = sc.parallelize(nodes)
+ .map(NodeData.apply(0, _))
+ .toDataFrame
+ dataRDD.saveAsParquetFile(Loader.dataPath(path))
+ }
+
+ def load(sc: SparkContext, path: String, algo: String, numNodes: Int): DecisionTreeModel = {
+ val datapath = Loader.dataPath(path)
+ val sqlContext = new SQLContext(sc)
+ // Load Parquet data.
+ val dataRDD = sqlContext.parquetFile(datapath)
+ // Check schema explicitly since erasure makes it hard to use match-case for checking.
+ Loader.checkSchema[NodeData](dataRDD.schema)
+ val nodes = dataRDD.map(NodeData.apply)
+ // Build node data into a tree.
+ val trees = constructTrees(nodes)
+ assert(trees.size == 1,
+ "Decision tree should contain exactly one tree but got ${trees.size} trees.")
+ val model = new DecisionTreeModel(trees(0), Algo.fromString(algo))
+ assert(model.numNodes == numNodes, s"Unable to load DecisionTreeModel data from: $datapath." +
+ s" Expected $numNodes nodes but found ${model.numNodes}")
+ model
+ }
+
+ def constructTrees(nodes: RDD[NodeData]): Array[Node] = {
+ val trees = nodes
+ .groupBy(_.treeId)
+ .mapValues(_.toArray)
+ .collect()
+ .map { case (treeId, data) =>
+ (treeId, constructTree(data))
+ }.sortBy(_._1)
+ val numTrees = trees.size
+ val treeIndices = trees.map(_._1).toSeq
+ assert(treeIndices == (0 until numTrees),
+ s"Tree indices must start from 0 and increment by 1, but we found $treeIndices.")
+ trees.map(_._2)
+ }
+
+ /**
+ * Given a list of nodes from a tree, construct the tree.
+ * @param data array of all node data in a tree.
+ */
+ def constructTree(data: Array[NodeData]): Node = {
+ val dataMap: Map[Int, NodeData] = data.map(n => n.nodeId -> n).toMap
+ assert(dataMap.contains(1),
+ s"DecisionTree missing root node (id = 1).")
+ constructNode(1, dataMap, mutable.Map.empty)
+ }
+
+ /**
+ * Builds a node from the node data map and adds new nodes to the input nodes map.
+ */
+ private def constructNode(
+ id: Int,
+ dataMap: Map[Int, NodeData],
+ nodes: mutable.Map[Int, Node]): Node = {
+ if (nodes.contains(id)) {
+ return nodes(id)
+ }
+ val data = dataMap(id)
+ val node =
+ if (data.isLeaf) {
+ Node(data.nodeId, data.predict.toPredict, data.impurity, data.isLeaf)
+ } else {
+ val leftNode = constructNode(data.leftNodeId.get, dataMap, nodes)
+ val rightNode = constructNode(data.rightNodeId.get, dataMap, nodes)
+ val stats = new InformationGainStats(data.infoGain.get, data.impurity, leftNode.impurity,
+ rightNode.impurity, leftNode.predict, rightNode.predict)
+ new Node(data.nodeId, data.predict.toPredict, data.impurity, data.isLeaf,
+ data.split.map(_.toSplit), Some(leftNode), Some(rightNode), Some(stats))
+ }
+ nodes += node.id -> node
+ node
+ }
+ }
+
+ override def load(sc: SparkContext, path: String): DecisionTreeModel = {
+ val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
+ val (algo: String, numNodes: Int) = try {
+ val algo_numNodes = metadata.select("algo", "numNodes").collect()
+ assert(algo_numNodes.length == 1)
+ algo_numNodes(0) match {
+ case Row(a: String, n: Int) => (a, n)
+ }
+ } catch {
+ // Catch both Error and Exception since the checks above can throw either.
+ case e: Throwable =>
+ throw new Exception(
+ s"Unable to load DecisionTreeModel metadata from: ${Loader.metadataPath(path)}."
+ + s" Error message: ${e.getMessage}")
+ }
+ val classNameV1_0 = SaveLoadV1_0.thisClassName
+ (loadedClassName, version) match {
+ case (className, "1.0") if className == classNameV1_0 =>
+ SaveLoadV1_0.load(sc, path, algo, numNodes)
+ case _ => throw new Exception(
+ s"DecisionTreeModel.load did not recognize model with (className, format version):" +
+ s"($loadedClassName, $version). Supported:\n" +
+ s" ($classNameV1_0, 1.0)")
+ }
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
index 9a50ecb550..80990aa9a6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
@@ -49,7 +49,9 @@ class InformationGainStats(
gain == other.gain &&
impurity == other.impurity &&
leftImpurity == other.leftImpurity &&
- rightImpurity == other.rightImpurity
+ rightImpurity == other.rightImpurity &&
+ leftPredict == other.leftPredict &&
+ rightPredict == other.rightPredict
}
case _ => false
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
index 2179da8dbe..d961081d18 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
@@ -166,6 +166,11 @@ class Node (
}
}
+ /** Returns an iterator that traverses (DFS, left to right) the subtree of this node. */
+ private[tree] def subtreeIterator: Iterator[Node] = {
+ Iterator.single(this) ++ leftNode.map(_.subtreeIterator).getOrElse(Iterator.empty) ++
+ rightNode.map(_.subtreeIterator).getOrElse(Iterator.empty)
+ }
}
private[tree] object Node {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
index 004838ee5b..ad4c0dbbfb 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
@@ -32,4 +32,11 @@ class Predict(
override def toString = {
"predict = %f, prob = %f".format(predict, prob)
}
+
+ override def equals(other: Any): Boolean = {
+ other match {
+ case p: Predict => predict == p.predict && prob == p.prob
+ case _ => false
+ }
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
index 22997110de..23bd46baab 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
@@ -21,12 +21,17 @@ import scala.collection.mutable
import com.github.fommil.netlib.BLAS.{getInstance => blas}
+import org.apache.spark.SparkContext
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.configuration.Algo
import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._
+import org.apache.spark.mllib.util.{Saveable, Loader}
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, SQLContext}
+
/**
* :: Experimental ::
@@ -38,9 +43,42 @@ import org.apache.spark.rdd.RDD
@Experimental
class RandomForestModel(override val algo: Algo, override val trees: Array[DecisionTreeModel])
extends TreeEnsembleModel(algo, trees, Array.fill(trees.size)(1.0),
- combiningStrategy = if (algo == Classification) Vote else Average) {
+ combiningStrategy = if (algo == Classification) Vote else Average)
+ with Saveable {
require(trees.forall(_.algo == algo))
+
+ override def save(sc: SparkContext, path: String): Unit = {
+ TreeEnsembleModel.SaveLoadV1_0.save(sc, path, this,
+ RandomForestModel.SaveLoadV1_0.thisClassName)
+ }
+
+ override protected def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion
+}
+
+object RandomForestModel extends Loader[RandomForestModel] {
+
+ override def load(sc: SparkContext, path: String): RandomForestModel = {
+ val (loadedClassName, version, metadataRDD) = Loader.loadMetadata(sc, path)
+ val classNameV1_0 = SaveLoadV1_0.thisClassName
+ (loadedClassName, version) match {
+ case (className, "1.0") if className == classNameV1_0 =>
+ val metadata = TreeEnsembleModel.SaveLoadV1_0.readMetadata(metadataRDD, path)
+ assert(metadata.treeWeights.forall(_ == 1.0))
+ val trees =
+ TreeEnsembleModel.SaveLoadV1_0.loadTrees(sc, path, metadata.treeAlgo)
+ new RandomForestModel(Algo.fromString(metadata.algo), trees)
+ case _ => throw new Exception(s"RandomForestModel.load did not recognize model" +
+ s" with (className, format version): ($loadedClassName, $version). Supported:\n" +
+ s" ($classNameV1_0, 1.0)")
+ }
+ }
+
+ private object SaveLoadV1_0 {
+ // Hard-code class name string in case it changes in the future
+ def thisClassName = "org.apache.spark.mllib.tree.model.RandomForestModel"
+ }
+
}
/**
@@ -56,9 +94,42 @@ class GradientBoostedTreesModel(
override val algo: Algo,
override val trees: Array[DecisionTreeModel],
override val treeWeights: Array[Double])
- extends TreeEnsembleModel(algo, trees, treeWeights, combiningStrategy = Sum) {
+ extends TreeEnsembleModel(algo, trees, treeWeights, combiningStrategy = Sum)
+ with Saveable {
require(trees.size == treeWeights.size)
+
+ override def save(sc: SparkContext, path: String): Unit = {
+ TreeEnsembleModel.SaveLoadV1_0.save(sc, path, this,
+ GradientBoostedTreesModel.SaveLoadV1_0.thisClassName)
+ }
+
+ override protected def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion
+}
+
+object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {
+
+ override def load(sc: SparkContext, path: String): GradientBoostedTreesModel = {
+ val (loadedClassName, version, metadataRDD) = Loader.loadMetadata(sc, path)
+ val classNameV1_0 = SaveLoadV1_0.thisClassName
+ (loadedClassName, version) match {
+ case (className, "1.0") if className == classNameV1_0 =>
+ val metadata = TreeEnsembleModel.SaveLoadV1_0.readMetadata(metadataRDD, path)
+ assert(metadata.combiningStrategy == Sum.toString)
+ val trees =
+ TreeEnsembleModel.SaveLoadV1_0.loadTrees(sc, path, metadata.treeAlgo)
+ new GradientBoostedTreesModel(Algo.fromString(metadata.algo), trees, metadata.treeWeights)
+ case _ => throw new Exception(s"GradientBoostedTreesModel.load did not recognize model" +
+ s" with (className, format version): ($loadedClassName, $version). Supported:\n" +
+ s" ($classNameV1_0, 1.0)")
+ }
+ }
+
+ private object SaveLoadV1_0 {
+ // Hard-code class name string in case it changes in the future
+ def thisClassName = "org.apache.spark.mllib.tree.model.GradientBoostedTreesModel"
+ }
+
}
/**
@@ -176,3 +247,85 @@ private[tree] sealed class TreeEnsembleModel(
*/
def totalNumNodes: Int = trees.map(_.numNodes).sum
}
+
+private[tree] object TreeEnsembleModel {
+
+ object SaveLoadV1_0 {
+
+ import DecisionTreeModel.SaveLoadV1_0.{NodeData, constructTrees}
+
+ def thisFormatVersion = "1.0"
+
+ case class Metadata(
+ algo: String,
+ treeAlgo: String,
+ combiningStrategy: String,
+ treeWeights: Array[Double])
+
+ /**
+ * Model data for model import/export.
+ * We have to duplicate NodeData here since Spark SQL does not yet support extracting subfields
+ * of nested fields; once that is possible, we can use something like:
+ * case class EnsembleNodeData(treeId: Int, node: NodeData),
+ * where NodeData is from DecisionTreeModel.
+ */
+ case class EnsembleNodeData(treeId: Int, node: NodeData)
+
+ def save(sc: SparkContext, path: String, model: TreeEnsembleModel, className: String): Unit = {
+ val sqlContext = new SQLContext(sc)
+ import sqlContext.implicits._
+
+ // Create JSON metadata.
+ val metadata = Metadata(model.algo.toString, model.trees(0).algo.toString,
+ model.combiningStrategy.toString, model.treeWeights)
+ val metadataRDD = sc.parallelize(Seq((className, thisFormatVersion, metadata)), 1)
+ .toDataFrame("class", "version", "metadata")
+ metadataRDD.toJSON.saveAsTextFile(Loader.metadataPath(path))
+
+ // Create Parquet data.
+ val dataRDD = sc.parallelize(model.trees.zipWithIndex).flatMap { case (tree, treeId) =>
+ tree.topNode.subtreeIterator.toSeq.map(node => NodeData(treeId, node))
+ }.toDataFrame
+ dataRDD.saveAsParquetFile(Loader.dataPath(path))
+ }
+
+ /**
+ * Read metadata from the loaded metadata DataFrame.
+ * @param path Path for loading data, used for debug messages.
+ */
+ def readMetadata(metadata: DataFrame, path: String): Metadata = {
+ try {
+ // We rely on the try-catch for schema checking rather than creating a schema just for this.
+ val metadataArray = metadata.select("metadata.algo", "metadata.treeAlgo",
+ "metadata.combiningStrategy", "metadata.treeWeights").collect()
+ assert(metadataArray.size == 1)
+ Metadata(metadataArray(0).getString(0), metadataArray(0).getString(1),
+ metadataArray(0).getString(2), metadataArray(0).getAs[Seq[Double]](3).toArray)
+ } catch {
+ // Catch both Error and Exception since the checks above can throw either.
+ case e: Throwable =>
+ throw new Exception(
+ s"Unable to load TreeEnsembleModel metadata from: ${Loader.metadataPath(path)}."
+ + s" Error message: ${e.getMessage}")
+ }
+ }
+
+ /**
+ * Load trees for an ensemble, and return them in order.
+ * @param path path to load the model from
+ * @param treeAlgo Algorithm for individual trees (which may differ from the ensemble's
+ * algorithm).
+ */
+ def loadTrees(
+ sc: SparkContext,
+ path: String,
+ treeAlgo: String): Array[DecisionTreeModel] = {
+ val datapath = Loader.dataPath(path)
+ val sqlContext = new SQLContext(sc)
+ val nodes = sqlContext.parquetFile(datapath).map(NodeData.apply)
+ val trees = constructTrees(nodes)
+ trees.map(new DecisionTreeModel(_, Algo.fromString(treeAlgo)))
+ }
+ }
+
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 9347eaf922..7b1aed5ffe 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -29,8 +29,10 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.tree.configuration.{QuantileStrategy, Strategy}
import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, TreePoint}
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
-import org.apache.spark.mllib.tree.model.{InformationGainStats, DecisionTreeModel, Node}
+import org.apache.spark.mllib.tree.model._
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.util.Utils
+
class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
@@ -857,9 +859,32 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
assert(topNode.leftNode.get.impurity === 0.0)
assert(topNode.rightNode.get.impurity === 0.0)
}
+
+ test("Node.subtreeIterator") {
+ val model = DecisionTreeSuite.createModel(Classification)
+ val nodeIds = model.topNode.subtreeIterator.map(_.id).toArray.sorted
+ assert(nodeIds === DecisionTreeSuite.createdModelNodeIds)
+ }
+
+ test("model save/load") {
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ Array(Classification, Regression).foreach { algo =>
+ val model = DecisionTreeSuite.createModel(algo)
+ // Save model, load it back, and compare.
+ try {
+ model.save(sc, path)
+ val sameModel = DecisionTreeModel.load(sc, path)
+ DecisionTreeSuite.checkEqual(model, sameModel)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+ }
}
-object DecisionTreeSuite {
+object DecisionTreeSuite extends FunSuite {
def validateClassifier(
model: DecisionTreeModel,
@@ -979,4 +1004,95 @@ object DecisionTreeSuite {
arr
}
+ /** Create a leaf node with the given node ID */
+ private def createLeafNode(id: Int): Node = {
+ Node(nodeIndex = id, new Predict(0.0, 1.0), impurity = 0.5, isLeaf = true)
+ }
+
+ /**
+ * Create an internal node with the given node ID and feature type.
+ * Note: This does NOT set the child nodes.
+ */
+ private def createInternalNode(id: Int, featureType: FeatureType): Node = {
+ val node = Node(nodeIndex = id, new Predict(0.0, 1.0), impurity = 0.5, isLeaf = false)
+ featureType match {
+ case Continuous =>
+ node.split = Some(new Split(feature = 0, threshold = 0.5, Continuous,
+ categories = List.empty[Double]))
+ case Categorical =>
+ node.split = Some(new Split(feature = 1, threshold = 0.0, Categorical,
+ categories = List(0.0, 1.0)))
+ }
+ // TODO: The information gain stats should be consistent with the same info stored in children.
+ node.stats = Some(new InformationGainStats(gain = 0.1, impurity = 0.2,
+ leftImpurity = 0.3, rightImpurity = 0.4, new Predict(1.0, 0.4), new Predict(0.0, 0.6)))
+ node
+ }
+
+ /**
+ * Create a tree model. This is deterministic and contains a variety of node and feature types.
+ */
+ private[tree] def createModel(algo: Algo): DecisionTreeModel = {
+ val topNode = createInternalNode(id = 1, Continuous)
+ val (node2, node3) = (createLeafNode(id = 2), createInternalNode(id = 3, Categorical))
+ val (node6, node7) = (createLeafNode(id = 6), createLeafNode(id = 7))
+ topNode.leftNode = Some(node2)
+ topNode.rightNode = Some(node3)
+ node3.leftNode = Some(node6)
+ node3.rightNode = Some(node7)
+ new DecisionTreeModel(topNode, algo)
+ }
+
+ /** Sorted Node IDs matching the model returned by [[createModel()]] */
+ private val createdModelNodeIds = Array(1, 2, 3, 6, 7)
+
+ /**
+ * Check if the two trees are exactly the same.
+ * Note: I hesitate to override Node.equals since it could cause problems if users
+ * make mistakes such as creating loops of Nodes.
+ * If the trees are not equal, this prints the two trees and throws an exception.
+ */
+ private[tree] def checkEqual(a: DecisionTreeModel, b: DecisionTreeModel): Unit = {
+ try {
+ assert(a.algo === b.algo)
+ checkEqual(a.topNode, b.topNode)
+ } catch {
+ case ex: Exception =>
+ throw new AssertionError("checkEqual failed since the two trees were not identical.\n" +
+ "TREE A:\n" + a.toDebugString + "\n" +
+ "TREE B:\n" + b.toDebugString + "\n", ex)
+ }
+ }
+
+ /**
+ * Return true iff the two nodes and their descendents are exactly the same.
+ * Note: I hesitate to override Node.equals since it could cause problems if users
+ * make mistakes such as creating loops of Nodes.
+ */
+ private def checkEqual(a: Node, b: Node): Unit = {
+ assert(a.id === b.id)
+ assert(a.predict === b.predict)
+ assert(a.impurity === b.impurity)
+ assert(a.isLeaf === b.isLeaf)
+ assert(a.split === b.split)
+ (a.stats, b.stats) match {
+ // TODO: Check other fields besides the infomation gain.
+ case (Some(aStats), Some(bStats)) => assert(aStats.gain === bStats.gain)
+ case (None, None) =>
+ case _ => throw new AssertionError(
+ s"Only one instance has stats defined. (a.stats: ${a.stats}, b.stats: ${b.stats})")
+ }
+ (a.leftNode, b.leftNode) match {
+ case (Some(aNode), Some(bNode)) => checkEqual(aNode, bNode)
+ case (None, None) =>
+ case _ => throw new AssertionError("Only one instance has leftNode defined. " +
+ s"(a.leftNode: ${a.leftNode}, b.leftNode: ${b.leftNode})")
+ }
+ (a.rightNode, b.rightNode) match {
+ case (Some(aNode: Node), Some(bNode: Node)) => checkEqual(aNode, bNode)
+ case (None, None) =>
+ case _ => throw new AssertionError("Only one instance has rightNode defined. " +
+ s"(a.rightNode: ${a.rightNode}, b.rightNode: ${b.rightNode})")
+ }
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
index e8341a5d0d..bde47606eb 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
@@ -24,8 +24,10 @@ import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy}
import org.apache.spark.mllib.tree.impurity.Variance
import org.apache.spark.mllib.tree.loss.{AbsoluteError, SquaredError, LogLoss}
-
+import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.util.Utils
+
/**
* Test suite for [[GradientBoostedTrees]].
@@ -35,32 +37,30 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
test("Regression with continuous features: SquaredError") {
GradientBoostedTreesSuite.testCombinations.foreach {
case (numIterations, learningRate, subsamplingRate) =>
- GradientBoostedTreesSuite.randomSeeds.foreach { randomSeed =>
- val rdd = sc.parallelize(GradientBoostedTreesSuite.data, 2)
-
- val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
- categoricalFeaturesInfo = Map.empty, subsamplingRate = subsamplingRate)
- val boostingStrategy =
- new BoostingStrategy(treeStrategy, SquaredError, numIterations, learningRate)
-
- val gbt = GradientBoostedTrees.train(rdd, boostingStrategy)
-
- assert(gbt.trees.size === numIterations)
- try {
- EnsembleTestHelper.validateRegressor(gbt, GradientBoostedTreesSuite.data, 0.06)
- } catch {
- case e: java.lang.AssertionError =>
- println(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," +
- s" subsamplingRate=$subsamplingRate")
- throw e
- }
-
- val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
- val dt = DecisionTree.train(remappedInput, treeStrategy)
-
- // Make sure trees are the same.
- assert(gbt.trees.head.toString == dt.toString)
+ val rdd = sc.parallelize(GradientBoostedTreesSuite.data, 2)
+
+ val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
+ categoricalFeaturesInfo = Map.empty, subsamplingRate = subsamplingRate)
+ val boostingStrategy =
+ new BoostingStrategy(treeStrategy, SquaredError, numIterations, learningRate)
+
+ val gbt = GradientBoostedTrees.train(rdd, boostingStrategy)
+
+ assert(gbt.trees.size === numIterations)
+ try {
+ EnsembleTestHelper.validateRegressor(gbt, GradientBoostedTreesSuite.data, 0.06)
+ } catch {
+ case e: java.lang.AssertionError =>
+ println(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," +
+ s" subsamplingRate=$subsamplingRate")
+ throw e
}
+
+ val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
+ val dt = DecisionTree.train(remappedInput, treeStrategy)
+
+ // Make sure trees are the same.
+ assert(gbt.trees.head.toString == dt.toString)
}
}
@@ -133,14 +133,37 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
BoostingStrategy.defaultParams(algo)
}
}
+
+ test("model save/load") {
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ val trees = Range(0, 3).map(_ => DecisionTreeSuite.createModel(Regression)).toArray
+ val treeWeights = Array(0.1, 0.3, 1.1)
+
+ Array(Classification, Regression).foreach { algo =>
+ val model = new GradientBoostedTreesModel(algo, trees, treeWeights)
+
+ // Save model, load it back, and compare.
+ try {
+ model.save(sc, path)
+ val sameModel = GradientBoostedTreesModel.load(sc, path)
+ assert(model.algo == sameModel.algo)
+ model.trees.zip(sameModel.trees).foreach { case (treeA, treeB) =>
+ DecisionTreeSuite.checkEqual(treeA, treeB)
+ }
+ assert(model.treeWeights === sameModel.treeWeights)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+ }
}
-object GradientBoostedTreesSuite {
+private object GradientBoostedTreesSuite {
// Combinations for estimators, learning rates and subsamplingRate
val testCombinations = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 0.5, 0.75), (10, 0.1, 0.75))
- val randomSeeds = Array(681283, 4398)
-
val data = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100)
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
index 55e963977b..ee3bc98486 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
@@ -27,8 +27,10 @@ import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata
import org.apache.spark.mllib.tree.impurity.{Gini, Variance}
-import org.apache.spark.mllib.tree.model.Node
+import org.apache.spark.mllib.tree.model.{Node, RandomForestModel}
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.util.Utils
+
/**
* Test suite for [[RandomForest]].
@@ -212,6 +214,26 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext {
assert(rf1.toDebugString != rf2.toDebugString)
}
-}
-
+ test("model save/load") {
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ Array(Classification, Regression).foreach { algo =>
+ val trees = Range(0, 3).map(_ => DecisionTreeSuite.createModel(algo)).toArray
+ val model = new RandomForestModel(algo, trees)
+
+ // Save model, load it back, and compare.
+ try {
+ model.save(sc, path)
+ val sameModel = RandomForestModel.load(sc, path)
+ assert(model.algo == sameModel.algo)
+ model.trees.zip(sameModel.trees).foreach { case (treeA, treeB) =>
+ DecisionTreeSuite.checkEqual(treeA, treeB)
+ }
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+ }
+}