diff options
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala | 121 |
1 files changed, 94 insertions, 27 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 798947b94a..4c4ff278d4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -17,18 +17,22 @@ package org.apache.spark.ml.regression +import org.json4s.{DefaultFormats, JObject} +import org.json4s.JsonDSL._ + import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeEnsembleModel, TreeRegressorParams} +import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.impl.RandomForest -import org.apache.spark.ml.util.{Identifiable, MetadataUtils} +import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ @@ -41,7 +45,7 @@ import org.apache.spark.sql.functions._ @Experimental final class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends Predictor[Vector, RandomForestRegressor, RandomForestRegressionModel] - with RandomForestParams with TreeRegressorParams { + with RandomForestRegressorParams with DefaultParamsWritable { @Since("1.4.0") def this() = this(Identifiable.randomUID("rfr")) @@ -89,7 +93,7 @@ final class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val override def setFeatureSubsetStrategy(value: String): this.type = super.setFeatureSubsetStrategy(value) - override protected def train(dataset: DataFrame): RandomForestRegressionModel = { + override protected def train(dataset: Dataset[_]): RandomForestRegressionModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) @@ -108,7 +112,7 @@ final class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val @Since("1.4.0") @Experimental -object RandomForestRegressor { +object RandomForestRegressor extends DefaultParamsReadable[RandomForestRegressor]{ /** Accessor for supported impurity settings: variance */ @Since("1.4.0") final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities @@ -117,12 +121,17 @@ object RandomForestRegressor { @Since("1.4.0") final val supportedFeatureSubsetStrategies: Array[String] = RandomForestParams.supportedFeatureSubsetStrategies + + @Since("2.0.0") + override def load(path: String): RandomForestRegressor = super.load(path) + } /** * :: Experimental :: * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for regression. * It supports both continuous and categorical features. + * * @param _trees Decision trees in the ensemble. * @param numFeatures Number of features used by this model */ @@ -133,27 +142,29 @@ final class RandomForestRegressionModel private[ml] ( private val _trees: Array[DecisionTreeRegressionModel], override val numFeatures: Int) extends PredictionModel[Vector, RandomForestRegressionModel] - with TreeEnsembleModel with Serializable { + with RandomForestRegressionModelParams with TreeEnsembleModel[DecisionTreeRegressionModel] + with MLWritable with Serializable { - require(numTrees > 0, "RandomForestRegressionModel requires at least 1 tree.") + require(_trees.nonEmpty, "RandomForestRegressionModel requires at least 1 tree.") /** * Construct a random forest regression model, with all trees weighted equally. + * * @param trees Component trees */ private[ml] def this(trees: Array[DecisionTreeRegressionModel], numFeatures: Int) = this(Identifiable.randomUID("rfr"), trees, numFeatures) @Since("1.4.0") - override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] + override def trees: Array[DecisionTreeRegressionModel] = _trees // Note: We may add support for weights (based on tree performance) later on. - private lazy val _treeWeights: Array[Double] = Array.fill[Double](numTrees)(1.0) + private lazy val _treeWeights: Array[Double] = Array.fill[Double](_trees.length)(1.0) @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights - override protected def transformImpl(dataset: DataFrame): DataFrame = { + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) val predictUDF = udf { (features: Any) => bcastModel.value.predict(features.asInstanceOf[Vector]) @@ -165,9 +176,17 @@ final class RandomForestRegressionModel private[ml] ( // TODO: When we add a generic Bagging class, handle transform there. SPARK-7128 // Predict average of tree predictions. // Ignore the weights since all are 1.0 for now. - _trees.map(_.rootNode.predictImpl(features).prediction).sum / numTrees + _trees.map(_.rootNode.predictImpl(features).prediction).sum / getNumTrees } + /** + * Number of trees in ensemble + * @deprecated Use [[getNumTrees]] instead. This method will be removed in 2.1.0 + */ + // TODO: Once this is removed, then this class can inherit from RandomForestRegressorParams + @deprecated("Use getNumTrees instead. This method will be removed in 2.1.0.", "2.0.0") + val numTrees: Int = trees.length + @Since("1.4.0") override def copy(extra: ParamMap): RandomForestRegressionModel = { copyValues(new RandomForestRegressionModel(uid, _trees, numFeatures), extra).setParent(parent) @@ -175,36 +194,83 @@ final class RandomForestRegressionModel private[ml] ( @Since("1.4.0") override def toString: String = { - s"RandomForestRegressionModel (uid=$uid) with $numTrees trees" + s"RandomForestRegressionModel (uid=$uid) with $getNumTrees trees" } /** * Estimate of the importance of each feature. * - * This generalizes the idea of "Gini" importance to other losses, - * following the explanation of Gini importance from "Random Forests" documentation - * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn. + * Each feature's importance is the average of its importance across all trees in the ensemble + * The importance vector is normalized to sum to 1. This method is suggested by Hastie et al. + * (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.) + * and follows the implementation from scikit-learn. * - * This feature importance is calculated as follows: - * - Average over trees: - * - importance(feature j) = sum (over nodes which split on feature j) of the gain, - * where gain is scaled by the number of instances passing through node - * - Normalize importances for tree to sum to 1. - * - Normalize feature importance vector to sum to 1. + * @see [[DecisionTreeRegressionModel.featureImportances]] */ @Since("1.5.0") - lazy val featureImportances: Vector = RandomForest.featureImportances(trees, numFeatures) + lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures) /** (private[ml]) Convert to a model in the old API */ private[ml] def toOld: OldRandomForestModel = { new OldRandomForestModel(OldAlgo.Regression, _trees.map(_.toOld)) } + + @Since("2.0.0") + override def write: MLWriter = + new RandomForestRegressionModel.RandomForestRegressionModelWriter(this) } -private[ml] object RandomForestRegressionModel { +@Since("2.0.0") +object RandomForestRegressionModel extends MLReadable[RandomForestRegressionModel] { + + @Since("2.0.0") + override def read: MLReader[RandomForestRegressionModel] = new RandomForestRegressionModelReader + + @Since("2.0.0") + override def load(path: String): RandomForestRegressionModel = super.load(path) + + private[RandomForestRegressionModel] + class RandomForestRegressionModelWriter(instance: RandomForestRegressionModel) + extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val extraMetadata: JObject = Map( + "numFeatures" -> instance.numFeatures, + "numTrees" -> instance.getNumTrees) + EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata) + } + } + + private class RandomForestRegressionModelReader extends MLReader[RandomForestRegressionModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[RandomForestRegressionModel].getName + private val treeClassName = classOf[DecisionTreeRegressionModel].getName + + override def load(path: String): RandomForestRegressionModel = { + implicit val format = DefaultFormats + val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = + EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName) + val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] + val numTrees = (metadata.metadata \ "numTrees").extract[Int] + + val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) => + val tree = + new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures) + DefaultParamsReader.getAndSetParams(tree, treeMetadata) + tree + } + require(numTrees == trees.length, s"RandomForestRegressionModel.load expected $numTrees" + + s" trees based on metadata but found ${trees.length} trees.") + + val model = new RandomForestRegressionModel(metadata.uid, trees, numFeatures) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } - /** (private[ml]) Convert a model from the old API */ - def fromOld( + /** Convert a model from the old API */ + private[ml] def fromOld( oldModel: OldRandomForestModel, parent: RandomForestRegressor, categoricalFeatures: Map[Int, Int], @@ -215,6 +281,7 @@ private[ml] object RandomForestRegressionModel { // parent for each tree is null since there is no good way to set this. DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } - new RandomForestRegressionModel(parent.uid, newTrees, numFeatures) + val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfr") + new RandomForestRegressionModel(uid, newTrees, numFeatures) } } |