aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/regression
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/regression')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala128
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala15
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala155
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala44
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala16
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala77
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala121
7 files changed, 386 insertions, 170 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index ba5708ab8d..89ba6ab5d2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -31,8 +31,9 @@ import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors, VectorUDT}
+import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, StructType}
import org.apache.spark.storage.StorageLevel
@@ -103,7 +104,7 @@ private[regression] trait AFTSurvivalRegressionParams extends Params
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
if (fitting) {
SchemaUtils.checkColumnType(schema, $(censorCol), DoubleType)
- SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
+ SchemaUtils.checkNumericType(schema, $(labelCol))
}
if (hasQuantilesCol) {
SchemaUtils.appendColumn(schema, $(quantilesCol), new VectorUDT)
@@ -183,24 +184,35 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
* Extract [[featuresCol]], [[labelCol]] and [[censorCol]] from input dataset,
* and put it in an RDD with strong types.
*/
- protected[ml] def extractAFTPoints(dataset: DataFrame): RDD[AFTPoint] = {
- dataset.select($(featuresCol), $(labelCol), $(censorCol)).rdd.map {
- case Row(features: Vector, label: Double, censor: Double) =>
- AFTPoint(features, label, censor)
- }
+ protected[ml] def extractAFTPoints(dataset: Dataset[_]): RDD[AFTPoint] = {
+ dataset.select(col($(featuresCol)), col($(labelCol)).cast(DoubleType), col($(censorCol)))
+ .rdd.map {
+ case Row(features: Vector, label: Double, censor: Double) =>
+ AFTPoint(features, label, censor)
+ }
}
- @Since("1.6.0")
- override def fit(dataset: DataFrame): AFTSurvivalRegressionModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): AFTSurvivalRegressionModel = {
validateAndTransformSchema(dataset.schema, fitting = true)
val instances = extractAFTPoints(dataset)
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
- val costFun = new AFTCostFun(instances, $(fitIntercept))
+ val featuresSummarizer = {
+ val seqOp = (c: MultivariateOnlineSummarizer, v: AFTPoint) => c.add(v.features)
+ val combOp = (c1: MultivariateOnlineSummarizer, c2: MultivariateOnlineSummarizer) => {
+ c1.merge(c2)
+ }
+ instances.treeAggregate(new MultivariateOnlineSummarizer)(seqOp, combOp)
+ }
+
+ val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt)
+
+ val costFun = new AFTCostFun(instances, $(fitIntercept), featuresStd)
val optimizer = new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
- val numFeatures = dataset.select($(featuresCol)).take(1)(0).getAs[Vector](0).size
+ val numFeatures = featuresStd.size
/*
The parameters vector has three parts:
the first element: Double, log(sigma), the log of scale parameter
@@ -229,7 +241,13 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
if (handlePersistence) instances.unpersist()
- val coefficients = Vectors.dense(parameters.slice(2, parameters.length))
+ val rawCoefficients = parameters.slice(2, parameters.length)
+ var i = 0
+ while (i < numFeatures) {
+ rawCoefficients(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i) else 0.0 }
+ i += 1
+ }
+ val coefficients = Vectors.dense(rawCoefficients)
val intercept = parameters(1)
val scale = math.exp(parameters(0))
val model = new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale)
@@ -298,8 +316,8 @@ class AFTSurvivalRegressionModel private[ml] (
math.exp(BLAS.dot(coefficients, features) + intercept)
}
- @Since("1.6.0")
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema)
val predictUDF = udf { features: Vector => predict(features) }
val predictQuantilesUDF = udf { features: Vector => predictQuantiles(features)}
@@ -433,29 +451,36 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel]
* @param parameters including three part: The log of scale parameter, the intercept and
* regression coefficients corresponding to the features.
* @param fitIntercept Whether to fit an intercept term.
+ * @param featuresStd The standard deviation values of the features.
*/
-private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean)
- extends Serializable {
+private class AFTAggregator(
+ parameters: BDV[Double],
+ fitIntercept: Boolean,
+ featuresStd: Array[Double]) extends Serializable {
// the regression coefficients to the covariates
private val coefficients = parameters.slice(2, parameters.length)
- private val intercept = parameters.valueAt(1)
+ private val intercept = parameters(1)
// sigma is the scale parameter of the AFT model
private val sigma = math.exp(parameters(0))
private var totalCnt: Long = 0L
private var lossSum = 0.0
- private var gradientCoefficientSum = BDV.zeros[Double](coefficients.length)
- private var gradientInterceptSum = 0.0
- private var gradientLogSigmaSum = 0.0
+ // Here we optimize loss function over log(sigma), intercept and coefficients
+ private val gradientSumArray = Array.ofDim[Double](parameters.length)
def count: Long = totalCnt
+ def loss: Double = {
+ require(totalCnt > 0.0, s"The number of instances should be " +
+ s"greater than 0.0, but got $totalCnt.")
+ lossSum / totalCnt
+ }
+ def gradient: BDV[Double] = {
+ require(totalCnt > 0.0, s"The number of instances should be " +
+ s"greater than 0.0, but got $totalCnt.")
+ new BDV(gradientSumArray.map(_ / totalCnt.toDouble))
+ }
- def loss: Double = if (totalCnt == 0) 1.0 else lossSum / totalCnt
-
- // Here we optimize loss function over coefficients, intercept and log(sigma)
- def gradient: BDV[Double] = BDV.vertcat(BDV(Array(gradientLogSigmaSum / totalCnt.toDouble)),
- BDV(Array(gradientInterceptSum/totalCnt.toDouble)), gradientCoefficientSum/totalCnt.toDouble)
/**
* Add a new training data to this AFTAggregator, and update the loss and gradient
@@ -465,25 +490,32 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean)
* @return This AFTAggregator object.
*/
def add(data: AFTPoint): this.type = {
-
- val interceptFlag = if (fitIntercept) 1.0 else 0.0
-
- val xi = data.features.toBreeze
+ val xi = data.features
val ti = data.label
val delta = data.censor
- val epsilon = (math.log(ti) - coefficients.dot(xi) - intercept * interceptFlag ) / sigma
- lossSum += math.log(sigma) * delta
- lossSum += (math.exp(epsilon) - delta * epsilon)
+ val margin = {
+ var sum = 0.0
+ xi.foreachActive { (index, value) =>
+ if (featuresStd(index) != 0.0 && value != 0.0) {
+ sum += coefficients(index) * (value / featuresStd(index))
+ }
+ }
+ sum + intercept
+ }
+ val epsilon = (math.log(ti) - margin) / sigma
+
+ lossSum += delta * math.log(sigma) - delta * epsilon + math.exp(epsilon)
- // Sanity check (should never occur):
- assert(!lossSum.isInfinity,
- s"AFTAggregator loss sum is infinity. Error for unknown reason.")
+ val multiplier = (delta - math.exp(epsilon)) / sigma
- val deltaMinusExpEps = delta - math.exp(epsilon)
- gradientCoefficientSum += xi * deltaMinusExpEps / sigma
- gradientInterceptSum += interceptFlag * deltaMinusExpEps / sigma
- gradientLogSigmaSum += delta + deltaMinusExpEps * epsilon
+ gradientSumArray(0) += delta + multiplier * sigma * epsilon
+ gradientSumArray(1) += { if (fitIntercept) multiplier else 0.0 }
+ xi.foreachActive { (index, value) =>
+ if (featuresStd(index) != 0.0 && value != 0.0) {
+ gradientSumArray(index + 2) += multiplier * (value / featuresStd(index))
+ }
+ }
totalCnt += 1
this
@@ -502,9 +534,12 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean)
totalCnt += other.totalCnt
lossSum += other.lossSum
- gradientCoefficientSum += other.gradientCoefficientSum
- gradientInterceptSum += other.gradientInterceptSum
- gradientLogSigmaSum += other.gradientLogSigmaSum
+ var i = 0
+ val len = this.gradientSumArray.length
+ while (i < len) {
+ this.gradientSumArray(i) += other.gradientSumArray(i)
+ i += 1
+ }
}
this
}
@@ -515,12 +550,15 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean)
* It returns the loss and gradient at a particular point (parameters).
* It's used in Breeze's convex optimization routines.
*/
-private class AFTCostFun(data: RDD[AFTPoint], fitIntercept: Boolean)
- extends DiffFunction[BDV[Double]] {
+private class AFTCostFun(
+ data: RDD[AFTPoint],
+ fitIntercept: Boolean,
+ featuresStd: Array[Double]) extends DiffFunction[BDV[Double]] {
override def calculate(parameters: BDV[Double]): (Double, BDV[Double]) = {
- val aftAggregator = data.treeAggregate(new AFTAggregator(parameters, fitIntercept))(
+ val aftAggregator = data.treeAggregate(
+ new AFTAggregator(parameters, fitIntercept, featuresStd))(
seqOp = (c, v) => (c, v) match {
case (aggregator, instance) => aggregator.add(instance)
},
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index 50ac96eb5e..c04c416aaf 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -33,7 +33,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
@@ -83,7 +83,7 @@ final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val
/** @group setParam */
def setVarianceCol(value: String): this.type = set(varianceCol, value)
- override protected def train(dataset: DataFrame): DecisionTreeRegressionModel = {
+ override protected def train(dataset: Dataset[_]): DecisionTreeRegressionModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
@@ -158,15 +158,16 @@ final class DecisionTreeRegressionModel private[ml] (
rootNode.predictImpl(features).impurityStats.calculate()
}
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
transformImpl(dataset)
}
- override protected def transformImpl(dataset: DataFrame): DataFrame = {
+ override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
val predictUDF = udf { (features: Vector) => predict(features) }
val predictVarianceUDF = udf { (features: Vector) => predictVariance(features) }
- var output = dataset
+ var output = dataset.toDF
if ($(predictionCol).nonEmpty) {
output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}
@@ -203,9 +204,9 @@ final class DecisionTreeRegressionModel private[ml] (
* to determine feature importance instead.
*/
@Since("2.0.0")
- lazy val featureImportances: Vector = RandomForest.featureImportances(this, numFeatures)
+ lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(this, numFeatures)
- /** Convert to spark.mllib DecisionTreeModel (losing some infomation) */
+ /** Convert to spark.mllib DecisionTreeModel (losing some information) */
override private[spark] def toOld: OldDecisionTreeModel = {
new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Regression)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index da5b77e8fa..741724d7a1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -18,23 +18,23 @@
package org.apache.spark.ml.regression
import com.github.fommil.netlib.BLAS.{getInstance => blas}
+import org.json4s.{DefaultFormats, JObject}
+import org.json4s.JsonDSL._
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.internal.Logging
import org.apache.spark.ml.{PredictionModel, Predictor}
-import org.apache.spark.ml.param.{Param, ParamMap}
-import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeEnsembleModel,
- TreeRegressorParams}
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.impl.GradientBoostedTrees
-import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
+import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
-import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, Loss => OldLoss,
- SquaredError => OldSquaredError}
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
/**
@@ -42,12 +42,24 @@ import org.apache.spark.sql.functions._
* [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]]
* learning algorithm for regression.
* It supports both continuous and categorical features.
+ *
+ * The implementation is based upon: J.H. Friedman. "Stochastic Gradient Boosting." 1999.
+ *
+ * Notes on Gradient Boosting vs. TreeBoost:
+ * - This implementation is for Stochastic Gradient Boosting, not for TreeBoost.
+ * - Both algorithms learn tree ensembles by minimizing loss functions.
+ * - TreeBoost (Friedman, 1999) additionally modifies the outputs at tree leaf nodes
+ * based on the loss function, whereas the original gradient boosting method does not.
+ * - When the loss is SquaredError, these methods give the same result, but they could differ
+ * for other loss functions.
+ * - We expect to implement TreeBoost in the future:
+ * [https://issues.apache.org/jira/browse/SPARK-4240]
*/
@Since("1.4.0")
@Experimental
final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Predictor[Vector, GBTRegressor, GBTRegressionModel]
- with GBTParams with TreeRegressorParams with Logging {
+ with GBTRegressorParams with DefaultParamsWritable with Logging {
@Since("1.4.0")
def this() = this(Identifiable.randomUID("gbtr"))
@@ -101,42 +113,13 @@ final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: Stri
@Since("1.4.0")
override def setStepSize(value: Double): this.type = super.setStepSize(value)
- // Parameters for GBTRegressor:
-
- /**
- * Loss function which GBT tries to minimize. (case-insensitive)
- * Supported: "squared" (L2) and "absolute" (L1)
- * (default = squared)
- * @group param
- */
- @Since("1.4.0")
- val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" +
- " tries to minimize (case-insensitive). Supported options:" +
- s" ${GBTRegressor.supportedLossTypes.mkString(", ")}",
- (value: String) => GBTRegressor.supportedLossTypes.contains(value.toLowerCase))
-
- setDefault(lossType -> "squared")
+ // Parameters from GBTRegressorParams:
/** @group setParam */
@Since("1.4.0")
def setLossType(value: String): this.type = set(lossType, value)
- /** @group getParam */
- @Since("1.4.0")
- def getLossType: String = $(lossType).toLowerCase
-
- /** (private[ml]) Convert new loss to old loss. */
- override private[ml] def getOldLossType: OldLoss = {
- getLossType match {
- case "squared" => OldSquaredError
- case "absolute" => OldAbsoluteError
- case _ =>
- // Should never happen because of check in setter method.
- throw new RuntimeException(s"GBTRegressorParams was given bad loss type: $getLossType")
- }
- }
-
- override protected def train(dataset: DataFrame): GBTRegressionModel = {
+ override protected def train(dataset: Dataset[_]): GBTRegressionModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
@@ -153,11 +136,14 @@ final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: Stri
@Since("1.4.0")
@Experimental
-object GBTRegressor {
- // The losses below should be lowercase.
+object GBTRegressor extends DefaultParamsReadable[GBTRegressor] {
+
/** Accessor for supported loss settings: squared (L2), absolute (L1) */
@Since("1.4.0")
- final val supportedLossTypes: Array[String] = Array("squared", "absolute").map(_.toLowerCase)
+ final val supportedLossTypes: Array[String] = GBTRegressorParams.supportedLossTypes
+
+ @Since("2.0.0")
+ override def load(path: String): GBTRegressor = super.load(path)
}
/**
@@ -177,9 +163,10 @@ final class GBTRegressionModel private[ml](
private val _treeWeights: Array[Double],
override val numFeatures: Int)
extends PredictionModel[Vector, GBTRegressionModel]
- with TreeEnsembleModel with Serializable {
+ with GBTRegressorParams with TreeEnsembleModel[DecisionTreeRegressionModel]
+ with MLWritable with Serializable {
- require(numTrees > 0, "GBTRegressionModel requires at least 1 tree.")
+ require(_trees.nonEmpty, "GBTRegressionModel requires at least 1 tree.")
require(_trees.length == _treeWeights.length, "GBTRegressionModel given trees, treeWeights of" +
s" non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).")
@@ -193,12 +180,12 @@ final class GBTRegressionModel private[ml](
this(uid, _trees, _treeWeights, -1)
@Since("1.4.0")
- override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
+ override def trees: Array[DecisionTreeRegressionModel] = _trees
@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights
- override protected def transformImpl(dataset: DataFrame): DataFrame = {
+ override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
val predictUDF = udf { (features: Any) =>
bcastModel.value.predict(features.asInstanceOf[Vector])
@@ -213,6 +200,9 @@ final class GBTRegressionModel private[ml](
blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
}
+ /** Number of trees in ensemble */
+ val numTrees: Int = trees.length
+
@Since("1.4.0")
override def copy(extra: ParamMap): GBTRegressionModel = {
copyValues(new GBTRegressionModel(uid, _trees, _treeWeights, numFeatures),
@@ -224,16 +214,81 @@ final class GBTRegressionModel private[ml](
s"GBTRegressionModel (uid=$uid) with $numTrees trees"
}
+ /**
+ * Estimate of the importance of each feature.
+ *
+ * Each feature's importance is the average of its importance across all trees in the ensemble
+ * The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
+ * (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
+ * and follows the implementation from scikit-learn.
+ *
+ * @see [[DecisionTreeRegressionModel.featureImportances]]
+ */
+ @Since("2.0.0")
+ lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures)
+
/** (private[ml]) Convert to a model in the old API */
private[ml] def toOld: OldGBTModel = {
new OldGBTModel(OldAlgo.Regression, _trees.map(_.toOld), _treeWeights)
}
+
+ @Since("2.0.0")
+ override def write: MLWriter = new GBTRegressionModel.GBTRegressionModelWriter(this)
}
-private[ml] object GBTRegressionModel {
+@Since("2.0.0")
+object GBTRegressionModel extends MLReadable[GBTRegressionModel] {
+
+ @Since("2.0.0")
+ override def read: MLReader[GBTRegressionModel] = new GBTRegressionModelReader
+
+ @Since("2.0.0")
+ override def load(path: String): GBTRegressionModel = super.load(path)
+
+ private[GBTRegressionModel]
+ class GBTRegressionModelWriter(instance: GBTRegressionModel) extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ val extraMetadata: JObject = Map(
+ "numFeatures" -> instance.numFeatures,
+ "numTrees" -> instance.getNumTrees)
+ EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata)
+ }
+ }
+
+ private class GBTRegressionModelReader extends MLReader[GBTRegressionModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[GBTRegressionModel].getName
+ private val treeClassName = classOf[DecisionTreeRegressionModel].getName
+
+ override def load(path: String): GBTRegressionModel = {
+ implicit val format = DefaultFormats
+ val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) =
+ EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName)
+
+ val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
+ val numTrees = (metadata.metadata \ "numTrees").extract[Int]
+
+ val trees: Array[DecisionTreeRegressionModel] = treesData.map {
+ case (treeMetadata, root) =>
+ val tree =
+ new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures)
+ DefaultParamsReader.getAndSetParams(tree, treeMetadata)
+ tree
+ }
+
+ require(numTrees == trees.length, s"GBTRegressionModel.load expected $numTrees" +
+ s" trees based on metadata but found ${trees.length} trees.")
+
+ val model = new GBTRegressionModel(metadata.uid, trees, treeWeights, numFeatures)
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
- /** (private[ml]) Convert a model from the old API */
- def fromOld(
+ /** Convert a model from the old API */
+ private[ml] def fromOld(
oldModel: OldGBTModel,
parent: GBTRegressor,
categoricalFeatures: Map[Int, Int],
@@ -245,6 +300,6 @@ private[ml] object GBTRegressionModel {
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtr")
- new GBTRegressionModel(parent.uid, newTrees, oldModel.treeWeights, numFeatures)
+ new GBTRegressionModel(uid, newTrees, oldModel.treeWeights, numFeatures)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index 0e71e8d8e1..e92a3e7fa1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -31,9 +31,9 @@ import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.{BLAS, Vector}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.{DataType, StructType}
+import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
/**
* Params for Generalized Linear Regression.
@@ -47,6 +47,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
* to be used in the model.
* Supported options: "gaussian", "binomial", "poisson" and "gamma".
* Default is "gaussian".
+ *
* @group param
*/
@Since("2.0.0")
@@ -63,6 +64,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
* Param for the name of link function which provides the relationship
* between the linear predictor and the mean of the distribution function.
* Supported options: "identity", "log", "inverse", "logit", "probit", "cloglog" and "sqrt".
+ *
* @group param
*/
@Since("2.0.0")
@@ -163,7 +165,11 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
setDefault(tol -> 1E-6)
/**
- * Sets the regularization parameter.
+ * Sets the regularization parameter for L2 regularization.
+ * The regularization term is
+ * {{{
+ * 0.5 * regParam * L2norm(coefficients)^2
+ * }}}
* Default is 0.0.
* @group setParam
*/
@@ -190,7 +196,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
def setSolver(value: String): this.type = set(solver, value)
setDefault(solver -> "irls")
- override protected def train(dataset: DataFrame): GeneralizedLinearRegressionModel = {
+ override protected def train(dataset: Dataset[_]): GeneralizedLinearRegressionModel = {
val familyObj = Family.fromName($(family))
val linkObj = if (isDefined(link)) {
Link.fromName($(link))
@@ -210,9 +216,10 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
}
val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol))
- val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd
- .map { case Row(label: Double, weight: Double, features: Vector) =>
- Instance(label, weight, features)
+ val instances: RDD[Instance] =
+ dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
+ case Row(label: Double, weight: Double, features: Vector) =>
+ Instance(label, weight, features)
}
if (familyObj == Gaussian && linkObj == Identity) {
@@ -230,7 +237,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
predictionColName,
model,
wlsModel.diagInvAtWA.toArray,
- 1)
+ 1,
+ getSolver)
return model.setSummary(trainingSummary)
}
@@ -250,7 +258,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
predictionColName,
model,
irlsModel.diagInvAtWA.toArray,
- irlsModel.numIterations)
+ irlsModel.numIterations,
+ getSolver)
model.setSummary(trainingSummary)
}
@@ -698,7 +707,7 @@ class GeneralizedLinearRegressionModel private[ml] (
: (GeneralizedLinearRegressionModel, String) = {
$(predictionCol) match {
case "" =>
- val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString()
+ val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString
(copy(ParamMap.empty).setPredictionCol(predictionColName), predictionColName)
case p => (this, p)
}
@@ -769,11 +778,12 @@ object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegr
* :: Experimental ::
* Summarizing Generalized Linear regression Fits.
*
- * @param predictions predictions outputted by the model's `transform` method
+ * @param predictions predictions output by the model's `transform` method
* @param predictionCol field in "predictions" which gives the prediction value of each instance
* @param model the model that should be summarized
* @param diagInvAtWA diagonal of matrix (A^T * W * A)^-1 in the last iteration
* @param numIterations number of iterations
+ * @param solver the solver algorithm used for model training
*/
@Since("2.0.0")
@Experimental
@@ -782,7 +792,8 @@ class GeneralizedLinearRegressionSummary private[regression] (
@Since("2.0.0") val predictionCol: String,
@Since("2.0.0") val model: GeneralizedLinearRegressionModel,
private val diagInvAtWA: Array[Double],
- @Since("2.0.0") val numIterations: Int) extends Serializable {
+ @Since("2.0.0") val numIterations: Int,
+ @Since("2.0.0") val solver: String) extends Serializable {
import GeneralizedLinearRegression._
@@ -930,6 +941,9 @@ class GeneralizedLinearRegressionSummary private[regression] (
/**
* Standard error of estimated coefficients and intercept.
+ *
+ * If [[GeneralizedLinearRegression.fitIntercept]] is set to true,
+ * then the last element returned corresponds to the intercept.
*/
@Since("2.0.0")
lazy val coefficientStandardErrors: Array[Double] = {
@@ -938,6 +952,9 @@ class GeneralizedLinearRegressionSummary private[regression] (
/**
* T-statistic of estimated coefficients and intercept.
+ *
+ * If [[GeneralizedLinearRegression.fitIntercept]] is set to true,
+ * then the last element returned corresponds to the intercept.
*/
@Since("2.0.0")
lazy val tValues: Array[Double] = {
@@ -951,6 +968,9 @@ class GeneralizedLinearRegressionSummary private[regression] (
/**
* Two-sided p-value of estimated coefficients and intercept.
+ *
+ * If [[GeneralizedLinearRegression.fitIntercept]] is set to true,
+ * then the last element returned corresponds to the intercept.
*/
@Since("2.0.0")
lazy val pValues: Array[Double] = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
index fb733f9a34..7a78ecbdf1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
@@ -30,7 +30,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT}
import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression}
import org.apache.spark.mllib.regression.{IsotonicRegressionModel => MLlibIsotonicRegressionModel}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.{col, lit, udf}
import org.apache.spark.sql.types.{DoubleType, StructType}
import org.apache.spark.storage.StorageLevel
@@ -77,7 +77,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures
* Extracts (label, feature, weight) from input dataset.
*/
protected[ml] def extractWeightedLabeledPoints(
- dataset: DataFrame): RDD[(Double, Double, Double)] = {
+ dataset: Dataset[_]): RDD[(Double, Double, Double)] = {
val f = if (dataset.schema($(featuresCol)).dataType.isInstanceOf[VectorUDT]) {
val idx = $(featureIndex)
val extract = udf { v: Vector => v(idx) }
@@ -90,7 +90,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures
} else {
lit(1.0)
}
- dataset.select(col($(labelCol)), f, w).rdd.map {
+ dataset.select(col($(labelCol)).cast(DoubleType), f, w).rdd.map {
case Row(label: Double, feature: Double, weight: Double) =>
(label, feature, weight)
}
@@ -106,7 +106,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures
schema: StructType,
fitting: Boolean): StructType = {
if (fitting) {
- SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
+ SchemaUtils.checkNumericType(schema, $(labelCol))
if (hasWeightCol) {
SchemaUtils.checkColumnType(schema, $(weightCol), DoubleType)
} else {
@@ -164,8 +164,8 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri
@Since("1.5.0")
override def copy(extra: ParamMap): IsotonicRegression = defaultCopy(extra)
- @Since("1.5.0")
- override def fit(dataset: DataFrame): IsotonicRegressionModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): IsotonicRegressionModel = {
validateAndTransformSchema(dataset.schema, fitting = true)
// Extract columns from data. If dataset is persisted, do not persist oldDataset.
val instances = extractWeightedLabeledPoints(dataset)
@@ -236,8 +236,8 @@ class IsotonicRegressionModel private[ml] (
copyValues(new IsotonicRegressionModel(uid, oldModel), extra).setParent(parent)
}
- @Since("1.5.0")
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
val predict = dataset.schema($(featuresCol)).dataType match {
case DoubleType =>
udf { feature: Double => oldModel.predict(feature) }
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index b81c588e44..71e02730c7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -38,8 +38,9 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.linalg.BLAS._
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.DoubleType
import org.apache.spark.storage.StorageLevel
/**
@@ -57,7 +58,7 @@ private[regression] trait LinearRegressionParams extends PredictorParams
* The specific squared error loss function used is:
* L = 1/2n ||A coefficients - y||^2^
*
- * This support multiple types of regularization:
+ * This supports multiple types of regularization:
* - none (a.k.a. ordinary least squares)
* - L2 (ridge regression)
* - L1 (Lasso)
@@ -157,7 +158,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
def setSolver(value: String): this.type = set(solver, value)
setDefault(solver -> "auto")
- override protected def train(dataset: DataFrame): LinearRegressionModel = {
+ override protected def train(dataset: Dataset[_]): LinearRegressionModel = {
// Extract the number of features before deciding optimization solver.
val numFeatures = dataset.select(col($(featuresCol))).limit(1).rdd.map {
case Row(features: Vector) => features.size
@@ -171,7 +172,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
// For low dimensional data, WeightedLeastSquares is more efficiently since the
// training algorithm only requires one pass through the data. (SPARK-10668)
val instances: RDD[Instance] = dataset.select(
- col($(labelCol)), w, col($(featuresCol))).rdd.map {
+ col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}
@@ -189,9 +190,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
summaryModel.transform(dataset),
predictionColName,
$(labelCol),
+ $(featuresCol),
summaryModel,
model.diagInvAtWA.toArray,
- $(featuresCol),
Array(0D))
return lrModel.setSummary(trainingSummary)
@@ -248,9 +249,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
summaryModel.transform(dataset),
predictionColName,
$(labelCol),
+ $(featuresCol),
model,
Array(0D),
- $(featuresCol),
Array(0D))
return copyValues(model.setSummary(trainingSummary))
} else {
@@ -355,9 +356,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
summaryModel.transform(dataset),
predictionColName,
$(labelCol),
+ $(featuresCol),
model,
Array(0D),
- $(featuresCol),
objectiveHistory)
model.setSummary(trainingSummary)
}
@@ -412,15 +413,15 @@ class LinearRegressionModel private[ml] (
def hasSummary: Boolean = trainingSummary.isDefined
/**
- * Evaluates the model on a testset.
+ * Evaluates the model on a test dataset.
* @param dataset Test dataset to evaluate model on.
*/
- // TODO: decide on a good name before exposing to public API
- private[regression] def evaluate(dataset: DataFrame): LinearRegressionSummary = {
+ @Since("2.0.0")
+ def evaluate(dataset: Dataset[_]): LinearRegressionSummary = {
// Handle possible missing or invalid prediction columns
val (summaryModel, predictionColName) = findSummaryModelAndPredictionCol()
new LinearRegressionSummary(summaryModel.transform(dataset), predictionColName,
- $(labelCol), this, Array(0D))
+ $(labelCol), $(featuresCol), summaryModel, Array(0D))
}
/**
@@ -431,7 +432,7 @@ class LinearRegressionModel private[ml] (
private[regression] def findSummaryModelAndPredictionCol(): (LinearRegressionModel, String) = {
$(predictionCol) match {
case "" =>
- val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString()
+ val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString
(copy(ParamMap.empty).setPredictionCol(predictionColName), predictionColName)
case p => (this, p)
}
@@ -510,9 +511,9 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] {
/**
* :: Experimental ::
* Linear regression training results. Currently, the training summary ignores the
- * training coefficients except for the objective trace.
+ * training weights except for the objective trace.
*
- * @param predictions predictions outputted by the model's `transform` method.
+ * @param predictions predictions output by the model's `transform` method.
* @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
*/
@Since("1.5.0")
@@ -521,13 +522,24 @@ class LinearRegressionTrainingSummary private[regression] (
predictions: DataFrame,
predictionCol: String,
labelCol: String,
+ featuresCol: String,
model: LinearRegressionModel,
diagInvAtWA: Array[Double],
- val featuresCol: String,
val objectiveHistory: Array[Double])
- extends LinearRegressionSummary(predictions, predictionCol, labelCol, model, diagInvAtWA) {
+ extends LinearRegressionSummary(
+ predictions,
+ predictionCol,
+ labelCol,
+ featuresCol,
+ model,
+ diagInvAtWA) {
- /** Number of training iterations until termination */
+ /**
+ * Number of training iterations until termination
+ *
+ * This value is only available when using the "l-bfgs" solver.
+ * @see [[LinearRegression.solver]]
+ */
@Since("1.5.0")
val totalIterations = objectiveHistory.length
@@ -537,7 +549,11 @@ class LinearRegressionTrainingSummary private[regression] (
* :: Experimental ::
* Linear regression results evaluated on a dataset.
*
- * @param predictions predictions outputted by the model's `transform` method.
+ * @param predictions predictions output by the model's `transform` method.
+ * @param predictionCol Field in "predictions" which gives the predicted value of the label at
+ * each instance.
+ * @param labelCol Field in "predictions" which gives the true label of each instance.
+ * @param featuresCol Field in "predictions" which gives the features of each instance as a vector.
*/
@Since("1.5.0")
@Experimental
@@ -545,12 +561,13 @@ class LinearRegressionSummary private[regression] (
@transient val predictions: DataFrame,
val predictionCol: String,
val labelCol: String,
+ val featuresCol: String,
val model: LinearRegressionModel,
private val diagInvAtWA: Array[Double]) extends Serializable {
@transient private val metrics = new RegressionMetrics(
predictions
- .select(predictionCol, labelCol)
+ .select(col(predictionCol), col(labelCol).cast(DoubleType))
.rdd
.map { case Row(pred: Double, label: Double) => (pred, label) },
!model.getFitIntercept)
@@ -638,6 +655,12 @@ class LinearRegressionSummary private[regression] (
/**
* Standard error of estimated coefficients and intercept.
+ * This value is only available when using the "normal" solver.
+ *
+ * If [[LinearRegression.fitIntercept]] is set to true,
+ * then the last element returned corresponds to the intercept.
+ *
+ * @see [[LinearRegression.solver]]
*/
lazy val coefficientStandardErrors: Array[Double] = {
if (diagInvAtWA.length == 1 && diagInvAtWA(0) == 0) {
@@ -653,12 +676,18 @@ class LinearRegressionSummary private[regression] (
col(model.getWeightCol)).as("wse")).agg(sum(col("wse"))).first().getDouble(0)
}
val sigma2 = rss / degreesOfFreedom
- diagInvAtWA.map(_ * sigma2).map(math.sqrt(_))
+ diagInvAtWA.map(_ * sigma2).map(math.sqrt)
}
}
/**
* T-statistic of estimated coefficients and intercept.
+ * This value is only available when using the "normal" solver.
+ *
+ * If [[LinearRegression.fitIntercept]] is set to true,
+ * then the last element returned corresponds to the intercept.
+ *
+ * @see [[LinearRegression.solver]]
*/
lazy val tValues: Array[Double] = {
if (diagInvAtWA.length == 1 && diagInvAtWA(0) == 0) {
@@ -676,6 +705,12 @@ class LinearRegressionSummary private[regression] (
/**
* Two-sided p-value of estimated coefficients and intercept.
+ * This value is only available when using the "normal" solver.
+ *
+ * If [[LinearRegression.fitIntercept]] is set to true,
+ * then the last element returned corresponds to the intercept.
+ *
+ * @see [[LinearRegression.solver]]
*/
lazy val pValues: Array[Double] = {
if (diagInvAtWA.length == 1 && diagInvAtWA(0) == 0) {
@@ -826,7 +861,7 @@ private class LeastSquaresAggregator(
instance match { case Instance(label, weight, features) =>
require(dim == features.size, s"Dimensions mismatch when adding new sample." +
s" Expecting $dim but got ${features.size}.")
- require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0")
+ require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0")
if (weight == 0.0) return this
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index 798947b94a..4c4ff278d4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -17,18 +17,22 @@
package org.apache.spark.ml.regression
+import org.json4s.{DefaultFormats, JObject}
+import org.json4s.JsonDSL._
+
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeEnsembleModel, TreeRegressorParams}
+import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.impl.RandomForest
-import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
+import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
@@ -41,7 +45,7 @@ import org.apache.spark.sql.functions._
@Experimental
final class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Predictor[Vector, RandomForestRegressor, RandomForestRegressionModel]
- with RandomForestParams with TreeRegressorParams {
+ with RandomForestRegressorParams with DefaultParamsWritable {
@Since("1.4.0")
def this() = this(Identifiable.randomUID("rfr"))
@@ -89,7 +93,7 @@ final class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val
override def setFeatureSubsetStrategy(value: String): this.type =
super.setFeatureSubsetStrategy(value)
- override protected def train(dataset: DataFrame): RandomForestRegressionModel = {
+ override protected def train(dataset: Dataset[_]): RandomForestRegressionModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
@@ -108,7 +112,7 @@ final class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val
@Since("1.4.0")
@Experimental
-object RandomForestRegressor {
+object RandomForestRegressor extends DefaultParamsReadable[RandomForestRegressor]{
/** Accessor for supported impurity settings: variance */
@Since("1.4.0")
final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities
@@ -117,12 +121,17 @@ object RandomForestRegressor {
@Since("1.4.0")
final val supportedFeatureSubsetStrategies: Array[String] =
RandomForestParams.supportedFeatureSubsetStrategies
+
+ @Since("2.0.0")
+ override def load(path: String): RandomForestRegressor = super.load(path)
+
}
/**
* :: Experimental ::
* [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for regression.
* It supports both continuous and categorical features.
+ *
* @param _trees Decision trees in the ensemble.
* @param numFeatures Number of features used by this model
*/
@@ -133,27 +142,29 @@ final class RandomForestRegressionModel private[ml] (
private val _trees: Array[DecisionTreeRegressionModel],
override val numFeatures: Int)
extends PredictionModel[Vector, RandomForestRegressionModel]
- with TreeEnsembleModel with Serializable {
+ with RandomForestRegressionModelParams with TreeEnsembleModel[DecisionTreeRegressionModel]
+ with MLWritable with Serializable {
- require(numTrees > 0, "RandomForestRegressionModel requires at least 1 tree.")
+ require(_trees.nonEmpty, "RandomForestRegressionModel requires at least 1 tree.")
/**
* Construct a random forest regression model, with all trees weighted equally.
+ *
* @param trees Component trees
*/
private[ml] def this(trees: Array[DecisionTreeRegressionModel], numFeatures: Int) =
this(Identifiable.randomUID("rfr"), trees, numFeatures)
@Since("1.4.0")
- override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
+ override def trees: Array[DecisionTreeRegressionModel] = _trees
// Note: We may add support for weights (based on tree performance) later on.
- private lazy val _treeWeights: Array[Double] = Array.fill[Double](numTrees)(1.0)
+ private lazy val _treeWeights: Array[Double] = Array.fill[Double](_trees.length)(1.0)
@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights
- override protected def transformImpl(dataset: DataFrame): DataFrame = {
+ override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
val predictUDF = udf { (features: Any) =>
bcastModel.value.predict(features.asInstanceOf[Vector])
@@ -165,9 +176,17 @@ final class RandomForestRegressionModel private[ml] (
// TODO: When we add a generic Bagging class, handle transform there. SPARK-7128
// Predict average of tree predictions.
// Ignore the weights since all are 1.0 for now.
- _trees.map(_.rootNode.predictImpl(features).prediction).sum / numTrees
+ _trees.map(_.rootNode.predictImpl(features).prediction).sum / getNumTrees
}
+ /**
+ * Number of trees in ensemble
+ * @deprecated Use [[getNumTrees]] instead. This method will be removed in 2.1.0
+ */
+ // TODO: Once this is removed, then this class can inherit from RandomForestRegressorParams
+ @deprecated("Use getNumTrees instead. This method will be removed in 2.1.0.", "2.0.0")
+ val numTrees: Int = trees.length
+
@Since("1.4.0")
override def copy(extra: ParamMap): RandomForestRegressionModel = {
copyValues(new RandomForestRegressionModel(uid, _trees, numFeatures), extra).setParent(parent)
@@ -175,36 +194,83 @@ final class RandomForestRegressionModel private[ml] (
@Since("1.4.0")
override def toString: String = {
- s"RandomForestRegressionModel (uid=$uid) with $numTrees trees"
+ s"RandomForestRegressionModel (uid=$uid) with $getNumTrees trees"
}
/**
* Estimate of the importance of each feature.
*
- * This generalizes the idea of "Gini" importance to other losses,
- * following the explanation of Gini importance from "Random Forests" documentation
- * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
+ * Each feature's importance is the average of its importance across all trees in the ensemble
+ * The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
+ * (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
+ * and follows the implementation from scikit-learn.
*
- * This feature importance is calculated as follows:
- * - Average over trees:
- * - importance(feature j) = sum (over nodes which split on feature j) of the gain,
- * where gain is scaled by the number of instances passing through node
- * - Normalize importances for tree to sum to 1.
- * - Normalize feature importance vector to sum to 1.
+ * @see [[DecisionTreeRegressionModel.featureImportances]]
*/
@Since("1.5.0")
- lazy val featureImportances: Vector = RandomForest.featureImportances(trees, numFeatures)
+ lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures)
/** (private[ml]) Convert to a model in the old API */
private[ml] def toOld: OldRandomForestModel = {
new OldRandomForestModel(OldAlgo.Regression, _trees.map(_.toOld))
}
+
+ @Since("2.0.0")
+ override def write: MLWriter =
+ new RandomForestRegressionModel.RandomForestRegressionModelWriter(this)
}
-private[ml] object RandomForestRegressionModel {
+@Since("2.0.0")
+object RandomForestRegressionModel extends MLReadable[RandomForestRegressionModel] {
+
+ @Since("2.0.0")
+ override def read: MLReader[RandomForestRegressionModel] = new RandomForestRegressionModelReader
+
+ @Since("2.0.0")
+ override def load(path: String): RandomForestRegressionModel = super.load(path)
+
+ private[RandomForestRegressionModel]
+ class RandomForestRegressionModelWriter(instance: RandomForestRegressionModel)
+ extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ val extraMetadata: JObject = Map(
+ "numFeatures" -> instance.numFeatures,
+ "numTrees" -> instance.getNumTrees)
+ EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata)
+ }
+ }
+
+ private class RandomForestRegressionModelReader extends MLReader[RandomForestRegressionModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[RandomForestRegressionModel].getName
+ private val treeClassName = classOf[DecisionTreeRegressionModel].getName
+
+ override def load(path: String): RandomForestRegressionModel = {
+ implicit val format = DefaultFormats
+ val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) =
+ EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName)
+ val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
+ val numTrees = (metadata.metadata \ "numTrees").extract[Int]
+
+ val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) =>
+ val tree =
+ new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures)
+ DefaultParamsReader.getAndSetParams(tree, treeMetadata)
+ tree
+ }
+ require(numTrees == trees.length, s"RandomForestRegressionModel.load expected $numTrees" +
+ s" trees based on metadata but found ${trees.length} trees.")
+
+ val model = new RandomForestRegressionModel(metadata.uid, trees, numFeatures)
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
- /** (private[ml]) Convert a model from the old API */
- def fromOld(
+ /** Convert a model from the old API */
+ private[ml] def fromOld(
oldModel: OldRandomForestModel,
parent: RandomForestRegressor,
categoricalFeatures: Map[Int, Int],
@@ -215,6 +281,7 @@ private[ml] object RandomForestRegressionModel {
// parent for each tree is null since there is no good way to set this.
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
- new RandomForestRegressionModel(parent.uid, newTrees, numFeatures)
+ val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfr")
+ new RandomForestRegressionModel(uid, newTrees, numFeatures)
}
}