aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2016-04-28 11:22:13 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-04-28 11:22:13 -0700
commit5ee72454df21ef4668c855134627d0cdf5d35132 (patch)
tree6a38fd5cb3f4cc45f2714284c3acff9f219f7075 /mllib/src/main/scala
parent12c360c057f09d13a31c458ad277640b5f6de394 (diff)
downloadspark-5ee72454df21ef4668c855134627d0cdf5d35132.tar.gz
spark-5ee72454df21ef4668c855134627d0cdf5d35132.tar.bz2
spark-5ee72454df21ef4668c855134627d0cdf5d35132.zip
[SPARK-14852][ML] refactored GLM summary into training, non-training summaries
## What changes were proposed in this pull request? This splits GeneralizedLinearRegressionSummary into 2 summary types: * GeneralizedLinearRegressionSummary, which does not store info from fitting (diagInvAtWA) * GeneralizedLinearRegressionTrainingSummary, which is a subclass of GeneralizedLinearRegressionSummary and stores info from fitting This also add a method evaluate() which can produce a GeneralizedLinearRegressionSummary on a new dataset. The summary no longer provides the model itself as a public val. Also: * Fixes bug where GeneralizedLinearRegressionTrainingSummary was created with model, not summaryModel. * Adds hasSummary method. * Renames findSummaryModelAndPredictionCol -> getSummaryModel and simplifies that method. * In summary, extract values from model immediately in case user later changes those (e.g., predictionCol). * Pardon the style fixes; that is IntelliJ being obnoxious. ## How was this patch tested? Existing unit tests + updated test for evaluate and hasSummary Author: Joseph K. Bradley <joseph@databricks.com> Closes #12624 from jkbradley/model-summary-api.
Diffstat (limited to 'mllib/src/main/scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala156
1 files changed, 101 insertions, 55 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index dcf69afe0d..bf9d3ff30c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -35,6 +35,7 @@ import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
+
/**
* Params for Generalized Linear Regression.
*/
@@ -81,6 +82,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
/**
* Param for link prediction (linear predictor) column name.
* Default is empty, which means we do not output link prediction.
+ *
* @group param
*/
@Since("2.0.0")
@@ -144,6 +146,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
/**
* Sets the value of param [[family]].
* Default is "gaussian".
+ *
* @group setParam
*/
@Since("2.0.0")
@@ -152,6 +155,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
/**
* Sets the value of param [[link]].
+ *
* @group setParam
*/
@Since("2.0.0")
@@ -160,6 +164,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
/**
* Sets if we should fit the intercept.
* Default is true.
+ *
* @group setParam
*/
@Since("2.0.0")
@@ -168,6 +173,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
/**
* Sets the maximum number of iterations.
* Default is 25 if the solver algorithm is "irls".
+ *
* @group setParam
*/
@Since("2.0.0")
@@ -177,6 +183,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
* Sets the convergence tolerance of iterations.
* Smaller value will lead to higher accuracy with the cost of more iterations.
* Default is 1E-6.
+ *
* @group setParam
*/
@Since("2.0.0")
@@ -190,6 +197,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
* 0.5 * regParam * L2norm(coefficients)^2
* }}}
* Default is 0.0.
+ *
* @group setParam
*/
@Since("2.0.0")
@@ -200,6 +208,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
* Sets the value of param [[weightCol]].
* If this is not set or empty, we treat all instance weights as 1.0.
* Default is empty, so all instances have weight one.
+ *
* @group setParam
*/
@Since("2.0.0")
@@ -209,6 +218,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
/**
* Sets the solver algorithm used for optimization.
* Currently only support "irls" which is also the default solver.
+ *
* @group setParam
*/
@Since("2.0.0")
@@ -217,6 +227,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
/**
* Sets the link prediction (linear predictor) column name.
+ *
* @group setParam
*/
@Since("2.0.0")
@@ -256,15 +267,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
val model = copyValues(
new GeneralizedLinearRegressionModel(uid, wlsModel.coefficients, wlsModel.intercept)
.setParent(this))
- // Handle possible missing or invalid prediction columns
- val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol()
- val trainingSummary = new GeneralizedLinearRegressionSummary(
- summaryModel.transform(dataset),
- predictionColName,
- model,
- wlsModel.diagInvAtWA.toArray,
- 1,
- getSolver)
+ val trainingSummary = new GeneralizedLinearRegressionTrainingSummary(dataset, model,
+ wlsModel.diagInvAtWA.toArray, 1, getSolver)
return model.setSummary(trainingSummary)
}
@@ -277,16 +281,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
val model = copyValues(
new GeneralizedLinearRegressionModel(uid, irlsModel.coefficients, irlsModel.intercept)
.setParent(this))
- // Handle possible missing or invalid prediction columns
- val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol()
- val trainingSummary = new GeneralizedLinearRegressionSummary(
- summaryModel.transform(dataset),
- predictionColName,
- model,
- irlsModel.diagInvAtWA.toArray,
- irlsModel.numIterations,
- getSolver)
-
+ val trainingSummary = new GeneralizedLinearRegressionTrainingSummary(dataset, model,
+ irlsModel.diagInvAtWA.toArray, irlsModel.numIterations, getSolver)
model.setSummary(trainingSummary)
}
@@ -363,6 +359,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
/**
* A description of the error distribution to be used in the model.
+ *
* @param name the name of the family.
*/
private[ml] abstract class Family(val name: String) extends Serializable {
@@ -381,6 +378,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
/**
* Akaike's 'An Information Criterion'(AIC) value of the family for a given dataset.
+ *
* @param predictions an RDD of (y, mu, weight) of instances in evaluation dataset
* @param deviance the deviance for the fitted model in evaluation dataset
* @param numInstances number of instances in evaluation dataset
@@ -400,6 +398,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
/**
* Gets the [[Family]] object from its name.
+ *
* @param name family name: "gaussian", "binomial", "poisson" or "gamma".
*/
def fromName(name: String): Family = {
@@ -579,6 +578,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
* A description of the link function to be used in the model.
* The link function provides the relationship between the linear predictor
* and the mean of the distribution function.
+ *
* @param name the name of link function.
*/
private[ml] abstract class Link(val name: String) extends Serializable {
@@ -597,6 +597,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
/**
* Gets the [[Link]] object from its name.
+ *
* @param name link name: "identity", "logit", "log",
* "inverse", "probit", "cloglog" or "sqrt".
*/
@@ -694,6 +695,7 @@ class GeneralizedLinearRegressionModel private[ml] (
/**
* Sets the link prediction (linear predictor) column name.
+ *
* @group setParam
*/
@Since("2.0.0")
@@ -736,39 +738,39 @@ class GeneralizedLinearRegressionModel private[ml] (
if ($(linkPredictionCol).nonEmpty) {
output = output.withColumn($(linkPredictionCol), predictLinkUDF(col($(featuresCol))))
}
- output.toDF
+ output.toDF()
}
- private var trainingSummary: Option[GeneralizedLinearRegressionSummary] = None
+ private var trainingSummary: Option[GeneralizedLinearRegressionTrainingSummary] = None
/**
* Gets R-like summary of model on training set. An exception is
- * thrown if `trainingSummary == None`.
+ * thrown if there is no summary available.
*/
@Since("2.0.0")
- def summary: GeneralizedLinearRegressionSummary = trainingSummary.getOrElse {
+ def summary: GeneralizedLinearRegressionTrainingSummary = trainingSummary.getOrElse {
throw new SparkException(
"No training summary available for this GeneralizedLinearRegressionModel")
}
- private[regression] def setSummary(summary: GeneralizedLinearRegressionSummary): this.type = {
+ /**
+ * Indicates if [[summary]] is available.
+ */
+ @Since("2.0.0")
+ def hasSummary: Boolean = trainingSummary.nonEmpty
+
+ private[regression]
+ def setSummary(summary: GeneralizedLinearRegressionTrainingSummary): this.type = {
this.trainingSummary = Some(summary)
this
}
/**
- * If the prediction column is set returns the current model and prediction column,
- * otherwise generates a new column and sets it as the prediction column on a new copy
- * of the current model.
+ * Evaluate the model on the given dataset, returning a summary of the results.
*/
- private[regression] def findSummaryModelAndPredictionCol()
- : (GeneralizedLinearRegressionModel, String) = {
- $(predictionCol) match {
- case "" =>
- val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString
- (copy(ParamMap.empty).setPredictionCol(predictionColName), predictionColName)
- case p => (this, p)
- }
+ @Since("2.0.0")
+ def evaluate(dataset: Dataset[_]): GeneralizedLinearRegressionSummary = {
+ new GeneralizedLinearRegressionSummary(dataset, this)
}
@Since("2.0.0")
@@ -834,36 +836,55 @@ object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegr
/**
* :: Experimental ::
- * Summarizing Generalized Linear regression Fits.
+ * Summary of [[GeneralizedLinearRegression]] model and predictions.
*
- * @param predictions predictions output by the model's `transform` method
- * @param predictionCol field in "predictions" which gives the prediction value of each instance
- * @param model the model that should be summarized
- * @param diagInvAtWA diagonal of matrix (A^T * W * A)^-1 in the last iteration
- * @param numIterations number of iterations
- * @param solver the solver algorithm used for model training
+ * @param dataset Dataset to be summarized.
+ * @param origModel Model to be summarized. This is copied to create an internal
+ * model which cannot be modified from outside.
*/
@Since("2.0.0")
@Experimental
class GeneralizedLinearRegressionSummary private[regression] (
- @Since("2.0.0") @transient val predictions: DataFrame,
- @Since("2.0.0") val predictionCol: String,
- @Since("2.0.0") val model: GeneralizedLinearRegressionModel,
- private val diagInvAtWA: Array[Double],
- @Since("2.0.0") val numIterations: Int,
- @Since("2.0.0") val solver: String) extends Serializable {
+ dataset: Dataset[_],
+ origModel: GeneralizedLinearRegressionModel) extends Serializable {
import GeneralizedLinearRegression._
- private lazy val family = Family.fromName(model.getFamily)
- private lazy val link = if (model.isDefined(model.getParam("link"))) {
+ /**
+ * Field in "predictions" which gives the prediction value of each instance.
+ * This is set to a new column name if the original model's `predictionCol` is not set.
+ */
+ @Since("2.0.0")
+ val predictionCol: String = {
+ if (origModel.isDefined(origModel.predictionCol) && origModel.getPredictionCol != "") {
+ origModel.getPredictionCol
+ } else {
+ "prediction_" + java.util.UUID.randomUUID.toString
+ }
+ }
+
+ /**
+ * Private copy of model to ensure Params are not modified outside this class.
+ * Coefficients is not a deep copy, but that is acceptable.
+ *
+ * NOTE: [[predictionCol]] must be set correctly before the value of [[model]] is set,
+ * and [[model]] must be set before [[predictions]] is set!
+ */
+ protected val model: GeneralizedLinearRegressionModel =
+ origModel.copy(ParamMap.empty).setPredictionCol(predictionCol)
+
+ /** predictions output by the model's `transform` method */
+ @Since("2.0.0") @transient val predictions: DataFrame = model.transform(dataset)
+
+ private[regression] lazy val family: Family = Family.fromName(model.getFamily)
+ private[regression] lazy val link: Link = if (model.isDefined(model.link)) {
Link.fromName(model.getLink)
} else {
family.defaultLink
}
/** Number of instances in DataFrame predictions */
- private lazy val numInstances: Long = predictions.count()
+ private[regression] lazy val numInstances: Long = predictions.count()
/** The numeric rank of the fitted linear model */
@Since("2.0.0")
@@ -891,7 +912,7 @@ class GeneralizedLinearRegressionSummary private[regression] (
numInstances
}
- private lazy val devianceResiduals: DataFrame = {
+ private[regression] lazy val devianceResiduals: DataFrame = {
val drUDF = udf { (y: Double, mu: Double, weight: Double) =>
val r = math.sqrt(math.max(family.deviance(y, mu, weight), 0.0))
if (y > mu) r else -1.0 * r
@@ -901,19 +922,19 @@ class GeneralizedLinearRegressionSummary private[regression] (
drUDF(col(model.getLabelCol), col(predictionCol), w).as("devianceResiduals"))
}
- private lazy val pearsonResiduals: DataFrame = {
+ private[regression] lazy val pearsonResiduals: DataFrame = {
val prUDF = udf { mu: Double => family.variance(mu) }
val w = if (model.getWeightCol.isEmpty) lit(1.0) else col(model.getWeightCol)
predictions.select(col(model.getLabelCol).minus(col(predictionCol))
.multiply(sqrt(w)).divide(sqrt(prUDF(col(predictionCol)))).as("pearsonResiduals"))
}
- private lazy val workingResiduals: DataFrame = {
+ private[regression] lazy val workingResiduals: DataFrame = {
val wrUDF = udf { (y: Double, mu: Double) => (y - mu) * link.deriv(mu) }
predictions.select(wrUDF(col(model.getLabelCol), col(predictionCol)).as("workingResiduals"))
}
- private lazy val responseResiduals: DataFrame = {
+ private[regression] lazy val responseResiduals: DataFrame = {
predictions.select(col(model.getLabelCol).minus(col(predictionCol)).as("responseResiduals"))
}
@@ -925,6 +946,7 @@ class GeneralizedLinearRegressionSummary private[regression] (
/**
* Get the residuals of the fitted model by type.
+ *
* @param residualsType The type of residuals which should be returned.
* Supported options: deviance, pearson, working and response.
*/
@@ -996,6 +1018,30 @@ class GeneralizedLinearRegressionSummary private[regression] (
}
family.aic(t, deviance, numInstances, weightSum) + 2 * rank
}
+}
+
+/**
+ * :: Experimental ::
+ * Summary of [[GeneralizedLinearRegression]] fitting and model.
+ *
+ * @param dataset Dataset to be summarized.
+ * @param origModel Model to be summarized. This is copied to create an internal
+ * model which cannot be modified from outside.
+ * @param diagInvAtWA diagonal of matrix (A^T * W * A)^-1 in the last iteration
+ * @param numIterations number of iterations
+ * @param solver the solver algorithm used for model training
+ */
+@Since("2.0.0")
+@Experimental
+class GeneralizedLinearRegressionTrainingSummary private[regression] (
+ dataset: Dataset[_],
+ origModel: GeneralizedLinearRegressionModel,
+ private val diagInvAtWA: Array[Double],
+ @Since("2.0.0") val numIterations: Int,
+ @Since("2.0.0") val solver: String)
+ extends GeneralizedLinearRegressionSummary(dataset, origModel) with Serializable {
+
+ import GeneralizedLinearRegression._
/**
* Standard error of estimated coefficients and intercept.