aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2016-04-04 10:24:02 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-04-04 10:24:02 -0700
commit89f3befab6c150f87de2fb91b50ea8b414c69095 (patch)
tree5b6e77a97a6ca8247fec9f750640d80353c7ef1d
parent745425332f41e2ae94649f9d1ad675243f36f743 (diff)
downloadspark-89f3befab6c150f87de2fb91b50ea8b414c69095.tar.gz
spark-89f3befab6c150f87de2fb91b50ea8b414c69095.tar.bz2
spark-89f3befab6c150f87de2fb91b50ea8b414c69095.zip
[SPARK-13784][ML] Persistence for RandomForestClassifier, RandomForestRegressor
## What changes were proposed in this pull request? **Main change**: Added save/load for RandomForestClassifier, RandomForestRegressor (implementation details below) Modified numTrees method (*deprecation*) * Goal: Use default implementations of unit tests which assume Estimators and Models share the same set of Params. * What this PR does: Moves method numTrees outside of trait TreeEnsembleModel. Adds it to GBT and RF Models. Deprecates it in RF Models in favor of new method getNumTrees. In Spark 2.1, we can have RF Models include Param numTrees. Minor items * Fixes bugs in GBTClassificationModel, GBTRegressionModel fromOld methods where they assign the wrong old UID. **Implementation details** * Split DecisionTreeModelReadWrite.loadTreeNodes into 2 methods in order to reuse some code for ensembles. * Added EnsembleModelReadWrite object with save/load implementations usable for RFs and GBTs * These store all trees' nodes in a single DataFrame, and all trees' metadata in a second DataFrame. * Split trait RandomForestParams into parts in order to add more Estimator Params to RF models * Split DefaultParamsWriter.saveMetadata into two methods to allow ensembles to store sub-models' metadata in a single DataFrame. Same for DefaultParamsReader.loadMetadata ## How was this patch tested? Adds standard unit tests for RF save/load Author: Joseph K. Bradley <joseph@databricks.com> Author: GayathriMurali <gayathri.m.softie@gmail.com> Closes #12118 from jkbradley/GayathriMurali-SPARK-13784.
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala7
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala100
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala7
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala96
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala131
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala66
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala42
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala40
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala38
9 files changed, 424 insertions, 103 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index 48ce051d0a..bfefaf1a1a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -192,7 +192,7 @@ final class GBTClassificationModel private[ml](
extends PredictionModel[Vector, GBTClassificationModel]
with TreeEnsembleModel with Serializable {
- require(numTrees > 0, "GBTClassificationModel requires at least 1 tree.")
+ require(_trees.nonEmpty, "GBTClassificationModel requires at least 1 tree.")
require(_trees.length == _treeWeights.length, "GBTClassificationModel given trees, treeWeights" +
s" of non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).")
@@ -227,6 +227,9 @@ final class GBTClassificationModel private[ml](
if (prediction > 0.0) 1.0 else 0.0
}
+ /** Number of trees in ensemble */
+ val numTrees: Int = trees.length
+
@Since("1.4.0")
override def copy(extra: ParamMap): GBTClassificationModel = {
copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures),
@@ -272,6 +275,6 @@ private[ml] object GBTClassificationModel {
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtc")
- new GBTClassificationModel(parent.uid, newTrees, oldModel.treeWeights, numFeatures)
+ new GBTClassificationModel(uid, newTrees, oldModel.treeWeights, numFeatures)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 82fa05a604..2ad893f4fa 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -17,11 +17,15 @@
package org.apache.spark.ml.classification
+import org.json4s.{DefaultFormats, JObject}
+import org.json4s.JsonDSL._
+
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel}
+import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.impl.RandomForest
-import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
+import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
@@ -43,7 +47,7 @@ import org.apache.spark.sql.functions._
final class RandomForestClassifier @Since("1.4.0") (
@Since("1.4.0") override val uid: String)
extends ProbabilisticClassifier[Vector, RandomForestClassifier, RandomForestClassificationModel]
- with RandomForestParams with TreeClassifierParams {
+ with RandomForestClassifierParams with DefaultParamsWritable {
@Since("1.4.0")
def this() = this(Identifiable.randomUID("rfc"))
@@ -120,7 +124,7 @@ final class RandomForestClassifier @Since("1.4.0") (
@Since("1.4.0")
@Experimental
-object RandomForestClassifier {
+object RandomForestClassifier extends DefaultParamsReadable[RandomForestClassifier] {
/** Accessor for supported impurity settings: entropy, gini */
@Since("1.4.0")
final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities
@@ -129,6 +133,9 @@ object RandomForestClassifier {
@Since("1.4.0")
final val supportedFeatureSubsetStrategies: Array[String] =
RandomForestParams.supportedFeatureSubsetStrategies
+
+ @Since("2.0.0")
+ override def load(path: String): RandomForestClassifier = super.load(path)
}
/**
@@ -136,8 +143,9 @@ object RandomForestClassifier {
* [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for classification.
* It supports both binary and multiclass labels, as well as both continuous and categorical
* features.
+ *
* @param _trees Decision trees in the ensemble.
- * Warning: These have null parents.
+ * Warning: These have null parents.
*/
@Since("1.4.0")
@Experimental
@@ -147,12 +155,14 @@ final class RandomForestClassificationModel private[ml] (
@Since("1.6.0") override val numFeatures: Int,
@Since("1.5.0") override val numClasses: Int)
extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel]
- with TreeEnsembleModel with Serializable {
+ with RandomForestClassificationModelParams with TreeEnsembleModel with MLWritable
+ with Serializable {
- require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.")
+ require(_trees.nonEmpty, "RandomForestClassificationModel requires at least 1 tree.")
/**
* Construct a random forest classification model, with all trees weighted equally.
+ *
* @param trees Component trees
*/
private[ml] def this(
@@ -165,7 +175,7 @@ final class RandomForestClassificationModel private[ml] (
override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
// Note: We may add support for weights (based on tree performance) later on.
- private lazy val _treeWeights: Array[Double] = Array.fill[Double](numTrees)(1.0)
+ private lazy val _treeWeights: Array[Double] = Array.fill[Double](_trees.length)(1.0)
@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights
@@ -208,6 +218,15 @@ final class RandomForestClassificationModel private[ml] (
}
}
+ /**
+ * Number of trees in ensemble
+ *
+ * @deprecated Use [[getNumTrees]] instead. This method will be removed in 2.1.0
+ */
+ // TODO: Once this is removed, then this class can inherit from RandomForestClassifierParams
+ @deprecated("Use getNumTrees instead. This method will be removed in 2.1.0.", "2.0.0")
+ val numTrees: Int = trees.length
+
@Since("1.4.0")
override def copy(extra: ParamMap): RandomForestClassificationModel = {
copyValues(new RandomForestClassificationModel(uid, _trees, numFeatures, numClasses), extra)
@@ -216,7 +235,7 @@ final class RandomForestClassificationModel private[ml] (
@Since("1.4.0")
override def toString: String = {
- s"RandomForestClassificationModel (uid=$uid) with $numTrees trees"
+ s"RandomForestClassificationModel (uid=$uid) with $getNumTrees trees"
}
/**
@@ -236,12 +255,69 @@ final class RandomForestClassificationModel private[ml] (
private[ml] def toOld: OldRandomForestModel = {
new OldRandomForestModel(OldAlgo.Classification, _trees.map(_.toOld))
}
+
+ @Since("2.0.0")
+ override def write: MLWriter =
+ new RandomForestClassificationModel.RandomForestClassificationModelWriter(this)
}
-private[ml] object RandomForestClassificationModel {
+@Since("2.0.0")
+object RandomForestClassificationModel extends MLReadable[RandomForestClassificationModel] {
+
+ @Since("2.0.0")
+ override def read: MLReader[RandomForestClassificationModel] =
+ new RandomForestClassificationModelReader
+
+ @Since("2.0.0")
+ override def load(path: String): RandomForestClassificationModel = super.load(path)
+
+ private[RandomForestClassificationModel]
+ class RandomForestClassificationModelWriter(instance: RandomForestClassificationModel)
+ extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ // Note: numTrees is not currently used, but could be nice to store for fast querying.
+ val extraMetadata: JObject = Map(
+ "numFeatures" -> instance.numFeatures,
+ "numClasses" -> instance.numClasses,
+ "numTrees" -> instance.getNumTrees)
+ EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata)
+ }
+ }
+
+ private class RandomForestClassificationModelReader
+ extends MLReader[RandomForestClassificationModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[RandomForestClassificationModel].getName
+ private val treeClassName = classOf[DecisionTreeClassificationModel].getName
+
+ override def load(path: String): RandomForestClassificationModel = {
+ implicit val format = DefaultFormats
+ val (metadata: Metadata, treesData: Array[(Metadata, Node)]) =
+ EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName)
+ val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
+ val numClasses = (metadata.metadata \ "numClasses").extract[Int]
+ val numTrees = (metadata.metadata \ "numTrees").extract[Int]
+
+ val trees: Array[DecisionTreeClassificationModel] = treesData.map {
+ case (treeMetadata, root) =>
+ val tree =
+ new DecisionTreeClassificationModel(treeMetadata.uid, root, numFeatures, numClasses)
+ DefaultParamsReader.getAndSetParams(tree, treeMetadata)
+ tree
+ }
+ require(numTrees == trees.length, s"RandomForestClassificationModel.load expected $numTrees" +
+ s" trees based on metadata but found ${trees.length} trees.")
+
+ val model = new RandomForestClassificationModel(metadata.uid, trees, numFeatures, numClasses)
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
- /** (private[ml]) Convert a model from the old API */
- def fromOld(
+ /** Convert a model from the old API */
+ private[ml] def fromOld(
oldModel: OldRandomForestModel,
parent: RandomForestClassifier,
categoricalFeatures: Map[Int, Int],
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index 8fca35da51..02e124a1c0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -179,7 +179,7 @@ final class GBTRegressionModel private[ml](
extends PredictionModel[Vector, GBTRegressionModel]
with TreeEnsembleModel with Serializable {
- require(numTrees > 0, "GBTRegressionModel requires at least 1 tree.")
+ require(_trees.nonEmpty, "GBTRegressionModel requires at least 1 tree.")
require(_trees.length == _treeWeights.length, "GBTRegressionModel given trees, treeWeights of" +
s" non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).")
@@ -213,6 +213,9 @@ final class GBTRegressionModel private[ml](
blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
}
+ /** Number of trees in ensemble */
+ val numTrees: Int = trees.length
+
@Since("1.4.0")
override def copy(extra: ParamMap): GBTRegressionModel = {
copyValues(new GBTRegressionModel(uid, _trees, _treeWeights, numFeatures),
@@ -258,6 +261,6 @@ private[ml] object GBTRegressionModel {
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtr")
- new GBTRegressionModel(parent.uid, newTrees, oldModel.treeWeights, numFeatures)
+ new GBTRegressionModel(uid, newTrees, oldModel.treeWeights, numFeatures)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index 5b3f3a1f5d..ba56b5cd3f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -17,12 +17,16 @@
package org.apache.spark.ml.regression
+import org.json4s.{DefaultFormats, JObject}
+import org.json4s.JsonDSL._
+
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeEnsembleModel, TreeRegressorParams}
+import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.impl.RandomForest
-import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
+import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
@@ -41,7 +45,7 @@ import org.apache.spark.sql.functions._
@Experimental
final class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Predictor[Vector, RandomForestRegressor, RandomForestRegressionModel]
- with RandomForestParams with TreeRegressorParams {
+ with RandomForestRegressorParams with DefaultParamsWritable {
@Since("1.4.0")
def this() = this(Identifiable.randomUID("rfr"))
@@ -108,7 +112,7 @@ final class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val
@Since("1.4.0")
@Experimental
-object RandomForestRegressor {
+object RandomForestRegressor extends DefaultParamsReadable[RandomForestRegressor]{
/** Accessor for supported impurity settings: variance */
@Since("1.4.0")
final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities
@@ -117,12 +121,17 @@ object RandomForestRegressor {
@Since("1.4.0")
final val supportedFeatureSubsetStrategies: Array[String] =
RandomForestParams.supportedFeatureSubsetStrategies
+
+ @Since("2.0.0")
+ override def load(path: String): RandomForestRegressor = super.load(path)
+
}
/**
* :: Experimental ::
* [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for regression.
* It supports both continuous and categorical features.
+ *
* @param _trees Decision trees in the ensemble.
* @param numFeatures Number of features used by this model
*/
@@ -133,12 +142,13 @@ final class RandomForestRegressionModel private[ml] (
private val _trees: Array[DecisionTreeRegressionModel],
override val numFeatures: Int)
extends PredictionModel[Vector, RandomForestRegressionModel]
- with TreeEnsembleModel with Serializable {
+ with RandomForestRegressionModelParams with TreeEnsembleModel with MLWritable with Serializable {
- require(numTrees > 0, "RandomForestRegressionModel requires at least 1 tree.")
+ require(_trees.nonEmpty, "RandomForestRegressionModel requires at least 1 tree.")
/**
* Construct a random forest regression model, with all trees weighted equally.
+ *
* @param trees Component trees
*/
private[ml] def this(trees: Array[DecisionTreeRegressionModel], numFeatures: Int) =
@@ -148,7 +158,7 @@ final class RandomForestRegressionModel private[ml] (
override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
// Note: We may add support for weights (based on tree performance) later on.
- private lazy val _treeWeights: Array[Double] = Array.fill[Double](numTrees)(1.0)
+ private lazy val _treeWeights: Array[Double] = Array.fill[Double](_trees.length)(1.0)
@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights
@@ -165,9 +175,17 @@ final class RandomForestRegressionModel private[ml] (
// TODO: When we add a generic Bagging class, handle transform there. SPARK-7128
// Predict average of tree predictions.
// Ignore the weights since all are 1.0 for now.
- _trees.map(_.rootNode.predictImpl(features).prediction).sum / numTrees
+ _trees.map(_.rootNode.predictImpl(features).prediction).sum / getNumTrees
}
+ /**
+ * Number of trees in ensemble
+ * @deprecated Use [[getNumTrees]] instead. This method will be removed in 2.1.0
+ */
+ // TODO: Once this is removed, then this class can inherit from RandomForestRegressorParams
+ @deprecated("Use getNumTrees instead. This method will be removed in 2.1.0.", "2.0.0")
+ val numTrees: Int = trees.length
+
@Since("1.4.0")
override def copy(extra: ParamMap): RandomForestRegressionModel = {
copyValues(new RandomForestRegressionModel(uid, _trees, numFeatures), extra).setParent(parent)
@@ -175,7 +193,7 @@ final class RandomForestRegressionModel private[ml] (
@Since("1.4.0")
override def toString: String = {
- s"RandomForestRegressionModel (uid=$uid) with $numTrees trees"
+ s"RandomForestRegressionModel (uid=$uid) with $getNumTrees trees"
}
/**
@@ -195,12 +213,63 @@ final class RandomForestRegressionModel private[ml] (
private[ml] def toOld: OldRandomForestModel = {
new OldRandomForestModel(OldAlgo.Regression, _trees.map(_.toOld))
}
+
+ @Since("2.0.0")
+ override def write: MLWriter =
+ new RandomForestRegressionModel.RandomForestRegressionModelWriter(this)
}
-private[ml] object RandomForestRegressionModel {
+@Since("2.0.0")
+object RandomForestRegressionModel extends MLReadable[RandomForestRegressionModel] {
+
+ @Since("2.0.0")
+ override def read: MLReader[RandomForestRegressionModel] = new RandomForestRegressionModelReader
+
+ @Since("2.0.0")
+ override def load(path: String): RandomForestRegressionModel = super.load(path)
+
+ private[RandomForestRegressionModel]
+ class RandomForestRegressionModelWriter(instance: RandomForestRegressionModel)
+ extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ val extraMetadata: JObject = Map(
+ "numFeatures" -> instance.numFeatures,
+ "numTrees" -> instance.getNumTrees)
+ EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata)
+ }
+ }
+
+ private class RandomForestRegressionModelReader extends MLReader[RandomForestRegressionModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[RandomForestRegressionModel].getName
+ private val treeClassName = classOf[DecisionTreeRegressionModel].getName
+
+ override def load(path: String): RandomForestRegressionModel = {
+ implicit val format = DefaultFormats
+ val (metadata: Metadata, treesData: Array[(Metadata, Node)]) =
+ EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName)
+ val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
+ val numTrees = (metadata.metadata \ "numTrees").extract[Int]
+
+ val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) =>
+ val tree =
+ new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures)
+ DefaultParamsReader.getAndSetParams(tree, treeMetadata)
+ tree
+ }
+ require(numTrees == trees.length, s"RandomForestRegressionModel.load expected $numTrees" +
+ s" trees based on metadata but found ${trees.length} trees.")
+
+ val model = new RandomForestRegressionModel(metadata.uid, trees, numFeatures)
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
- /** (private[ml]) Convert a model from the old API */
- def fromOld(
+ /** Convert a model from the old API */
+ private[ml] def fromOld(
oldModel: OldRandomForestModel,
parent: RandomForestRegressor,
categoricalFeatures: Map[Int, Int],
@@ -211,6 +280,7 @@ private[ml] object RandomForestRegressionModel {
// parent for each tree is null since there is no good way to set this.
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
- new RandomForestRegressionModel(parent.uid, newTrees, numFeatures)
+ val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfr")
+ new RandomForestRegressionModel(uid, newTrees, numFeatures)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
index 8ea767b2b3..48b8fd19ad 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -21,12 +21,15 @@ import org.apache.hadoop.fs.Path
import org.json4s._
import org.json4s.jackson.JsonMethods._
-import org.apache.spark.ml.param.Param
-import org.apache.spark.ml.util.DefaultParamsReader
+import org.apache.spark.ml.param.{Param, Params}
+import org.apache.spark.ml.tree.DecisionTreeModelReadWrite.NodeData
+import org.apache.spark.ml.util.{DefaultParamsReader, DefaultParamsWriter}
+import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
-import org.apache.spark.sql.SQLContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{Dataset, SQLContext}
import org.apache.spark.util.collection.OpenHashMap
/**
@@ -88,6 +91,11 @@ private[ml] trait TreeEnsembleModel {
/** Trees in this ensemble. Warning: These have null parent Estimators. */
def trees: Array[DecisionTreeModel]
+ /**
+ * Number of trees in ensemble
+ */
+ val getNumTrees: Int = trees.length
+
/** Weights for each tree, zippable with [[trees]] */
def treeWeights: Array[Double]
@@ -98,7 +106,7 @@ private[ml] trait TreeEnsembleModel {
/** Summary of the model */
override def toString: String = {
// Implementing classes should generally override this method to be more descriptive.
- s"TreeEnsembleModel with $numTrees trees"
+ s"TreeEnsembleModel with ${trees.length} trees"
}
/** Full description of model */
@@ -109,9 +117,6 @@ private[ml] trait TreeEnsembleModel {
}.fold("")(_ + _)
}
- /** Number of trees in ensemble */
- val numTrees: Int = trees.length
-
/** Total number of nodes, summed over all trees in the ensemble. */
lazy val totalNumNodes: Int = trees.map(_.numNodes).sum
}
@@ -316,6 +321,10 @@ private[ml] object DecisionTreeModelReadWrite {
}
}
+ /**
+ * Load a decision tree from a file.
+ * @return Root node of reconstructed tree
+ */
def loadTreeNodes(
path: String,
metadata: DefaultParamsReader.Metadata,
@@ -331,9 +340,18 @@ private[ml] object DecisionTreeModelReadWrite {
val dataPath = new Path(path, "data").toString
val data = sqlContext.read.parquet(dataPath).as[NodeData]
+ buildTreeFromNodes(data.collect(), impurityType)
+ }
+ /**
+ * Given all data for all nodes in a tree, rebuild the tree.
+ * @param data Unsorted node data
+ * @param impurityType Impurity type for this tree
+ * @return Root node of reconstructed tree
+ */
+ def buildTreeFromNodes(data: Array[NodeData], impurityType: String): Node = {
// Load all nodes, sorted by ID.
- val nodes: Array[NodeData] = data.collect().sortBy(_.id)
+ val nodes = data.sortBy(_.id)
// Sanity checks; could remove
assert(nodes.head.id == 0, s"Decision Tree load failed. Expected smallest node ID to be 0," +
s" but found ${nodes.head.id}")
@@ -358,3 +376,100 @@ private[ml] object DecisionTreeModelReadWrite {
finalNodes.head
}
}
+
+private[ml] object EnsembleModelReadWrite {
+
+ /**
+ * Helper method for saving a tree ensemble to disk.
+ *
+ * @param instance Tree ensemble model
+ * @param path Path to which to save the ensemble model.
+ * @param extraMetadata Metadata such as numFeatures, numClasses, numTrees.
+ */
+ def saveImpl[M <: Params with TreeEnsembleModel](
+ instance: M,
+ path: String,
+ sql: SQLContext,
+ extraMetadata: JObject): Unit = {
+ DefaultParamsWriter.saveMetadata(instance, path, sql.sparkContext, Some(extraMetadata))
+ val treesMetadataJson: Array[(Int, String)] = instance.trees.zipWithIndex.map {
+ case (tree, treeID) =>
+ treeID -> DefaultParamsWriter.getMetadataToSave(tree.asInstanceOf[Params], sql.sparkContext)
+ }
+ val treesMetadataPath = new Path(path, "treesMetadata").toString
+ sql.createDataFrame(treesMetadataJson).toDF("treeID", "metadata")
+ .write.parquet(treesMetadataPath)
+ val dataPath = new Path(path, "data").toString
+ val nodeDataRDD = sql.sparkContext.parallelize(instance.trees.zipWithIndex).flatMap {
+ case (tree, treeID) => EnsembleNodeData.build(tree, treeID)
+ }
+ sql.createDataFrame(nodeDataRDD).write.parquet(dataPath)
+ }
+
+ /**
+ * Helper method for loading a tree ensemble from disk.
+ * This reconstructs all trees, returning the root nodes.
+ * @param path Path given to [[saveImpl()]]
+ * @param className Class name for ensemble model type
+ * @param treeClassName Class name for tree model type in the ensemble
+ * @return (ensemble metadata, array over trees of (tree metadata, root node)),
+ * where the root node is linked with all descendents
+ * @see [[saveImpl()]] for how the model was saved
+ */
+ def loadImpl(
+ path: String,
+ sql: SQLContext,
+ className: String,
+ treeClassName: String): (Metadata, Array[(Metadata, Node)]) = {
+ import sql.implicits._
+ implicit val format = DefaultFormats
+ val metadata = DefaultParamsReader.loadMetadata(path, sql.sparkContext, className)
+
+ // Get impurity to construct ImpurityCalculator for each node
+ val impurityType: String = {
+ val impurityJson: JValue = metadata.getParamValue("impurity")
+ Param.jsonDecode[String](compact(render(impurityJson)))
+ }
+
+ val treesMetadataPath = new Path(path, "treesMetadata").toString
+ val treesMetadataRDD: RDD[(Int, Metadata)] = sql.read.parquet(treesMetadataPath)
+ .select("treeID", "metadata").as[(Int, String)].rdd.map {
+ case (treeID: Int, json: String) =>
+ treeID -> DefaultParamsReader.parseMetadata(json, treeClassName)
+ }
+ val treesMetadata: Array[Metadata] = treesMetadataRDD.sortByKey().values.collect()
+
+ val dataPath = new Path(path, "data").toString
+ val nodeData: Dataset[EnsembleNodeData] =
+ sql.read.parquet(dataPath).as[EnsembleNodeData]
+ val rootNodesRDD: RDD[(Int, Node)] =
+ nodeData.rdd.map(d => (d.treeID, d.nodeData)).groupByKey().map {
+ case (treeID: Int, nodeData: Iterable[NodeData]) =>
+ treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes(nodeData.toArray, impurityType)
+ }
+ val rootNodes: Array[Node] = rootNodesRDD.sortByKey().values.collect()
+ (metadata, treesMetadata.zip(rootNodes))
+ }
+
+ /**
+ * Info for one [[Node]] in a tree ensemble
+ *
+ * @param treeID Tree index
+ * @param nodeData Data for this node
+ */
+ case class EnsembleNodeData(
+ treeID: Int,
+ nodeData: NodeData)
+
+ object EnsembleNodeData {
+ /**
+ * Create [[EnsembleNodeData]] instances for the given tree.
+ *
+ * @return Sequence of nodes for this tree
+ */
+ def build(tree: DecisionTreeModel, treeID: Int): Seq[EnsembleNodeData] = {
+ val (nodeData: Seq[NodeData], _) = NodeData.build(tree.rootNode, 0)
+ nodeData.map(nd => EnsembleNodeData(treeID, nd))
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index 4fbd957677..78e6d3bfac 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -315,22 +315,8 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams {
}
}
-/**
- * Parameters for Random Forest algorithms.
- *
- * Note: Marked as private and DeveloperApi since this may be made public in the future.
- */
-private[ml] trait RandomForestParams extends TreeEnsembleParams {
-
- /**
- * Number of trees to train (>= 1).
- * If 1, then no bootstrapping is used. If > 1, then bootstrapping is done.
- * TODO: Change to always do bootstrapping (simpler). SPARK-7130
- * (default = 20)
- * @group param
- */
- final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)",
- ParamValidators.gtEq(1))
+/** Used for [[RandomForestParams]] */
+private[ml] trait HasFeatureSubsetStrategy extends Params {
/**
* The number of features to consider for splits at each tree node.
@@ -362,27 +348,65 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams {
(value: String) =>
RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase))
- setDefault(numTrees -> 20, featureSubsetStrategy -> "auto")
+ setDefault(featureSubsetStrategy -> "auto")
/** @group setParam */
- def setNumTrees(value: Int): this.type = set(numTrees, value)
+ def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value)
/** @group getParam */
- final def getNumTrees: Int = $(numTrees)
+ final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase
+}
+
+/**
+ * Used for [[RandomForestParams]].
+ * This is separated out from [[RandomForestParams]] because of an issue with the
+ * `numTrees` method conflicting with this Param in the Estimator.
+ */
+private[ml] trait HasNumTrees extends Params {
+
+ /**
+ * Number of trees to train (>= 1).
+ * If 1, then no bootstrapping is used. If > 1, then bootstrapping is done.
+ * TODO: Change to always do bootstrapping (simpler). SPARK-7130
+ * (default = 20)
+ * @group param
+ */
+ final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)",
+ ParamValidators.gtEq(1))
+
+ setDefault(numTrees -> 20)
/** @group setParam */
- def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value)
+ def setNumTrees(value: Int): this.type = set(numTrees, value)
/** @group getParam */
- final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase
+ final def getNumTrees: Int = $(numTrees)
}
+/**
+ * Parameters for Random Forest algorithms.
+ */
+private[ml] trait RandomForestParams extends TreeEnsembleParams
+ with HasFeatureSubsetStrategy with HasNumTrees
+
private[spark] object RandomForestParams {
// These options should be lowercase.
final val supportedFeatureSubsetStrategies: Array[String] =
Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase)
}
+private[ml] trait RandomForestClassifierParams
+ extends RandomForestParams with TreeClassifierParams
+
+private[ml] trait RandomForestClassificationModelParams extends TreeEnsembleParams
+ with HasFeatureSubsetStrategy with TreeClassifierParams
+
+private[ml] trait RandomForestRegressorParams
+ extends RandomForestParams with TreeRegressorParams
+
+private[ml] trait RandomForestRegressionModelParams extends TreeEnsembleParams
+ with HasFeatureSubsetStrategy with TreeRegressorParams
+
/**
* Parameters for Gradient-Boosted Tree algorithms.
*
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index 39999ede30..7dec07ea14 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -144,6 +144,7 @@ private[ml] trait DefaultParamsWritable extends MLWritable { self: Params =>
/**
* Abstract class for utility classes that can load ML instances.
+ *
* @tparam T ML instance type
*/
@Experimental
@@ -162,6 +163,7 @@ abstract class MLReader[T] extends BaseReadWrite {
/**
* Trait for objects that provide [[MLReader]].
+ *
* @tparam T ML instance type
*/
@Experimental
@@ -192,6 +194,7 @@ private[ml] trait DefaultParamsReadable[T] extends MLReadable[T] {
* Default [[MLWriter]] implementation for transformers and estimators that contain basic
* (json4s-serializable) params and no data. This will not handle more complex params or types with
* data (e.g., models with coefficients).
+ *
* @param instance object to save
*/
private[ml] class DefaultParamsWriter(instance: Params) extends MLWriter {
@@ -211,6 +214,7 @@ private[ml] object DefaultParamsWriter {
* - uid
* - paramMap
* - (optionally, extra metadata)
+ *
* @param extraMetadata Extra metadata to be saved at same level as uid, paramMap, etc.
* @param paramMap If given, this is saved in the "paramMap" field.
* Otherwise, all [[org.apache.spark.ml.param.Param]]s are encoded using
@@ -222,6 +226,22 @@ private[ml] object DefaultParamsWriter {
sc: SparkContext,
extraMetadata: Option[JObject] = None,
paramMap: Option[JValue] = None): Unit = {
+ val metadataPath = new Path(path, "metadata").toString
+ val metadataJson = getMetadataToSave(instance, sc, extraMetadata, paramMap)
+ sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
+ }
+
+ /**
+ * Helper for [[saveMetadata()]] which extracts the JSON to save.
+ * This is useful for ensemble models which need to save metadata for many sub-models.
+ *
+ * @see [[saveMetadata()]] for details on what this includes.
+ */
+ def getMetadataToSave(
+ instance: Params,
+ sc: SparkContext,
+ extraMetadata: Option[JObject] = None,
+ paramMap: Option[JValue] = None): String = {
val uid = instance.uid
val cls = instance.getClass.getName
val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
@@ -239,9 +259,8 @@ private[ml] object DefaultParamsWriter {
case None =>
basicMetadata
}
- val metadataPath = new Path(path, "metadata").toString
- val metadataJson = compact(render(metadata))
- sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
+ val metadataJson: String = compact(render(metadata))
+ metadataJson
}
}
@@ -249,6 +268,7 @@ private[ml] object DefaultParamsWriter {
* Default [[MLReader]] implementation for transformers and estimators that contain basic
* (json4s-serializable) params and no data. This will not handle more complex params or types with
* data (e.g., models with coefficients).
+ *
* @tparam T ML instance type
* TODO: Consider adding check for correct class name.
*/
@@ -268,6 +288,7 @@ private[ml] object DefaultParamsReader {
/**
* All info from metadata file.
+ *
* @param params paramMap, as a [[JValue]]
* @param metadata All metadata, including the other fields
* @param metadataJson Full metadata file String (for debugging)
@@ -304,13 +325,26 @@ private[ml] object DefaultParamsReader {
}
/**
- * Load metadata from file.
+ * Load metadata saved using [[DefaultParamsWriter.saveMetadata()]]
+ *
* @param expectedClassName If non empty, this is checked against the loaded metadata.
* @throws IllegalArgumentException if expectedClassName is specified and does not match metadata
*/
def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = {
val metadataPath = new Path(path, "metadata").toString
val metadataStr = sc.textFile(metadataPath, 1).first()
+ parseMetadata(metadataStr, expectedClassName)
+ }
+
+ /**
+ * Parse metadata JSON string produced by [[DefaultParamsWriter.getMetadataToSave()]].
+ * This is a helper function for [[loadMetadata()]].
+ *
+ * @param metadataStr JSON string of metadata
+ * @param expectedClassName If non empty, this is checked against the loaded metadata.
+ * @throws IllegalArgumentException if expectedClassName is specified and does not match metadata
+ */
+ def parseMetadata(metadataStr: String, expectedClassName: String = ""): Metadata = {
val metadata = parse(metadataStr)
implicit val format = DefaultFormats
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index 052bc83c38..aaaa429103 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.tree.LeafNode
import org.apache.spark.ml.tree.impl.TreeTests
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
@@ -34,7 +34,8 @@ import org.apache.spark.sql.{DataFrame, Row}
/**
* Test suite for [[RandomForestClassifier]].
*/
-class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
+class RandomForestClassifierSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
import RandomForestClassifierSuite.compareAPIs
@@ -190,27 +191,24 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
- // TODO: Reinstate test once save/load are implemented SPARK-6725
- /*
- test("model save/load") {
- val tempDir = Utils.createTempDir()
- val path = tempDir.toURI.toString
-
- val trees =
- Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Classification)).toArray
- val oldModel = new OldRandomForestModel(OldAlgo.Classification, trees)
- val newModel = RandomForestClassificationModel.fromOld(oldModel)
-
- // Save model, load it back, and compare.
- try {
- newModel.save(sc, path)
- val sameNewModel = RandomForestClassificationModel.load(sc, path)
- TreeTests.checkEqual(newModel, sameNewModel)
- } finally {
- Utils.deleteRecursively(tempDir)
+ test("read/write") {
+ def checkModelData(
+ model: RandomForestClassificationModel,
+ model2: RandomForestClassificationModel): Unit = {
+ TreeTests.checkEqual(model, model2)
+ assert(model.numFeatures === model2.numFeatures)
+ assert(model.numClasses === model2.numClasses)
}
+
+ val rf = new RandomForestClassifier().setNumTrees(2)
+ val rdd = TreeTests.getTreeReadWriteData(sc)
+
+ val allParamSettings = TreeTests.allParamSettings ++ Map("impurity" -> "entropy")
+
+ val continuousData: DataFrame =
+ TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2)
+ testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, checkModelData)
}
- */
}
private object RandomForestClassifierSuite extends SparkFunSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
index 2ab4f1b146..ca400e1914 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.tree.impl.TreeTests
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
@@ -30,7 +30,8 @@ import org.apache.spark.sql.DataFrame
/**
* Test suite for [[RandomForestRegressor]].
*/
-class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
+class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContext
+ with DefaultReadWriteTest{
import RandomForestRegressorSuite.compareAPIs
@@ -106,26 +107,23 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
- // TODO: Reinstate test once save/load are implemented SPARK-6725
- /*
- test("model save/load") {
- val tempDir = Utils.createTempDir()
- val path = tempDir.toURI.toString
-
- val trees = Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray
- val oldModel = new OldRandomForestModel(OldAlgo.Regression, trees)
- val newModel = RandomForestRegressionModel.fromOld(oldModel)
-
- // Save model, load it back, and compare.
- try {
- newModel.save(sc, path)
- val sameNewModel = RandomForestRegressionModel.load(sc, path)
- TreeTests.checkEqual(newModel, sameNewModel)
- } finally {
- Utils.deleteRecursively(tempDir)
+ test("read/write") {
+ def checkModelData(
+ model: RandomForestRegressionModel,
+ model2: RandomForestRegressionModel): Unit = {
+ TreeTests.checkEqual(model, model2)
+ assert(model.numFeatures === model2.numFeatures)
}
+
+ val rf = new RandomForestRegressor().setNumTrees(2)
+ val rdd = TreeTests.getTreeReadWriteData(sc)
+
+ val allParamSettings = TreeTests.allParamSettings ++ Map("impurity" -> "variance")
+
+ val continuousData: DataFrame =
+ TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0)
+ testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, checkModelData)
}
- */
}
private object RandomForestRegressorSuite extends SparkFunSuite {