aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala7
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala100
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala7
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala96
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala131
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala66
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala42
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala40
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala38
9 files changed, 424 insertions, 103 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index 48ce051d0a..bfefaf1a1a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -192,7 +192,7 @@ final class GBTClassificationModel private[ml](
extends PredictionModel[Vector, GBTClassificationModel]
with TreeEnsembleModel with Serializable {
- require(numTrees > 0, "GBTClassificationModel requires at least 1 tree.")
+ require(_trees.nonEmpty, "GBTClassificationModel requires at least 1 tree.")
require(_trees.length == _treeWeights.length, "GBTClassificationModel given trees, treeWeights" +
s" of non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).")
@@ -227,6 +227,9 @@ final class GBTClassificationModel private[ml](
if (prediction > 0.0) 1.0 else 0.0
}
+ /** Number of trees in ensemble */
+ val numTrees: Int = trees.length
+
@Since("1.4.0")
override def copy(extra: ParamMap): GBTClassificationModel = {
copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures),
@@ -272,6 +275,6 @@ private[ml] object GBTClassificationModel {
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtc")
- new GBTClassificationModel(parent.uid, newTrees, oldModel.treeWeights, numFeatures)
+ new GBTClassificationModel(uid, newTrees, oldModel.treeWeights, numFeatures)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 82fa05a604..2ad893f4fa 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -17,11 +17,15 @@
package org.apache.spark.ml.classification
+import org.json4s.{DefaultFormats, JObject}
+import org.json4s.JsonDSL._
+
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel}
+import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.impl.RandomForest
-import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
+import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
@@ -43,7 +47,7 @@ import org.apache.spark.sql.functions._
final class RandomForestClassifier @Since("1.4.0") (
@Since("1.4.0") override val uid: String)
extends ProbabilisticClassifier[Vector, RandomForestClassifier, RandomForestClassificationModel]
- with RandomForestParams with TreeClassifierParams {
+ with RandomForestClassifierParams with DefaultParamsWritable {
@Since("1.4.0")
def this() = this(Identifiable.randomUID("rfc"))
@@ -120,7 +124,7 @@ final class RandomForestClassifier @Since("1.4.0") (
@Since("1.4.0")
@Experimental
-object RandomForestClassifier {
+object RandomForestClassifier extends DefaultParamsReadable[RandomForestClassifier] {
/** Accessor for supported impurity settings: entropy, gini */
@Since("1.4.0")
final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities
@@ -129,6 +133,9 @@ object RandomForestClassifier {
@Since("1.4.0")
final val supportedFeatureSubsetStrategies: Array[String] =
RandomForestParams.supportedFeatureSubsetStrategies
+
+ @Since("2.0.0")
+ override def load(path: String): RandomForestClassifier = super.load(path)
}
/**
@@ -136,8 +143,9 @@ object RandomForestClassifier {
* [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for classification.
* It supports both binary and multiclass labels, as well as both continuous and categorical
* features.
+ *
* @param _trees Decision trees in the ensemble.
- * Warning: These have null parents.
+ * Warning: These have null parents.
*/
@Since("1.4.0")
@Experimental
@@ -147,12 +155,14 @@ final class RandomForestClassificationModel private[ml] (
@Since("1.6.0") override val numFeatures: Int,
@Since("1.5.0") override val numClasses: Int)
extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel]
- with TreeEnsembleModel with Serializable {
+ with RandomForestClassificationModelParams with TreeEnsembleModel with MLWritable
+ with Serializable {
- require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.")
+ require(_trees.nonEmpty, "RandomForestClassificationModel requires at least 1 tree.")
/**
* Construct a random forest classification model, with all trees weighted equally.
+ *
* @param trees Component trees
*/
private[ml] def this(
@@ -165,7 +175,7 @@ final class RandomForestClassificationModel private[ml] (
override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
// Note: We may add support for weights (based on tree performance) later on.
- private lazy val _treeWeights: Array[Double] = Array.fill[Double](numTrees)(1.0)
+ private lazy val _treeWeights: Array[Double] = Array.fill[Double](_trees.length)(1.0)
@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights
@@ -208,6 +218,15 @@ final class RandomForestClassificationModel private[ml] (
}
}
+ /**
+ * Number of trees in ensemble
+ *
+ * @deprecated Use [[getNumTrees]] instead. This method will be removed in 2.1.0
+ */
+ // TODO: Once this is removed, then this class can inherit from RandomForestClassifierParams
+ @deprecated("Use getNumTrees instead. This method will be removed in 2.1.0.", "2.0.0")
+ val numTrees: Int = trees.length
+
@Since("1.4.0")
override def copy(extra: ParamMap): RandomForestClassificationModel = {
copyValues(new RandomForestClassificationModel(uid, _trees, numFeatures, numClasses), extra)
@@ -216,7 +235,7 @@ final class RandomForestClassificationModel private[ml] (
@Since("1.4.0")
override def toString: String = {
- s"RandomForestClassificationModel (uid=$uid) with $numTrees trees"
+ s"RandomForestClassificationModel (uid=$uid) with $getNumTrees trees"
}
/**
@@ -236,12 +255,69 @@ final class RandomForestClassificationModel private[ml] (
private[ml] def toOld: OldRandomForestModel = {
new OldRandomForestModel(OldAlgo.Classification, _trees.map(_.toOld))
}
+
+ @Since("2.0.0")
+ override def write: MLWriter =
+ new RandomForestClassificationModel.RandomForestClassificationModelWriter(this)
}
-private[ml] object RandomForestClassificationModel {
+@Since("2.0.0")
+object RandomForestClassificationModel extends MLReadable[RandomForestClassificationModel] {
+
+ @Since("2.0.0")
+ override def read: MLReader[RandomForestClassificationModel] =
+ new RandomForestClassificationModelReader
+
+ @Since("2.0.0")
+ override def load(path: String): RandomForestClassificationModel = super.load(path)
+
+ private[RandomForestClassificationModel]
+ class RandomForestClassificationModelWriter(instance: RandomForestClassificationModel)
+ extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ // Note: numTrees is not currently used, but could be nice to store for fast querying.
+ val extraMetadata: JObject = Map(
+ "numFeatures" -> instance.numFeatures,
+ "numClasses" -> instance.numClasses,
+ "numTrees" -> instance.getNumTrees)
+ EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata)
+ }
+ }
+
+ private class RandomForestClassificationModelReader
+ extends MLReader[RandomForestClassificationModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[RandomForestClassificationModel].getName
+ private val treeClassName = classOf[DecisionTreeClassificationModel].getName
+
+ override def load(path: String): RandomForestClassificationModel = {
+ implicit val format = DefaultFormats
+ val (metadata: Metadata, treesData: Array[(Metadata, Node)]) =
+ EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName)
+ val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
+ val numClasses = (metadata.metadata \ "numClasses").extract[Int]
+ val numTrees = (metadata.metadata \ "numTrees").extract[Int]
+
+ val trees: Array[DecisionTreeClassificationModel] = treesData.map {
+ case (treeMetadata, root) =>
+ val tree =
+ new DecisionTreeClassificationModel(treeMetadata.uid, root, numFeatures, numClasses)
+ DefaultParamsReader.getAndSetParams(tree, treeMetadata)
+ tree
+ }
+ require(numTrees == trees.length, s"RandomForestClassificationModel.load expected $numTrees" +
+ s" trees based on metadata but found ${trees.length} trees.")
+
+ val model = new RandomForestClassificationModel(metadata.uid, trees, numFeatures, numClasses)
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
- /** (private[ml]) Convert a model from the old API */
- def fromOld(
+ /** Convert a model from the old API */
+ private[ml] def fromOld(
oldModel: OldRandomForestModel,
parent: RandomForestClassifier,
categoricalFeatures: Map[Int, Int],
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index 8fca35da51..02e124a1c0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -179,7 +179,7 @@ final class GBTRegressionModel private[ml](
extends PredictionModel[Vector, GBTRegressionModel]
with TreeEnsembleModel with Serializable {
- require(numTrees > 0, "GBTRegressionModel requires at least 1 tree.")
+ require(_trees.nonEmpty, "GBTRegressionModel requires at least 1 tree.")
require(_trees.length == _treeWeights.length, "GBTRegressionModel given trees, treeWeights of" +
s" non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).")
@@ -213,6 +213,9 @@ final class GBTRegressionModel private[ml](
blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
}
+ /** Number of trees in ensemble */
+ val numTrees: Int = trees.length
+
@Since("1.4.0")
override def copy(extra: ParamMap): GBTRegressionModel = {
copyValues(new GBTRegressionModel(uid, _trees, _treeWeights, numFeatures),
@@ -258,6 +261,6 @@ private[ml] object GBTRegressionModel {
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtr")
- new GBTRegressionModel(parent.uid, newTrees, oldModel.treeWeights, numFeatures)
+ new GBTRegressionModel(uid, newTrees, oldModel.treeWeights, numFeatures)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index 5b3f3a1f5d..ba56b5cd3f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -17,12 +17,16 @@
package org.apache.spark.ml.regression
+import org.json4s.{DefaultFormats, JObject}
+import org.json4s.JsonDSL._
+
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeEnsembleModel, TreeRegressorParams}
+import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.impl.RandomForest
-import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
+import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
@@ -41,7 +45,7 @@ import org.apache.spark.sql.functions._
@Experimental
final class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Predictor[Vector, RandomForestRegressor, RandomForestRegressionModel]
- with RandomForestParams with TreeRegressorParams {
+ with RandomForestRegressorParams with DefaultParamsWritable {
@Since("1.4.0")
def this() = this(Identifiable.randomUID("rfr"))
@@ -108,7 +112,7 @@ final class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val
@Since("1.4.0")
@Experimental
-object RandomForestRegressor {
+object RandomForestRegressor extends DefaultParamsReadable[RandomForestRegressor]{
/** Accessor for supported impurity settings: variance */
@Since("1.4.0")
final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities
@@ -117,12 +121,17 @@ object RandomForestRegressor {
@Since("1.4.0")
final val supportedFeatureSubsetStrategies: Array[String] =
RandomForestParams.supportedFeatureSubsetStrategies
+
+ @Since("2.0.0")
+ override def load(path: String): RandomForestRegressor = super.load(path)
+
}
/**
* :: Experimental ::
* [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for regression.
* It supports both continuous and categorical features.
+ *
* @param _trees Decision trees in the ensemble.
* @param numFeatures Number of features used by this model
*/
@@ -133,12 +142,13 @@ final class RandomForestRegressionModel private[ml] (
private val _trees: Array[DecisionTreeRegressionModel],
override val numFeatures: Int)
extends PredictionModel[Vector, RandomForestRegressionModel]
- with TreeEnsembleModel with Serializable {
+ with RandomForestRegressionModelParams with TreeEnsembleModel with MLWritable with Serializable {
- require(numTrees > 0, "RandomForestRegressionModel requires at least 1 tree.")
+ require(_trees.nonEmpty, "RandomForestRegressionModel requires at least 1 tree.")
/**
* Construct a random forest regression model, with all trees weighted equally.
+ *
* @param trees Component trees
*/
private[ml] def this(trees: Array[DecisionTreeRegressionModel], numFeatures: Int) =
@@ -148,7 +158,7 @@ final class RandomForestRegressionModel private[ml] (
override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
// Note: We may add support for weights (based on tree performance) later on.
- private lazy val _treeWeights: Array[Double] = Array.fill[Double](numTrees)(1.0)
+ private lazy val _treeWeights: Array[Double] = Array.fill[Double](_trees.length)(1.0)
@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights
@@ -165,9 +175,17 @@ final class RandomForestRegressionModel private[ml] (
// TODO: When we add a generic Bagging class, handle transform there. SPARK-7128
// Predict average of tree predictions.
// Ignore the weights since all are 1.0 for now.
- _trees.map(_.rootNode.predictImpl(features).prediction).sum / numTrees
+ _trees.map(_.rootNode.predictImpl(features).prediction).sum / getNumTrees
}
+ /**
+ * Number of trees in ensemble
+ * @deprecated Use [[getNumTrees]] instead. This method will be removed in 2.1.0
+ */
+ // TODO: Once this is removed, then this class can inherit from RandomForestRegressorParams
+ @deprecated("Use getNumTrees instead. This method will be removed in 2.1.0.", "2.0.0")
+ val numTrees: Int = trees.length
+
@Since("1.4.0")
override def copy(extra: ParamMap): RandomForestRegressionModel = {
copyValues(new RandomForestRegressionModel(uid, _trees, numFeatures), extra).setParent(parent)
@@ -175,7 +193,7 @@ final class RandomForestRegressionModel private[ml] (
@Since("1.4.0")
override def toString: String = {
- s"RandomForestRegressionModel (uid=$uid) with $numTrees trees"
+ s"RandomForestRegressionModel (uid=$uid) with $getNumTrees trees"
}
/**
@@ -195,12 +213,63 @@ final class RandomForestRegressionModel private[ml] (
private[ml] def toOld: OldRandomForestModel = {
new OldRandomForestModel(OldAlgo.Regression, _trees.map(_.toOld))
}
+
+ @Since("2.0.0")
+ override def write: MLWriter =
+ new RandomForestRegressionModel.RandomForestRegressionModelWriter(this)
}
-private[ml] object RandomForestRegressionModel {
+@Since("2.0.0")
+object RandomForestRegressionModel extends MLReadable[RandomForestRegressionModel] {
+
+ @Since("2.0.0")
+ override def read: MLReader[RandomForestRegressionModel] = new RandomForestRegressionModelReader
+
+ @Since("2.0.0")
+ override def load(path: String): RandomForestRegressionModel = super.load(path)
+
+ private[RandomForestRegressionModel]
+ class RandomForestRegressionModelWriter(instance: RandomForestRegressionModel)
+ extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ val extraMetadata: JObject = Map(
+ "numFeatures" -> instance.numFeatures,
+ "numTrees" -> instance.getNumTrees)
+ EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata)
+ }
+ }
+
+ private class RandomForestRegressionModelReader extends MLReader[RandomForestRegressionModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[RandomForestRegressionModel].getName
+ private val treeClassName = classOf[DecisionTreeRegressionModel].getName
+
+ override def load(path: String): RandomForestRegressionModel = {
+ implicit val format = DefaultFormats
+ val (metadata: Metadata, treesData: Array[(Metadata, Node)]) =
+ EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName)
+ val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
+ val numTrees = (metadata.metadata \ "numTrees").extract[Int]
+
+ val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) =>
+ val tree =
+ new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures)
+ DefaultParamsReader.getAndSetParams(tree, treeMetadata)
+ tree
+ }
+ require(numTrees == trees.length, s"RandomForestRegressionModel.load expected $numTrees" +
+ s" trees based on metadata but found ${trees.length} trees.")
+
+ val model = new RandomForestRegressionModel(metadata.uid, trees, numFeatures)
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
- /** (private[ml]) Convert a model from the old API */
- def fromOld(
+ /** Convert a model from the old API */
+ private[ml] def fromOld(
oldModel: OldRandomForestModel,
parent: RandomForestRegressor,
categoricalFeatures: Map[Int, Int],
@@ -211,6 +280,7 @@ private[ml] object RandomForestRegressionModel {
// parent for each tree is null since there is no good way to set this.
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
- new RandomForestRegressionModel(parent.uid, newTrees, numFeatures)
+ val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfr")
+ new RandomForestRegressionModel(uid, newTrees, numFeatures)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
index 8ea767b2b3..48b8fd19ad 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -21,12 +21,15 @@ import org.apache.hadoop.fs.Path
import org.json4s._
import org.json4s.jackson.JsonMethods._
-import org.apache.spark.ml.param.Param
-import org.apache.spark.ml.util.DefaultParamsReader
+import org.apache.spark.ml.param.{Param, Params}
+import org.apache.spark.ml.tree.DecisionTreeModelReadWrite.NodeData
+import org.apache.spark.ml.util.{DefaultParamsReader, DefaultParamsWriter}
+import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
-import org.apache.spark.sql.SQLContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{Dataset, SQLContext}
import org.apache.spark.util.collection.OpenHashMap
/**
@@ -88,6 +91,11 @@ private[ml] trait TreeEnsembleModel {
/** Trees in this ensemble. Warning: These have null parent Estimators. */
def trees: Array[DecisionTreeModel]
+ /**
+ * Number of trees in ensemble
+ */
+ val getNumTrees: Int = trees.length
+
/** Weights for each tree, zippable with [[trees]] */
def treeWeights: Array[Double]
@@ -98,7 +106,7 @@ private[ml] trait TreeEnsembleModel {
/** Summary of the model */
override def toString: String = {
// Implementing classes should generally override this method to be more descriptive.
- s"TreeEnsembleModel with $numTrees trees"
+ s"TreeEnsembleModel with ${trees.length} trees"
}
/** Full description of model */
@@ -109,9 +117,6 @@ private[ml] trait TreeEnsembleModel {
}.fold("")(_ + _)
}
- /** Number of trees in ensemble */
- val numTrees: Int = trees.length
-
/** Total number of nodes, summed over all trees in the ensemble. */
lazy val totalNumNodes: Int = trees.map(_.numNodes).sum
}
@@ -316,6 +321,10 @@ private[ml] object DecisionTreeModelReadWrite {
}
}
+ /**
+ * Load a decision tree from a file.
+ * @return Root node of reconstructed tree
+ */
def loadTreeNodes(
path: String,
metadata: DefaultParamsReader.Metadata,
@@ -331,9 +340,18 @@ private[ml] object DecisionTreeModelReadWrite {
val dataPath = new Path(path, "data").toString
val data = sqlContext.read.parquet(dataPath).as[NodeData]
+ buildTreeFromNodes(data.collect(), impurityType)
+ }
+ /**
+ * Given all data for all nodes in a tree, rebuild the tree.
+ * @param data Unsorted node data
+ * @param impurityType Impurity type for this tree
+ * @return Root node of reconstructed tree
+ */
+ def buildTreeFromNodes(data: Array[NodeData], impurityType: String): Node = {
// Load all nodes, sorted by ID.
- val nodes: Array[NodeData] = data.collect().sortBy(_.id)
+ val nodes = data.sortBy(_.id)
// Sanity checks; could remove
assert(nodes.head.id == 0, s"Decision Tree load failed. Expected smallest node ID to be 0," +
s" but found ${nodes.head.id}")
@@ -358,3 +376,100 @@ private[ml] object DecisionTreeModelReadWrite {
finalNodes.head
}
}
+
+private[ml] object EnsembleModelReadWrite {
+
+ /**
+ * Helper method for saving a tree ensemble to disk.
+ *
+ * @param instance Tree ensemble model
+ * @param path Path to which to save the ensemble model.
+ * @param extraMetadata Metadata such as numFeatures, numClasses, numTrees.
+ */
+ def saveImpl[M <: Params with TreeEnsembleModel](
+ instance: M,
+ path: String,
+ sql: SQLContext,
+ extraMetadata: JObject): Unit = {
+ DefaultParamsWriter.saveMetadata(instance, path, sql.sparkContext, Some(extraMetadata))
+ val treesMetadataJson: Array[(Int, String)] = instance.trees.zipWithIndex.map {
+ case (tree, treeID) =>
+ treeID -> DefaultParamsWriter.getMetadataToSave(tree.asInstanceOf[Params], sql.sparkContext)
+ }
+ val treesMetadataPath = new Path(path, "treesMetadata").toString
+ sql.createDataFrame(treesMetadataJson).toDF("treeID", "metadata")
+ .write.parquet(treesMetadataPath)
+ val dataPath = new Path(path, "data").toString
+ val nodeDataRDD = sql.sparkContext.parallelize(instance.trees.zipWithIndex).flatMap {
+ case (tree, treeID) => EnsembleNodeData.build(tree, treeID)
+ }
+ sql.createDataFrame(nodeDataRDD).write.parquet(dataPath)
+ }
+
+ /**
+ * Helper method for loading a tree ensemble from disk.
+ * This reconstructs all trees, returning the root nodes.
+ * @param path Path given to [[saveImpl()]]
+ * @param className Class name for ensemble model type
+ * @param treeClassName Class name for tree model type in the ensemble
+ * @return (ensemble metadata, array over trees of (tree metadata, root node)),
+ * where the root node is linked with all descendents
+ * @see [[saveImpl()]] for how the model was saved
+ */
+ def loadImpl(
+ path: String,
+ sql: SQLContext,
+ className: String,
+ treeClassName: String): (Metadata, Array[(Metadata, Node)]) = {
+ import sql.implicits._
+ implicit val format = DefaultFormats
+ val metadata = DefaultParamsReader.loadMetadata(path, sql.sparkContext, className)
+
+ // Get impurity to construct ImpurityCalculator for each node
+ val impurityType: String = {
+ val impurityJson: JValue = metadata.getParamValue("impurity")
+ Param.jsonDecode[String](compact(render(impurityJson)))
+ }
+
+ val treesMetadataPath = new Path(path, "treesMetadata").toString
+ val treesMetadataRDD: RDD[(Int, Metadata)] = sql.read.parquet(treesMetadataPath)
+ .select("treeID", "metadata").as[(Int, String)].rdd.map {
+ case (treeID: Int, json: String) =>
+ treeID -> DefaultParamsReader.parseMetadata(json, treeClassName)
+ }
+ val treesMetadata: Array[Metadata] = treesMetadataRDD.sortByKey().values.collect()
+
+ val dataPath = new Path(path, "data").toString
+ val nodeData: Dataset[EnsembleNodeData] =
+ sql.read.parquet(dataPath).as[EnsembleNodeData]
+ val rootNodesRDD: RDD[(Int, Node)] =
+ nodeData.rdd.map(d => (d.treeID, d.nodeData)).groupByKey().map {
+ case (treeID: Int, nodeData: Iterable[NodeData]) =>
+ treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes(nodeData.toArray, impurityType)
+ }
+ val rootNodes: Array[Node] = rootNodesRDD.sortByKey().values.collect()
+ (metadata, treesMetadata.zip(rootNodes))
+ }
+
+ /**
+ * Info for one [[Node]] in a tree ensemble
+ *
+ * @param treeID Tree index
+ * @param nodeData Data for this node
+ */
+ case class EnsembleNodeData(
+ treeID: Int,
+ nodeData: NodeData)
+
+ object EnsembleNodeData {
+ /**
+ * Create [[EnsembleNodeData]] instances for the given tree.
+ *
+ * @return Sequence of nodes for this tree
+ */
+ def build(tree: DecisionTreeModel, treeID: Int): Seq[EnsembleNodeData] = {
+ val (nodeData: Seq[NodeData], _) = NodeData.build(tree.rootNode, 0)
+ nodeData.map(nd => EnsembleNodeData(treeID, nd))
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index 4fbd957677..78e6d3bfac 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -315,22 +315,8 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams {
}
}
-/**
- * Parameters for Random Forest algorithms.
- *
- * Note: Marked as private and DeveloperApi since this may be made public in the future.
- */
-private[ml] trait RandomForestParams extends TreeEnsembleParams {
-
- /**
- * Number of trees to train (>= 1).
- * If 1, then no bootstrapping is used. If > 1, then bootstrapping is done.
- * TODO: Change to always do bootstrapping (simpler). SPARK-7130
- * (default = 20)
- * @group param
- */
- final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)",
- ParamValidators.gtEq(1))
+/** Used for [[RandomForestParams]] */
+private[ml] trait HasFeatureSubsetStrategy extends Params {
/**
* The number of features to consider for splits at each tree node.
@@ -362,27 +348,65 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams {
(value: String) =>
RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase))
- setDefault(numTrees -> 20, featureSubsetStrategy -> "auto")
+ setDefault(featureSubsetStrategy -> "auto")
/** @group setParam */
- def setNumTrees(value: Int): this.type = set(numTrees, value)
+ def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value)
/** @group getParam */
- final def getNumTrees: Int = $(numTrees)
+ final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase
+}
+
+/**
+ * Used for [[RandomForestParams]].
+ * This is separated out from [[RandomForestParams]] because of an issue with the
+ * `numTrees` method conflicting with this Param in the Estimator.
+ */
+private[ml] trait HasNumTrees extends Params {
+
+ /**
+ * Number of trees to train (>= 1).
+ * If 1, then no bootstrapping is used. If > 1, then bootstrapping is done.
+ * TODO: Change to always do bootstrapping (simpler). SPARK-7130
+ * (default = 20)
+ * @group param
+ */
+ final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)",
+ ParamValidators.gtEq(1))
+
+ setDefault(numTrees -> 20)
/** @group setParam */
- def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value)
+ def setNumTrees(value: Int): this.type = set(numTrees, value)
/** @group getParam */
- final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase
+ final def getNumTrees: Int = $(numTrees)
}
+/**
+ * Parameters for Random Forest algorithms.
+ */
+private[ml] trait RandomForestParams extends TreeEnsembleParams
+ with HasFeatureSubsetStrategy with HasNumTrees
+
private[spark] object RandomForestParams {
// These options should be lowercase.
final val supportedFeatureSubsetStrategies: Array[String] =
Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase)
}
+private[ml] trait RandomForestClassifierParams
+ extends RandomForestParams with TreeClassifierParams
+
+private[ml] trait RandomForestClassificationModelParams extends TreeEnsembleParams
+ with HasFeatureSubsetStrategy with TreeClassifierParams
+
+private[ml] trait RandomForestRegressorParams
+ extends RandomForestParams with TreeRegressorParams
+
+private[ml] trait RandomForestRegressionModelParams extends TreeEnsembleParams
+ with HasFeatureSubsetStrategy with TreeRegressorParams
+
/**
* Parameters for Gradient-Boosted Tree algorithms.
*
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index 39999ede30..7dec07ea14 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -144,6 +144,7 @@ private[ml] trait DefaultParamsWritable extends MLWritable { self: Params =>
/**
* Abstract class for utility classes that can load ML instances.
+ *
* @tparam T ML instance type
*/
@Experimental
@@ -162,6 +163,7 @@ abstract class MLReader[T] extends BaseReadWrite {
/**
* Trait for objects that provide [[MLReader]].
+ *
* @tparam T ML instance type
*/
@Experimental
@@ -192,6 +194,7 @@ private[ml] trait DefaultParamsReadable[T] extends MLReadable[T] {
* Default [[MLWriter]] implementation for transformers and estimators that contain basic
* (json4s-serializable) params and no data. This will not handle more complex params or types with
* data (e.g., models with coefficients).
+ *
* @param instance object to save
*/
private[ml] class DefaultParamsWriter(instance: Params) extends MLWriter {
@@ -211,6 +214,7 @@ private[ml] object DefaultParamsWriter {
* - uid
* - paramMap
* - (optionally, extra metadata)
+ *
* @param extraMetadata Extra metadata to be saved at same level as uid, paramMap, etc.
* @param paramMap If given, this is saved in the "paramMap" field.
* Otherwise, all [[org.apache.spark.ml.param.Param]]s are encoded using
@@ -222,6 +226,22 @@ private[ml] object DefaultParamsWriter {
sc: SparkContext,
extraMetadata: Option[JObject] = None,
paramMap: Option[JValue] = None): Unit = {
+ val metadataPath = new Path(path, "metadata").toString
+ val metadataJson = getMetadataToSave(instance, sc, extraMetadata, paramMap)
+ sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
+ }
+
+ /**
+ * Helper for [[saveMetadata()]] which extracts the JSON to save.
+ * This is useful for ensemble models which need to save metadata for many sub-models.
+ *
+ * @see [[saveMetadata()]] for details on what this includes.
+ */
+ def getMetadataToSave(
+ instance: Params,
+ sc: SparkContext,
+ extraMetadata: Option[JObject] = None,
+ paramMap: Option[JValue] = None): String = {
val uid = instance.uid
val cls = instance.getClass.getName
val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
@@ -239,9 +259,8 @@ private[ml] object DefaultParamsWriter {
case None =>
basicMetadata
}
- val metadataPath = new Path(path, "metadata").toString
- val metadataJson = compact(render(metadata))
- sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
+ val metadataJson: String = compact(render(metadata))
+ metadataJson
}
}
@@ -249,6 +268,7 @@ private[ml] object DefaultParamsWriter {
* Default [[MLReader]] implementation for transformers and estimators that contain basic
* (json4s-serializable) params and no data. This will not handle more complex params or types with
* data (e.g., models with coefficients).
+ *
* @tparam T ML instance type
* TODO: Consider adding check for correct class name.
*/
@@ -268,6 +288,7 @@ private[ml] object DefaultParamsReader {
/**
* All info from metadata file.
+ *
* @param params paramMap, as a [[JValue]]
* @param metadata All metadata, including the other fields
* @param metadataJson Full metadata file String (for debugging)
@@ -304,13 +325,26 @@ private[ml] object DefaultParamsReader {
}
/**
- * Load metadata from file.
+ * Load metadata saved using [[DefaultParamsWriter.saveMetadata()]]
+ *
* @param expectedClassName If non empty, this is checked against the loaded metadata.
* @throws IllegalArgumentException if expectedClassName is specified and does not match metadata
*/
def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = {
val metadataPath = new Path(path, "metadata").toString
val metadataStr = sc.textFile(metadataPath, 1).first()
+ parseMetadata(metadataStr, expectedClassName)
+ }
+
+ /**
+ * Parse metadata JSON string produced by [[DefaultParamsWriter.getMetadataToSave()]].
+ * This is a helper function for [[loadMetadata()]].
+ *
+ * @param metadataStr JSON string of metadata
+ * @param expectedClassName If non empty, this is checked against the loaded metadata.
+ * @throws IllegalArgumentException if expectedClassName is specified and does not match metadata
+ */
+ def parseMetadata(metadataStr: String, expectedClassName: String = ""): Metadata = {
val metadata = parse(metadataStr)
implicit val format = DefaultFormats
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index 052bc83c38..aaaa429103 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.tree.LeafNode
import org.apache.spark.ml.tree.impl.TreeTests
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
@@ -34,7 +34,8 @@ import org.apache.spark.sql.{DataFrame, Row}
/**
* Test suite for [[RandomForestClassifier]].
*/
-class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
+class RandomForestClassifierSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
import RandomForestClassifierSuite.compareAPIs
@@ -190,27 +191,24 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
- // TODO: Reinstate test once save/load are implemented SPARK-6725
- /*
- test("model save/load") {
- val tempDir = Utils.createTempDir()
- val path = tempDir.toURI.toString
-
- val trees =
- Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Classification)).toArray
- val oldModel = new OldRandomForestModel(OldAlgo.Classification, trees)
- val newModel = RandomForestClassificationModel.fromOld(oldModel)
-
- // Save model, load it back, and compare.
- try {
- newModel.save(sc, path)
- val sameNewModel = RandomForestClassificationModel.load(sc, path)
- TreeTests.checkEqual(newModel, sameNewModel)
- } finally {
- Utils.deleteRecursively(tempDir)
+ test("read/write") {
+ def checkModelData(
+ model: RandomForestClassificationModel,
+ model2: RandomForestClassificationModel): Unit = {
+ TreeTests.checkEqual(model, model2)
+ assert(model.numFeatures === model2.numFeatures)
+ assert(model.numClasses === model2.numClasses)
}
+
+ val rf = new RandomForestClassifier().setNumTrees(2)
+ val rdd = TreeTests.getTreeReadWriteData(sc)
+
+ val allParamSettings = TreeTests.allParamSettings ++ Map("impurity" -> "entropy")
+
+ val continuousData: DataFrame =
+ TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2)
+ testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, checkModelData)
}
- */
}
private object RandomForestClassifierSuite extends SparkFunSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
index 2ab4f1b146..ca400e1914 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.tree.impl.TreeTests
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
@@ -30,7 +30,8 @@ import org.apache.spark.sql.DataFrame
/**
* Test suite for [[RandomForestRegressor]].
*/
-class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
+class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContext
+ with DefaultReadWriteTest{
import RandomForestRegressorSuite.compareAPIs
@@ -106,26 +107,23 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
- // TODO: Reinstate test once save/load are implemented SPARK-6725
- /*
- test("model save/load") {
- val tempDir = Utils.createTempDir()
- val path = tempDir.toURI.toString
-
- val trees = Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray
- val oldModel = new OldRandomForestModel(OldAlgo.Regression, trees)
- val newModel = RandomForestRegressionModel.fromOld(oldModel)
-
- // Save model, load it back, and compare.
- try {
- newModel.save(sc, path)
- val sameNewModel = RandomForestRegressionModel.load(sc, path)
- TreeTests.checkEqual(newModel, sameNewModel)
- } finally {
- Utils.deleteRecursively(tempDir)
+ test("read/write") {
+ def checkModelData(
+ model: RandomForestRegressionModel,
+ model2: RandomForestRegressionModel): Unit = {
+ TreeTests.checkEqual(model, model2)
+ assert(model.numFeatures === model2.numFeatures)
}
+
+ val rf = new RandomForestRegressor().setNumTrees(2)
+ val rdd = TreeTests.getTreeReadWriteData(sc)
+
+ val allParamSettings = TreeTests.allParamSettings ++ Map("impurity" -> "variance")
+
+ val continuousData: DataFrame =
+ TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0)
+ testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, checkModelData)
}
- */
}
private object RandomForestRegressorSuite extends SparkFunSuite {