diff options
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/regression')
7 files changed, 386 insertions, 170 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index ba5708ab8d..89ba6ab5d2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -31,8 +31,9 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors, VectorUDT} +import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructType} import org.apache.spark.storage.StorageLevel @@ -103,7 +104,7 @@ private[regression] trait AFTSurvivalRegressionParams extends Params SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) if (fitting) { SchemaUtils.checkColumnType(schema, $(censorCol), DoubleType) - SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) + SchemaUtils.checkNumericType(schema, $(labelCol)) } if (hasQuantilesCol) { SchemaUtils.appendColumn(schema, $(quantilesCol), new VectorUDT) @@ -183,24 +184,35 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S * Extract [[featuresCol]], [[labelCol]] and [[censorCol]] from input dataset, * and put it in an RDD with strong types. */ - protected[ml] def extractAFTPoints(dataset: DataFrame): RDD[AFTPoint] = { - dataset.select($(featuresCol), $(labelCol), $(censorCol)).rdd.map { - case Row(features: Vector, label: Double, censor: Double) => - AFTPoint(features, label, censor) - } + protected[ml] def extractAFTPoints(dataset: Dataset[_]): RDD[AFTPoint] = { + dataset.select(col($(featuresCol)), col($(labelCol)).cast(DoubleType), col($(censorCol))) + .rdd.map { + case Row(features: Vector, label: Double, censor: Double) => + AFTPoint(features, label, censor) + } } - @Since("1.6.0") - override def fit(dataset: DataFrame): AFTSurvivalRegressionModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): AFTSurvivalRegressionModel = { validateAndTransformSchema(dataset.schema, fitting = true) val instances = extractAFTPoints(dataset) val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) - val costFun = new AFTCostFun(instances, $(fitIntercept)) + val featuresSummarizer = { + val seqOp = (c: MultivariateOnlineSummarizer, v: AFTPoint) => c.add(v.features) + val combOp = (c1: MultivariateOnlineSummarizer, c2: MultivariateOnlineSummarizer) => { + c1.merge(c2) + } + instances.treeAggregate(new MultivariateOnlineSummarizer)(seqOp, combOp) + } + + val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt) + + val costFun = new AFTCostFun(instances, $(fitIntercept), featuresStd) val optimizer = new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) - val numFeatures = dataset.select($(featuresCol)).take(1)(0).getAs[Vector](0).size + val numFeatures = featuresStd.size /* The parameters vector has three parts: the first element: Double, log(sigma), the log of scale parameter @@ -229,7 +241,13 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S if (handlePersistence) instances.unpersist() - val coefficients = Vectors.dense(parameters.slice(2, parameters.length)) + val rawCoefficients = parameters.slice(2, parameters.length) + var i = 0 + while (i < numFeatures) { + rawCoefficients(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i) else 0.0 } + i += 1 + } + val coefficients = Vectors.dense(rawCoefficients) val intercept = parameters(1) val scale = math.exp(parameters(0)) val model = new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale) @@ -298,8 +316,8 @@ class AFTSurvivalRegressionModel private[ml] ( math.exp(BLAS.dot(coefficients, features) + intercept) } - @Since("1.6.0") - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema) val predictUDF = udf { features: Vector => predict(features) } val predictQuantilesUDF = udf { features: Vector => predictQuantiles(features)} @@ -433,29 +451,36 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel] * @param parameters including three part: The log of scale parameter, the intercept and * regression coefficients corresponding to the features. * @param fitIntercept Whether to fit an intercept term. + * @param featuresStd The standard deviation values of the features. */ -private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean) - extends Serializable { +private class AFTAggregator( + parameters: BDV[Double], + fitIntercept: Boolean, + featuresStd: Array[Double]) extends Serializable { // the regression coefficients to the covariates private val coefficients = parameters.slice(2, parameters.length) - private val intercept = parameters.valueAt(1) + private val intercept = parameters(1) // sigma is the scale parameter of the AFT model private val sigma = math.exp(parameters(0)) private var totalCnt: Long = 0L private var lossSum = 0.0 - private var gradientCoefficientSum = BDV.zeros[Double](coefficients.length) - private var gradientInterceptSum = 0.0 - private var gradientLogSigmaSum = 0.0 + // Here we optimize loss function over log(sigma), intercept and coefficients + private val gradientSumArray = Array.ofDim[Double](parameters.length) def count: Long = totalCnt + def loss: Double = { + require(totalCnt > 0.0, s"The number of instances should be " + + s"greater than 0.0, but got $totalCnt.") + lossSum / totalCnt + } + def gradient: BDV[Double] = { + require(totalCnt > 0.0, s"The number of instances should be " + + s"greater than 0.0, but got $totalCnt.") + new BDV(gradientSumArray.map(_ / totalCnt.toDouble)) + } - def loss: Double = if (totalCnt == 0) 1.0 else lossSum / totalCnt - - // Here we optimize loss function over coefficients, intercept and log(sigma) - def gradient: BDV[Double] = BDV.vertcat(BDV(Array(gradientLogSigmaSum / totalCnt.toDouble)), - BDV(Array(gradientInterceptSum/totalCnt.toDouble)), gradientCoefficientSum/totalCnt.toDouble) /** * Add a new training data to this AFTAggregator, and update the loss and gradient @@ -465,25 +490,32 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean) * @return This AFTAggregator object. */ def add(data: AFTPoint): this.type = { - - val interceptFlag = if (fitIntercept) 1.0 else 0.0 - - val xi = data.features.toBreeze + val xi = data.features val ti = data.label val delta = data.censor - val epsilon = (math.log(ti) - coefficients.dot(xi) - intercept * interceptFlag ) / sigma - lossSum += math.log(sigma) * delta - lossSum += (math.exp(epsilon) - delta * epsilon) + val margin = { + var sum = 0.0 + xi.foreachActive { (index, value) => + if (featuresStd(index) != 0.0 && value != 0.0) { + sum += coefficients(index) * (value / featuresStd(index)) + } + } + sum + intercept + } + val epsilon = (math.log(ti) - margin) / sigma + + lossSum += delta * math.log(sigma) - delta * epsilon + math.exp(epsilon) - // Sanity check (should never occur): - assert(!lossSum.isInfinity, - s"AFTAggregator loss sum is infinity. Error for unknown reason.") + val multiplier = (delta - math.exp(epsilon)) / sigma - val deltaMinusExpEps = delta - math.exp(epsilon) - gradientCoefficientSum += xi * deltaMinusExpEps / sigma - gradientInterceptSum += interceptFlag * deltaMinusExpEps / sigma - gradientLogSigmaSum += delta + deltaMinusExpEps * epsilon + gradientSumArray(0) += delta + multiplier * sigma * epsilon + gradientSumArray(1) += { if (fitIntercept) multiplier else 0.0 } + xi.foreachActive { (index, value) => + if (featuresStd(index) != 0.0 && value != 0.0) { + gradientSumArray(index + 2) += multiplier * (value / featuresStd(index)) + } + } totalCnt += 1 this @@ -502,9 +534,12 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean) totalCnt += other.totalCnt lossSum += other.lossSum - gradientCoefficientSum += other.gradientCoefficientSum - gradientInterceptSum += other.gradientInterceptSum - gradientLogSigmaSum += other.gradientLogSigmaSum + var i = 0 + val len = this.gradientSumArray.length + while (i < len) { + this.gradientSumArray(i) += other.gradientSumArray(i) + i += 1 + } } this } @@ -515,12 +550,15 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean) * It returns the loss and gradient at a particular point (parameters). * It's used in Breeze's convex optimization routines. */ -private class AFTCostFun(data: RDD[AFTPoint], fitIntercept: Boolean) - extends DiffFunction[BDV[Double]] { +private class AFTCostFun( + data: RDD[AFTPoint], + fitIntercept: Boolean, + featuresStd: Array[Double]) extends DiffFunction[BDV[Double]] { override def calculate(parameters: BDV[Double]): (Double, BDV[Double]) = { - val aftAggregator = data.treeAggregate(new AFTAggregator(parameters, fitIntercept))( + val aftAggregator = data.treeAggregate( + new AFTAggregator(parameters, fitIntercept, featuresStd))( seqOp = (c, v) => (c, v) match { case (aggregator, instance) => aggregator.add(instance) }, diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 50ac96eb5e..c04c416aaf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -33,7 +33,7 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ @@ -83,7 +83,7 @@ final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val /** @group setParam */ def setVarianceCol(value: String): this.type = set(varianceCol, value) - override protected def train(dataset: DataFrame): DecisionTreeRegressionModel = { + override protected def train(dataset: Dataset[_]): DecisionTreeRegressionModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) @@ -158,15 +158,16 @@ final class DecisionTreeRegressionModel private[ml] ( rootNode.predictImpl(features).impurityStats.calculate() } - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) transformImpl(dataset) } - override protected def transformImpl(dataset: DataFrame): DataFrame = { + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { val predictUDF = udf { (features: Vector) => predict(features) } val predictVarianceUDF = udf { (features: Vector) => predictVariance(features) } - var output = dataset + var output = dataset.toDF if ($(predictionCol).nonEmpty) { output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } @@ -203,9 +204,9 @@ final class DecisionTreeRegressionModel private[ml] ( * to determine feature importance instead. */ @Since("2.0.0") - lazy val featureImportances: Vector = RandomForest.featureImportances(this, numFeatures) + lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(this, numFeatures) - /** Convert to spark.mllib DecisionTreeModel (losing some infomation) */ + /** Convert to spark.mllib DecisionTreeModel (losing some information) */ override private[spark] def toOld: OldDecisionTreeModel = { new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Regression) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index da5b77e8fa..741724d7a1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -18,23 +18,23 @@ package org.apache.spark.ml.regression import com.github.fommil.netlib.BLAS.{getInstance => blas} +import org.json4s.{DefaultFormats, JObject} +import org.json4s.JsonDSL._ import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml.{PredictionModel, Predictor} -import org.apache.spark.ml.param.{Param, ParamMap} -import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeEnsembleModel, - TreeRegressorParams} +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.impl.GradientBoostedTrees -import org.apache.spark.ml.util.{Identifiable, MetadataUtils} +import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} -import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, Loss => OldLoss, - SquaredError => OldSquaredError} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ /** @@ -42,12 +42,24 @@ import org.apache.spark.sql.functions._ * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]] * learning algorithm for regression. * It supports both continuous and categorical features. + * + * The implementation is based upon: J.H. Friedman. "Stochastic Gradient Boosting." 1999. + * + * Notes on Gradient Boosting vs. TreeBoost: + * - This implementation is for Stochastic Gradient Boosting, not for TreeBoost. + * - Both algorithms learn tree ensembles by minimizing loss functions. + * - TreeBoost (Friedman, 1999) additionally modifies the outputs at tree leaf nodes + * based on the loss function, whereas the original gradient boosting method does not. + * - When the loss is SquaredError, these methods give the same result, but they could differ + * for other loss functions. + * - We expect to implement TreeBoost in the future: + * [https://issues.apache.org/jira/browse/SPARK-4240] */ @Since("1.4.0") @Experimental final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends Predictor[Vector, GBTRegressor, GBTRegressionModel] - with GBTParams with TreeRegressorParams with Logging { + with GBTRegressorParams with DefaultParamsWritable with Logging { @Since("1.4.0") def this() = this(Identifiable.randomUID("gbtr")) @@ -101,42 +113,13 @@ final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: Stri @Since("1.4.0") override def setStepSize(value: Double): this.type = super.setStepSize(value) - // Parameters for GBTRegressor: - - /** - * Loss function which GBT tries to minimize. (case-insensitive) - * Supported: "squared" (L2) and "absolute" (L1) - * (default = squared) - * @group param - */ - @Since("1.4.0") - val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + - " tries to minimize (case-insensitive). Supported options:" + - s" ${GBTRegressor.supportedLossTypes.mkString(", ")}", - (value: String) => GBTRegressor.supportedLossTypes.contains(value.toLowerCase)) - - setDefault(lossType -> "squared") + // Parameters from GBTRegressorParams: /** @group setParam */ @Since("1.4.0") def setLossType(value: String): this.type = set(lossType, value) - /** @group getParam */ - @Since("1.4.0") - def getLossType: String = $(lossType).toLowerCase - - /** (private[ml]) Convert new loss to old loss. */ - override private[ml] def getOldLossType: OldLoss = { - getLossType match { - case "squared" => OldSquaredError - case "absolute" => OldAbsoluteError - case _ => - // Should never happen because of check in setter method. - throw new RuntimeException(s"GBTRegressorParams was given bad loss type: $getLossType") - } - } - - override protected def train(dataset: DataFrame): GBTRegressionModel = { + override protected def train(dataset: Dataset[_]): GBTRegressionModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) @@ -153,11 +136,14 @@ final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: Stri @Since("1.4.0") @Experimental -object GBTRegressor { - // The losses below should be lowercase. +object GBTRegressor extends DefaultParamsReadable[GBTRegressor] { + /** Accessor for supported loss settings: squared (L2), absolute (L1) */ @Since("1.4.0") - final val supportedLossTypes: Array[String] = Array("squared", "absolute").map(_.toLowerCase) + final val supportedLossTypes: Array[String] = GBTRegressorParams.supportedLossTypes + + @Since("2.0.0") + override def load(path: String): GBTRegressor = super.load(path) } /** @@ -177,9 +163,10 @@ final class GBTRegressionModel private[ml]( private val _treeWeights: Array[Double], override val numFeatures: Int) extends PredictionModel[Vector, GBTRegressionModel] - with TreeEnsembleModel with Serializable { + with GBTRegressorParams with TreeEnsembleModel[DecisionTreeRegressionModel] + with MLWritable with Serializable { - require(numTrees > 0, "GBTRegressionModel requires at least 1 tree.") + require(_trees.nonEmpty, "GBTRegressionModel requires at least 1 tree.") require(_trees.length == _treeWeights.length, "GBTRegressionModel given trees, treeWeights of" + s" non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).") @@ -193,12 +180,12 @@ final class GBTRegressionModel private[ml]( this(uid, _trees, _treeWeights, -1) @Since("1.4.0") - override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] + override def trees: Array[DecisionTreeRegressionModel] = _trees @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights - override protected def transformImpl(dataset: DataFrame): DataFrame = { + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) val predictUDF = udf { (features: Any) => bcastModel.value.predict(features.asInstanceOf[Vector]) @@ -213,6 +200,9 @@ final class GBTRegressionModel private[ml]( blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) } + /** Number of trees in ensemble */ + val numTrees: Int = trees.length + @Since("1.4.0") override def copy(extra: ParamMap): GBTRegressionModel = { copyValues(new GBTRegressionModel(uid, _trees, _treeWeights, numFeatures), @@ -224,16 +214,81 @@ final class GBTRegressionModel private[ml]( s"GBTRegressionModel (uid=$uid) with $numTrees trees" } + /** + * Estimate of the importance of each feature. + * + * Each feature's importance is the average of its importance across all trees in the ensemble + * The importance vector is normalized to sum to 1. This method is suggested by Hastie et al. + * (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.) + * and follows the implementation from scikit-learn. + * + * @see [[DecisionTreeRegressionModel.featureImportances]] + */ + @Since("2.0.0") + lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures) + /** (private[ml]) Convert to a model in the old API */ private[ml] def toOld: OldGBTModel = { new OldGBTModel(OldAlgo.Regression, _trees.map(_.toOld), _treeWeights) } + + @Since("2.0.0") + override def write: MLWriter = new GBTRegressionModel.GBTRegressionModelWriter(this) } -private[ml] object GBTRegressionModel { +@Since("2.0.0") +object GBTRegressionModel extends MLReadable[GBTRegressionModel] { + + @Since("2.0.0") + override def read: MLReader[GBTRegressionModel] = new GBTRegressionModelReader + + @Since("2.0.0") + override def load(path: String): GBTRegressionModel = super.load(path) + + private[GBTRegressionModel] + class GBTRegressionModelWriter(instance: GBTRegressionModel) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val extraMetadata: JObject = Map( + "numFeatures" -> instance.numFeatures, + "numTrees" -> instance.getNumTrees) + EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata) + } + } + + private class GBTRegressionModelReader extends MLReader[GBTRegressionModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[GBTRegressionModel].getName + private val treeClassName = classOf[DecisionTreeRegressionModel].getName + + override def load(path: String): GBTRegressionModel = { + implicit val format = DefaultFormats + val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = + EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName) + + val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] + val numTrees = (metadata.metadata \ "numTrees").extract[Int] + + val trees: Array[DecisionTreeRegressionModel] = treesData.map { + case (treeMetadata, root) => + val tree = + new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures) + DefaultParamsReader.getAndSetParams(tree, treeMetadata) + tree + } + + require(numTrees == trees.length, s"GBTRegressionModel.load expected $numTrees" + + s" trees based on metadata but found ${trees.length} trees.") + + val model = new GBTRegressionModel(metadata.uid, trees, treeWeights, numFeatures) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } - /** (private[ml]) Convert a model from the old API */ - def fromOld( + /** Convert a model from the old API */ + private[ml] def fromOld( oldModel: OldGBTModel, parent: GBTRegressor, categoricalFeatures: Map[Int, Int], @@ -245,6 +300,6 @@ private[ml] object GBTRegressionModel { DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtr") - new GBTRegressionModel(parent.uid, newTrees, oldModel.treeWeights, numFeatures) + new GBTRegressionModel(uid, newTrees, oldModel.treeWeights, numFeatures) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 0e71e8d8e1..e92a3e7fa1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -31,9 +31,9 @@ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{BLAS, Vector} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.types.{DataType, DoubleType, StructType} /** * Params for Generalized Linear Regression. @@ -47,6 +47,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam * to be used in the model. * Supported options: "gaussian", "binomial", "poisson" and "gamma". * Default is "gaussian". + * * @group param */ @Since("2.0.0") @@ -63,6 +64,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam * Param for the name of link function which provides the relationship * between the linear predictor and the mean of the distribution function. * Supported options: "identity", "log", "inverse", "logit", "probit", "cloglog" and "sqrt". + * * @group param */ @Since("2.0.0") @@ -163,7 +165,11 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val setDefault(tol -> 1E-6) /** - * Sets the regularization parameter. + * Sets the regularization parameter for L2 regularization. + * The regularization term is + * {{{ + * 0.5 * regParam * L2norm(coefficients)^2 + * }}} * Default is 0.0. * @group setParam */ @@ -190,7 +196,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val def setSolver(value: String): this.type = set(solver, value) setDefault(solver -> "irls") - override protected def train(dataset: DataFrame): GeneralizedLinearRegressionModel = { + override protected def train(dataset: Dataset[_]): GeneralizedLinearRegressionModel = { val familyObj = Family.fromName($(family)) val linkObj = if (isDefined(link)) { Link.fromName($(link)) @@ -210,9 +216,10 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val } val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol)) - val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd - .map { case Row(label: Double, weight: Double, features: Vector) => - Instance(label, weight, features) + val instances: RDD[Instance] = + dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { + case Row(label: Double, weight: Double, features: Vector) => + Instance(label, weight, features) } if (familyObj == Gaussian && linkObj == Identity) { @@ -230,7 +237,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val predictionColName, model, wlsModel.diagInvAtWA.toArray, - 1) + 1, + getSolver) return model.setSummary(trainingSummary) } @@ -250,7 +258,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val predictionColName, model, irlsModel.diagInvAtWA.toArray, - irlsModel.numIterations) + irlsModel.numIterations, + getSolver) model.setSummary(trainingSummary) } @@ -698,7 +707,7 @@ class GeneralizedLinearRegressionModel private[ml] ( : (GeneralizedLinearRegressionModel, String) = { $(predictionCol) match { case "" => - val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString() + val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString (copy(ParamMap.empty).setPredictionCol(predictionColName), predictionColName) case p => (this, p) } @@ -769,11 +778,12 @@ object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegr * :: Experimental :: * Summarizing Generalized Linear regression Fits. * - * @param predictions predictions outputted by the model's `transform` method + * @param predictions predictions output by the model's `transform` method * @param predictionCol field in "predictions" which gives the prediction value of each instance * @param model the model that should be summarized * @param diagInvAtWA diagonal of matrix (A^T * W * A)^-1 in the last iteration * @param numIterations number of iterations + * @param solver the solver algorithm used for model training */ @Since("2.0.0") @Experimental @@ -782,7 +792,8 @@ class GeneralizedLinearRegressionSummary private[regression] ( @Since("2.0.0") val predictionCol: String, @Since("2.0.0") val model: GeneralizedLinearRegressionModel, private val diagInvAtWA: Array[Double], - @Since("2.0.0") val numIterations: Int) extends Serializable { + @Since("2.0.0") val numIterations: Int, + @Since("2.0.0") val solver: String) extends Serializable { import GeneralizedLinearRegression._ @@ -930,6 +941,9 @@ class GeneralizedLinearRegressionSummary private[regression] ( /** * Standard error of estimated coefficients and intercept. + * + * If [[GeneralizedLinearRegression.fitIntercept]] is set to true, + * then the last element returned corresponds to the intercept. */ @Since("2.0.0") lazy val coefficientStandardErrors: Array[Double] = { @@ -938,6 +952,9 @@ class GeneralizedLinearRegressionSummary private[regression] ( /** * T-statistic of estimated coefficients and intercept. + * + * If [[GeneralizedLinearRegression.fitIntercept]] is set to true, + * then the last element returned corresponds to the intercept. */ @Since("2.0.0") lazy val tValues: Array[Double] = { @@ -951,6 +968,9 @@ class GeneralizedLinearRegressionSummary private[regression] ( /** * Two-sided p-value of estimated coefficients and intercept. + * + * If [[GeneralizedLinearRegression.fitIntercept]] is set to true, + * then the last element returned corresponds to the intercept. */ @Since("2.0.0") lazy val pValues: Array[Double] = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index fb733f9a34..7a78ecbdf1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -30,7 +30,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression} import org.apache.spark.mllib.regression.{IsotonicRegressionModel => MLlibIsotonicRegressionModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.{col, lit, udf} import org.apache.spark.sql.types.{DoubleType, StructType} import org.apache.spark.storage.StorageLevel @@ -77,7 +77,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures * Extracts (label, feature, weight) from input dataset. */ protected[ml] def extractWeightedLabeledPoints( - dataset: DataFrame): RDD[(Double, Double, Double)] = { + dataset: Dataset[_]): RDD[(Double, Double, Double)] = { val f = if (dataset.schema($(featuresCol)).dataType.isInstanceOf[VectorUDT]) { val idx = $(featureIndex) val extract = udf { v: Vector => v(idx) } @@ -90,7 +90,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures } else { lit(1.0) } - dataset.select(col($(labelCol)), f, w).rdd.map { + dataset.select(col($(labelCol)).cast(DoubleType), f, w).rdd.map { case Row(label: Double, feature: Double, weight: Double) => (label, feature, weight) } @@ -106,7 +106,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures schema: StructType, fitting: Boolean): StructType = { if (fitting) { - SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) + SchemaUtils.checkNumericType(schema, $(labelCol)) if (hasWeightCol) { SchemaUtils.checkColumnType(schema, $(weightCol), DoubleType) } else { @@ -164,8 +164,8 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri @Since("1.5.0") override def copy(extra: ParamMap): IsotonicRegression = defaultCopy(extra) - @Since("1.5.0") - override def fit(dataset: DataFrame): IsotonicRegressionModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): IsotonicRegressionModel = { validateAndTransformSchema(dataset.schema, fitting = true) // Extract columns from data. If dataset is persisted, do not persist oldDataset. val instances = extractWeightedLabeledPoints(dataset) @@ -236,8 +236,8 @@ class IsotonicRegressionModel private[ml] ( copyValues(new IsotonicRegressionModel(uid, oldModel), extra).setParent(parent) } - @Since("1.5.0") - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { val predict = dataset.schema($(featuresCol)).dataType match { case DoubleType => udf { feature: Double => oldModel.predict(feature) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index b81c588e44..71e02730c7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -38,8 +38,9 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.DoubleType import org.apache.spark.storage.StorageLevel /** @@ -57,7 +58,7 @@ private[regression] trait LinearRegressionParams extends PredictorParams * The specific squared error loss function used is: * L = 1/2n ||A coefficients - y||^2^ * - * This support multiple types of regularization: + * This supports multiple types of regularization: * - none (a.k.a. ordinary least squares) * - L2 (ridge regression) * - L1 (Lasso) @@ -157,7 +158,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String def setSolver(value: String): this.type = set(solver, value) setDefault(solver -> "auto") - override protected def train(dataset: DataFrame): LinearRegressionModel = { + override protected def train(dataset: Dataset[_]): LinearRegressionModel = { // Extract the number of features before deciding optimization solver. val numFeatures = dataset.select(col($(featuresCol))).limit(1).rdd.map { case Row(features: Vector) => features.size @@ -171,7 +172,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String // For low dimensional data, WeightedLeastSquares is more efficiently since the // training algorithm only requires one pass through the data. (SPARK-10668) val instances: RDD[Instance] = dataset.select( - col($(labelCol)), w, col($(featuresCol))).rdd.map { + col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { case Row(label: Double, weight: Double, features: Vector) => Instance(label, weight, features) } @@ -189,9 +190,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String summaryModel.transform(dataset), predictionColName, $(labelCol), + $(featuresCol), summaryModel, model.diagInvAtWA.toArray, - $(featuresCol), Array(0D)) return lrModel.setSummary(trainingSummary) @@ -248,9 +249,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String summaryModel.transform(dataset), predictionColName, $(labelCol), + $(featuresCol), model, Array(0D), - $(featuresCol), Array(0D)) return copyValues(model.setSummary(trainingSummary)) } else { @@ -355,9 +356,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String summaryModel.transform(dataset), predictionColName, $(labelCol), + $(featuresCol), model, Array(0D), - $(featuresCol), objectiveHistory) model.setSummary(trainingSummary) } @@ -412,15 +413,15 @@ class LinearRegressionModel private[ml] ( def hasSummary: Boolean = trainingSummary.isDefined /** - * Evaluates the model on a testset. + * Evaluates the model on a test dataset. * @param dataset Test dataset to evaluate model on. */ - // TODO: decide on a good name before exposing to public API - private[regression] def evaluate(dataset: DataFrame): LinearRegressionSummary = { + @Since("2.0.0") + def evaluate(dataset: Dataset[_]): LinearRegressionSummary = { // Handle possible missing or invalid prediction columns val (summaryModel, predictionColName) = findSummaryModelAndPredictionCol() new LinearRegressionSummary(summaryModel.transform(dataset), predictionColName, - $(labelCol), this, Array(0D)) + $(labelCol), $(featuresCol), summaryModel, Array(0D)) } /** @@ -431,7 +432,7 @@ class LinearRegressionModel private[ml] ( private[regression] def findSummaryModelAndPredictionCol(): (LinearRegressionModel, String) = { $(predictionCol) match { case "" => - val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString() + val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString (copy(ParamMap.empty).setPredictionCol(predictionColName), predictionColName) case p => (this, p) } @@ -510,9 +511,9 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] { /** * :: Experimental :: * Linear regression training results. Currently, the training summary ignores the - * training coefficients except for the objective trace. + * training weights except for the objective trace. * - * @param predictions predictions outputted by the model's `transform` method. + * @param predictions predictions output by the model's `transform` method. * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. */ @Since("1.5.0") @@ -521,13 +522,24 @@ class LinearRegressionTrainingSummary private[regression] ( predictions: DataFrame, predictionCol: String, labelCol: String, + featuresCol: String, model: LinearRegressionModel, diagInvAtWA: Array[Double], - val featuresCol: String, val objectiveHistory: Array[Double]) - extends LinearRegressionSummary(predictions, predictionCol, labelCol, model, diagInvAtWA) { + extends LinearRegressionSummary( + predictions, + predictionCol, + labelCol, + featuresCol, + model, + diagInvAtWA) { - /** Number of training iterations until termination */ + /** + * Number of training iterations until termination + * + * This value is only available when using the "l-bfgs" solver. + * @see [[LinearRegression.solver]] + */ @Since("1.5.0") val totalIterations = objectiveHistory.length @@ -537,7 +549,11 @@ class LinearRegressionTrainingSummary private[regression] ( * :: Experimental :: * Linear regression results evaluated on a dataset. * - * @param predictions predictions outputted by the model's `transform` method. + * @param predictions predictions output by the model's `transform` method. + * @param predictionCol Field in "predictions" which gives the predicted value of the label at + * each instance. + * @param labelCol Field in "predictions" which gives the true label of each instance. + * @param featuresCol Field in "predictions" which gives the features of each instance as a vector. */ @Since("1.5.0") @Experimental @@ -545,12 +561,13 @@ class LinearRegressionSummary private[regression] ( @transient val predictions: DataFrame, val predictionCol: String, val labelCol: String, + val featuresCol: String, val model: LinearRegressionModel, private val diagInvAtWA: Array[Double]) extends Serializable { @transient private val metrics = new RegressionMetrics( predictions - .select(predictionCol, labelCol) + .select(col(predictionCol), col(labelCol).cast(DoubleType)) .rdd .map { case Row(pred: Double, label: Double) => (pred, label) }, !model.getFitIntercept) @@ -638,6 +655,12 @@ class LinearRegressionSummary private[regression] ( /** * Standard error of estimated coefficients and intercept. + * This value is only available when using the "normal" solver. + * + * If [[LinearRegression.fitIntercept]] is set to true, + * then the last element returned corresponds to the intercept. + * + * @see [[LinearRegression.solver]] */ lazy val coefficientStandardErrors: Array[Double] = { if (diagInvAtWA.length == 1 && diagInvAtWA(0) == 0) { @@ -653,12 +676,18 @@ class LinearRegressionSummary private[regression] ( col(model.getWeightCol)).as("wse")).agg(sum(col("wse"))).first().getDouble(0) } val sigma2 = rss / degreesOfFreedom - diagInvAtWA.map(_ * sigma2).map(math.sqrt(_)) + diagInvAtWA.map(_ * sigma2).map(math.sqrt) } } /** * T-statistic of estimated coefficients and intercept. + * This value is only available when using the "normal" solver. + * + * If [[LinearRegression.fitIntercept]] is set to true, + * then the last element returned corresponds to the intercept. + * + * @see [[LinearRegression.solver]] */ lazy val tValues: Array[Double] = { if (diagInvAtWA.length == 1 && diagInvAtWA(0) == 0) { @@ -676,6 +705,12 @@ class LinearRegressionSummary private[regression] ( /** * Two-sided p-value of estimated coefficients and intercept. + * This value is only available when using the "normal" solver. + * + * If [[LinearRegression.fitIntercept]] is set to true, + * then the last element returned corresponds to the intercept. + * + * @see [[LinearRegression.solver]] */ lazy val pValues: Array[Double] = { if (diagInvAtWA.length == 1 && diagInvAtWA(0) == 0) { @@ -826,7 +861,7 @@ private class LeastSquaresAggregator( instance match { case Instance(label, weight, features) => require(dim == features.size, s"Dimensions mismatch when adding new sample." + s" Expecting $dim but got ${features.size}.") - require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0") + require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0") if (weight == 0.0) return this diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 798947b94a..4c4ff278d4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -17,18 +17,22 @@ package org.apache.spark.ml.regression +import org.json4s.{DefaultFormats, JObject} +import org.json4s.JsonDSL._ + import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeEnsembleModel, TreeRegressorParams} +import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.impl.RandomForest -import org.apache.spark.ml.util.{Identifiable, MetadataUtils} +import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ @@ -41,7 +45,7 @@ import org.apache.spark.sql.functions._ @Experimental final class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends Predictor[Vector, RandomForestRegressor, RandomForestRegressionModel] - with RandomForestParams with TreeRegressorParams { + with RandomForestRegressorParams with DefaultParamsWritable { @Since("1.4.0") def this() = this(Identifiable.randomUID("rfr")) @@ -89,7 +93,7 @@ final class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val override def setFeatureSubsetStrategy(value: String): this.type = super.setFeatureSubsetStrategy(value) - override protected def train(dataset: DataFrame): RandomForestRegressionModel = { + override protected def train(dataset: Dataset[_]): RandomForestRegressionModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) @@ -108,7 +112,7 @@ final class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val @Since("1.4.0") @Experimental -object RandomForestRegressor { +object RandomForestRegressor extends DefaultParamsReadable[RandomForestRegressor]{ /** Accessor for supported impurity settings: variance */ @Since("1.4.0") final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities @@ -117,12 +121,17 @@ object RandomForestRegressor { @Since("1.4.0") final val supportedFeatureSubsetStrategies: Array[String] = RandomForestParams.supportedFeatureSubsetStrategies + + @Since("2.0.0") + override def load(path: String): RandomForestRegressor = super.load(path) + } /** * :: Experimental :: * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for regression. * It supports both continuous and categorical features. + * * @param _trees Decision trees in the ensemble. * @param numFeatures Number of features used by this model */ @@ -133,27 +142,29 @@ final class RandomForestRegressionModel private[ml] ( private val _trees: Array[DecisionTreeRegressionModel], override val numFeatures: Int) extends PredictionModel[Vector, RandomForestRegressionModel] - with TreeEnsembleModel with Serializable { + with RandomForestRegressionModelParams with TreeEnsembleModel[DecisionTreeRegressionModel] + with MLWritable with Serializable { - require(numTrees > 0, "RandomForestRegressionModel requires at least 1 tree.") + require(_trees.nonEmpty, "RandomForestRegressionModel requires at least 1 tree.") /** * Construct a random forest regression model, with all trees weighted equally. + * * @param trees Component trees */ private[ml] def this(trees: Array[DecisionTreeRegressionModel], numFeatures: Int) = this(Identifiable.randomUID("rfr"), trees, numFeatures) @Since("1.4.0") - override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] + override def trees: Array[DecisionTreeRegressionModel] = _trees // Note: We may add support for weights (based on tree performance) later on. - private lazy val _treeWeights: Array[Double] = Array.fill[Double](numTrees)(1.0) + private lazy val _treeWeights: Array[Double] = Array.fill[Double](_trees.length)(1.0) @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights - override protected def transformImpl(dataset: DataFrame): DataFrame = { + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) val predictUDF = udf { (features: Any) => bcastModel.value.predict(features.asInstanceOf[Vector]) @@ -165,9 +176,17 @@ final class RandomForestRegressionModel private[ml] ( // TODO: When we add a generic Bagging class, handle transform there. SPARK-7128 // Predict average of tree predictions. // Ignore the weights since all are 1.0 for now. - _trees.map(_.rootNode.predictImpl(features).prediction).sum / numTrees + _trees.map(_.rootNode.predictImpl(features).prediction).sum / getNumTrees } + /** + * Number of trees in ensemble + * @deprecated Use [[getNumTrees]] instead. This method will be removed in 2.1.0 + */ + // TODO: Once this is removed, then this class can inherit from RandomForestRegressorParams + @deprecated("Use getNumTrees instead. This method will be removed in 2.1.0.", "2.0.0") + val numTrees: Int = trees.length + @Since("1.4.0") override def copy(extra: ParamMap): RandomForestRegressionModel = { copyValues(new RandomForestRegressionModel(uid, _trees, numFeatures), extra).setParent(parent) @@ -175,36 +194,83 @@ final class RandomForestRegressionModel private[ml] ( @Since("1.4.0") override def toString: String = { - s"RandomForestRegressionModel (uid=$uid) with $numTrees trees" + s"RandomForestRegressionModel (uid=$uid) with $getNumTrees trees" } /** * Estimate of the importance of each feature. * - * This generalizes the idea of "Gini" importance to other losses, - * following the explanation of Gini importance from "Random Forests" documentation - * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn. + * Each feature's importance is the average of its importance across all trees in the ensemble + * The importance vector is normalized to sum to 1. This method is suggested by Hastie et al. + * (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.) + * and follows the implementation from scikit-learn. * - * This feature importance is calculated as follows: - * - Average over trees: - * - importance(feature j) = sum (over nodes which split on feature j) of the gain, - * where gain is scaled by the number of instances passing through node - * - Normalize importances for tree to sum to 1. - * - Normalize feature importance vector to sum to 1. + * @see [[DecisionTreeRegressionModel.featureImportances]] */ @Since("1.5.0") - lazy val featureImportances: Vector = RandomForest.featureImportances(trees, numFeatures) + lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures) /** (private[ml]) Convert to a model in the old API */ private[ml] def toOld: OldRandomForestModel = { new OldRandomForestModel(OldAlgo.Regression, _trees.map(_.toOld)) } + + @Since("2.0.0") + override def write: MLWriter = + new RandomForestRegressionModel.RandomForestRegressionModelWriter(this) } -private[ml] object RandomForestRegressionModel { +@Since("2.0.0") +object RandomForestRegressionModel extends MLReadable[RandomForestRegressionModel] { + + @Since("2.0.0") + override def read: MLReader[RandomForestRegressionModel] = new RandomForestRegressionModelReader + + @Since("2.0.0") + override def load(path: String): RandomForestRegressionModel = super.load(path) + + private[RandomForestRegressionModel] + class RandomForestRegressionModelWriter(instance: RandomForestRegressionModel) + extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val extraMetadata: JObject = Map( + "numFeatures" -> instance.numFeatures, + "numTrees" -> instance.getNumTrees) + EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata) + } + } + + private class RandomForestRegressionModelReader extends MLReader[RandomForestRegressionModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[RandomForestRegressionModel].getName + private val treeClassName = classOf[DecisionTreeRegressionModel].getName + + override def load(path: String): RandomForestRegressionModel = { + implicit val format = DefaultFormats + val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = + EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName) + val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] + val numTrees = (metadata.metadata \ "numTrees").extract[Int] + + val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) => + val tree = + new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures) + DefaultParamsReader.getAndSetParams(tree, treeMetadata) + tree + } + require(numTrees == trees.length, s"RandomForestRegressionModel.load expected $numTrees" + + s" trees based on metadata but found ${trees.length} trees.") + + val model = new RandomForestRegressionModel(metadata.uid, trees, numFeatures) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } - /** (private[ml]) Convert a model from the old API */ - def fromOld( + /** Convert a model from the old API */ + private[ml] def fromOld( oldModel: OldRandomForestModel, parent: RandomForestRegressor, categoricalFeatures: Map[Int, Int], @@ -215,6 +281,7 @@ private[ml] object RandomForestRegressionModel { // parent for each tree is null since there is no good way to set this. DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } - new RandomForestRegressionModel(parent.uid, newTrees, numFeatures) + val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfr") + new RandomForestRegressionModel(uid, newTrees, numFeatures) } } |