aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala121
1 files changed, 94 insertions, 27 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index 798947b94a..4c4ff278d4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -17,18 +17,22 @@
package org.apache.spark.ml.regression
+import org.json4s.{DefaultFormats, JObject}
+import org.json4s.JsonDSL._
+
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeEnsembleModel, TreeRegressorParams}
+import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.impl.RandomForest
-import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
+import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
@@ -41,7 +45,7 @@ import org.apache.spark.sql.functions._
@Experimental
final class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Predictor[Vector, RandomForestRegressor, RandomForestRegressionModel]
- with RandomForestParams with TreeRegressorParams {
+ with RandomForestRegressorParams with DefaultParamsWritable {
@Since("1.4.0")
def this() = this(Identifiable.randomUID("rfr"))
@@ -89,7 +93,7 @@ final class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val
override def setFeatureSubsetStrategy(value: String): this.type =
super.setFeatureSubsetStrategy(value)
- override protected def train(dataset: DataFrame): RandomForestRegressionModel = {
+ override protected def train(dataset: Dataset[_]): RandomForestRegressionModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
@@ -108,7 +112,7 @@ final class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val
@Since("1.4.0")
@Experimental
-object RandomForestRegressor {
+object RandomForestRegressor extends DefaultParamsReadable[RandomForestRegressor]{
/** Accessor for supported impurity settings: variance */
@Since("1.4.0")
final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities
@@ -117,12 +121,17 @@ object RandomForestRegressor {
@Since("1.4.0")
final val supportedFeatureSubsetStrategies: Array[String] =
RandomForestParams.supportedFeatureSubsetStrategies
+
+ @Since("2.0.0")
+ override def load(path: String): RandomForestRegressor = super.load(path)
+
}
/**
* :: Experimental ::
* [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for regression.
* It supports both continuous and categorical features.
+ *
* @param _trees Decision trees in the ensemble.
* @param numFeatures Number of features used by this model
*/
@@ -133,27 +142,29 @@ final class RandomForestRegressionModel private[ml] (
private val _trees: Array[DecisionTreeRegressionModel],
override val numFeatures: Int)
extends PredictionModel[Vector, RandomForestRegressionModel]
- with TreeEnsembleModel with Serializable {
+ with RandomForestRegressionModelParams with TreeEnsembleModel[DecisionTreeRegressionModel]
+ with MLWritable with Serializable {
- require(numTrees > 0, "RandomForestRegressionModel requires at least 1 tree.")
+ require(_trees.nonEmpty, "RandomForestRegressionModel requires at least 1 tree.")
/**
* Construct a random forest regression model, with all trees weighted equally.
+ *
* @param trees Component trees
*/
private[ml] def this(trees: Array[DecisionTreeRegressionModel], numFeatures: Int) =
this(Identifiable.randomUID("rfr"), trees, numFeatures)
@Since("1.4.0")
- override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
+ override def trees: Array[DecisionTreeRegressionModel] = _trees
// Note: We may add support for weights (based on tree performance) later on.
- private lazy val _treeWeights: Array[Double] = Array.fill[Double](numTrees)(1.0)
+ private lazy val _treeWeights: Array[Double] = Array.fill[Double](_trees.length)(1.0)
@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights
- override protected def transformImpl(dataset: DataFrame): DataFrame = {
+ override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
val predictUDF = udf { (features: Any) =>
bcastModel.value.predict(features.asInstanceOf[Vector])
@@ -165,9 +176,17 @@ final class RandomForestRegressionModel private[ml] (
// TODO: When we add a generic Bagging class, handle transform there. SPARK-7128
// Predict average of tree predictions.
// Ignore the weights since all are 1.0 for now.
- _trees.map(_.rootNode.predictImpl(features).prediction).sum / numTrees
+ _trees.map(_.rootNode.predictImpl(features).prediction).sum / getNumTrees
}
+ /**
+ * Number of trees in ensemble
+ * @deprecated Use [[getNumTrees]] instead. This method will be removed in 2.1.0
+ */
+ // TODO: Once this is removed, then this class can inherit from RandomForestRegressorParams
+ @deprecated("Use getNumTrees instead. This method will be removed in 2.1.0.", "2.0.0")
+ val numTrees: Int = trees.length
+
@Since("1.4.0")
override def copy(extra: ParamMap): RandomForestRegressionModel = {
copyValues(new RandomForestRegressionModel(uid, _trees, numFeatures), extra).setParent(parent)
@@ -175,36 +194,83 @@ final class RandomForestRegressionModel private[ml] (
@Since("1.4.0")
override def toString: String = {
- s"RandomForestRegressionModel (uid=$uid) with $numTrees trees"
+ s"RandomForestRegressionModel (uid=$uid) with $getNumTrees trees"
}
/**
* Estimate of the importance of each feature.
*
- * This generalizes the idea of "Gini" importance to other losses,
- * following the explanation of Gini importance from "Random Forests" documentation
- * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
+ * Each feature's importance is the average of its importance across all trees in the ensemble
+ * The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
+ * (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
+ * and follows the implementation from scikit-learn.
*
- * This feature importance is calculated as follows:
- * - Average over trees:
- * - importance(feature j) = sum (over nodes which split on feature j) of the gain,
- * where gain is scaled by the number of instances passing through node
- * - Normalize importances for tree to sum to 1.
- * - Normalize feature importance vector to sum to 1.
+ * @see [[DecisionTreeRegressionModel.featureImportances]]
*/
@Since("1.5.0")
- lazy val featureImportances: Vector = RandomForest.featureImportances(trees, numFeatures)
+ lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures)
/** (private[ml]) Convert to a model in the old API */
private[ml] def toOld: OldRandomForestModel = {
new OldRandomForestModel(OldAlgo.Regression, _trees.map(_.toOld))
}
+
+ @Since("2.0.0")
+ override def write: MLWriter =
+ new RandomForestRegressionModel.RandomForestRegressionModelWriter(this)
}
-private[ml] object RandomForestRegressionModel {
+@Since("2.0.0")
+object RandomForestRegressionModel extends MLReadable[RandomForestRegressionModel] {
+
+ @Since("2.0.0")
+ override def read: MLReader[RandomForestRegressionModel] = new RandomForestRegressionModelReader
+
+ @Since("2.0.0")
+ override def load(path: String): RandomForestRegressionModel = super.load(path)
+
+ private[RandomForestRegressionModel]
+ class RandomForestRegressionModelWriter(instance: RandomForestRegressionModel)
+ extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ val extraMetadata: JObject = Map(
+ "numFeatures" -> instance.numFeatures,
+ "numTrees" -> instance.getNumTrees)
+ EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata)
+ }
+ }
+
+ private class RandomForestRegressionModelReader extends MLReader[RandomForestRegressionModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[RandomForestRegressionModel].getName
+ private val treeClassName = classOf[DecisionTreeRegressionModel].getName
+
+ override def load(path: String): RandomForestRegressionModel = {
+ implicit val format = DefaultFormats
+ val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) =
+ EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName)
+ val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
+ val numTrees = (metadata.metadata \ "numTrees").extract[Int]
+
+ val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) =>
+ val tree =
+ new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures)
+ DefaultParamsReader.getAndSetParams(tree, treeMetadata)
+ tree
+ }
+ require(numTrees == trees.length, s"RandomForestRegressionModel.load expected $numTrees" +
+ s" trees based on metadata but found ${trees.length} trees.")
+
+ val model = new RandomForestRegressionModel(metadata.uid, trees, numFeatures)
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
- /** (private[ml]) Convert a model from the old API */
- def fromOld(
+ /** Convert a model from the old API */
+ private[ml] def fromOld(
oldModel: OldRandomForestModel,
parent: RandomForestRegressor,
categoricalFeatures: Map[Int, Int],
@@ -215,6 +281,7 @@ private[ml] object RandomForestRegressionModel {
// parent for each tree is null since there is no good way to set this.
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
- new RandomForestRegressionModel(parent.uid, newTrees, numFeatures)
+ val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfr")
+ new RandomForestRegressionModel(uid, newTrees, numFeatures)
}
}