From 3665294d4e1c6ea13ee66e71cf802f1a961ab15c Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 15 Mar 2016 22:30:07 -0700 Subject: [SPARK-9837][ML] R-like summary statistics for GLMs via iteratively reweighted least squares ## What changes were proposed in this pull request? Provide R-like summary statistics for GLMs via iteratively reweighted least squares. ## How was this patch tested? unit tests. Author: Yanbo Liang Closes #11694 from yanboliang/spark-9837. --- .../optim/IterativelyReweightedLeastSquares.scala | 9 +- .../regression/GeneralizedLinearRegression.scala | 341 ++++++++++++++- .../GeneralizedLinearRegressionSuite.scala | 457 +++++++++++++++++++++ 3 files changed, 796 insertions(+), 11 deletions(-) (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala index 6aa44e6ba7..fe82324ab2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala @@ -26,10 +26,14 @@ import org.apache.spark.rdd.RDD * Model fitted by [[IterativelyReweightedLeastSquares]]. * @param coefficients model coefficients * @param intercept model intercept + * @param diagInvAtWA diagonal of matrix (A^T * W * A)^-1 in the last iteration + * @param numIterations number of iterations */ private[ml] class IterativelyReweightedLeastSquaresModel( val coefficients: DenseVector, - val intercept: Double) extends Serializable + val intercept: Double, + val diagInvAtWA: DenseVector, + val numIterations: Int) extends Serializable /** * Implements the method of iteratively reweighted least squares (IRLS) which is used to solve @@ -103,6 +107,7 @@ private[ml] class IterativelyReweightedLeastSquares( } - new IterativelyReweightedLeastSquaresModel(model.coefficients, model.intercept) + new IterativelyReweightedLeastSquaresModel( + model.coefficients, model.intercept, model.diagInvAtWA, iter) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index de1dff9421..b4e47c8073 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.regression -import breeze.stats.distributions.{Gaussian => GD} +import breeze.stats.{distributions => dist} import org.apache.hadoop.fs.Path import org.apache.spark.{Logging, SparkException} @@ -217,7 +217,15 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val val model = copyValues( new GeneralizedLinearRegressionModel(uid, wlsModel.coefficients, wlsModel.intercept) .setParent(this)) - return model + // Handle possible missing or invalid prediction columns + val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol() + val trainingSummary = new GeneralizedLinearRegressionSummary( + summaryModel.transform(dataset), + predictionColName, + model, + wlsModel.diagInvAtWA.toArray, + 1) + return model.setSummary(trainingSummary) } // Fit Generalized Linear Model by iteratively reweighted least squares (IRLS). @@ -229,7 +237,16 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val val model = copyValues( new GeneralizedLinearRegressionModel(uid, irlsModel.coefficients, irlsModel.intercept) .setParent(this)) - model + // Handle possible missing or invalid prediction columns + val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol() + val trainingSummary = new GeneralizedLinearRegressionSummary( + summaryModel.transform(dataset), + predictionColName, + model, + irlsModel.diagInvAtWA.toArray, + irlsModel.numIterations) + + model.setSummary(trainingSummary) } @Since("2.0.0") @@ -318,6 +335,22 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine /** The variance of the endogenous variable's mean, given the value mu. */ def variance(mu: Double): Double + /** Deviance of (y, mu) pair. */ + def deviance(y: Double, mu: Double, weight: Double): Double + + /** + * Akaike's 'An Information Criterion'(AIC) value of the family for a given dataset. + * @param predictions an RDD of (y, mu, weight) of instances in evaluation dataset + * @param deviance the deviance for the fitted model in evaluation dataset + * @param numInstances number of instances in evaluation dataset + * @param weightSum weights sum of instances in evaluation dataset + */ + def aic( + predictions: RDD[(Double, Double, Double)], + deviance: Double, + numInstances: Double, + weightSum: Double): Double + /** Trim the fitted value so that it will be in valid range. */ def project(mu: Double): Double = mu } @@ -348,7 +381,20 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine override def initialize(y: Double, weight: Double): Double = y - def variance(mu: Double): Double = 1.0 + override def variance(mu: Double): Double = 1.0 + + override def deviance(y: Double, mu: Double, weight: Double): Double = { + weight * (y - mu) * (y - mu) + } + + override def aic( + predictions: RDD[(Double, Double, Double)], + deviance: Double, + numInstances: Double, + weightSum: Double): Double = { + val wt = predictions.map(x => math.log(x._3)).sum() + numInstances * (math.log(deviance / numInstances * 2.0 * math.Pi) + 1.0) + 2.0 - wt + } override def project(mu: Double): Double = { if (mu.isNegInfinity) { @@ -378,6 +424,22 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine override def variance(mu: Double): Double = mu * (1.0 - mu) + override def deviance(y: Double, mu: Double, weight: Double): Double = { + val my = 1.0 - y + 2.0 * weight * (y * math.log(math.max(y, 1.0) / mu) + + my * math.log(math.max(my, 1.0) / (1.0 - mu))) + } + + override def aic( + predictions: RDD[(Double, Double, Double)], + deviance: Double, + numInstances: Double, + weightSum: Double): Double = { + -2.0 * predictions.map { case (y: Double, mu: Double, weight: Double) => + weight * dist.Binomial(1, mu).logProbabilityOf(math.round(y).toInt) + }.sum() + } + override def project(mu: Double): Double = { if (mu < epsilon) { epsilon @@ -405,6 +467,20 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine override def variance(mu: Double): Double = mu + override def deviance(y: Double, mu: Double, weight: Double): Double = { + 2.0 * weight * (y * math.log(y / mu) - (y - mu)) + } + + override def aic( + predictions: RDD[(Double, Double, Double)], + deviance: Double, + numInstances: Double, + weightSum: Double): Double = { + -2.0 * predictions.map { case (y: Double, mu: Double, weight: Double) => + weight * dist.Poisson(mu).logProbabilityOf(y.toInt) + }.sum() + } + override def project(mu: Double): Double = { if (mu < epsilon) { epsilon @@ -430,7 +506,22 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine y } - override def variance(mu: Double): Double = math.pow(mu, 2.0) + override def variance(mu: Double): Double = mu * mu + + override def deviance(y: Double, mu: Double, weight: Double): Double = { + -2.0 * weight * (math.log(y / mu) - (y - mu)/mu) + } + + override def aic( + predictions: RDD[(Double, Double, Double)], + deviance: Double, + numInstances: Double, + weightSum: Double): Double = { + val disp = deviance / weightSum + -2.0 * predictions.map { case (y: Double, mu: Double, weight: Double) => + weight * dist.Gamma(1.0 / disp, mu * disp).logPdf(y) + }.sum() + 2.0 + } override def project(mu: Double): Double = { if (mu < epsilon) { @@ -519,11 +610,13 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine private[ml] object Probit extends Link("probit") { - override def link(mu: Double): Double = GD(0.0, 1.0).icdf(mu) + override def link(mu: Double): Double = dist.Gaussian(0.0, 1.0).icdf(mu) - override def deriv(mu: Double): Double = 1.0 / GD(0.0, 1.0).pdf(GD(0.0, 1.0).icdf(mu)) + override def deriv(mu: Double): Double = { + 1.0 / dist.Gaussian(0.0, 1.0).pdf(dist.Gaussian(0.0, 1.0).icdf(mu)) + } - override def unlink(eta: Double): Double = GD(0.0, 1.0).cdf(eta) + override def unlink(eta: Double): Double = dist.Gaussian(0.0, 1.0).cdf(eta) } private[ml] object CLogLog extends Link("cloglog") { @@ -541,7 +634,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine override def deriv(mu: Double): Double = 1.0 / (2.0 * math.sqrt(mu)) - override def unlink(eta: Double): Double = math.pow(eta, 2.0) + override def unlink(eta: Double): Double = eta * eta } } @@ -573,6 +666,39 @@ class GeneralizedLinearRegressionModel private[ml] ( familyAndLink.fitted(eta) } + private var trainingSummary: Option[GeneralizedLinearRegressionSummary] = None + + /** + * Gets R-like summary of model on training set. An exception is + * thrown if `trainingSummary == None`. + */ + @Since("2.0.0") + def summary: GeneralizedLinearRegressionSummary = trainingSummary.getOrElse { + throw new SparkException( + "No training summary available for this GeneralizedLinearRegressionModel", + new RuntimeException()) + } + + private[regression] def setSummary(summary: GeneralizedLinearRegressionSummary): this.type = { + this.trainingSummary = Some(summary) + this + } + + /** + * If the prediction column is set returns the current model and prediction column, + * otherwise generates a new column and sets it as the prediction column on a new copy + * of the current model. + */ + private[regression] def findSummaryModelAndPredictionCol() + : (GeneralizedLinearRegressionModel, String) = { + $(predictionCol) match { + case "" => + val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString() + (copy(ParamMap.empty).setPredictionCol(predictionColName), predictionColName) + case p => (this, p) + } + } + @Since("2.0.0") override def copy(extra: ParamMap): GeneralizedLinearRegressionModel = { copyValues(new GeneralizedLinearRegressionModel(uid, coefficients, intercept), extra) @@ -633,3 +759,200 @@ object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegr } } } + +/** + * :: Experimental :: + * Summarizing Generalized Linear regression Fits. + * + * @param predictions predictions outputted by the model's `transform` method + * @param predictionCol field in "predictions" which gives the prediction value of each instance + * @param model the model that should be summarized + * @param diagInvAtWA diagonal of matrix (A^T * W * A)^-1 in the last iteration + * @param numIterations number of iterations + */ +@Since("2.0.0") +@Experimental +class GeneralizedLinearRegressionSummary private[regression] ( + @Since("2.0.0") @transient val predictions: DataFrame, + @Since("2.0.0") val predictionCol: String, + @Since("2.0.0") val model: GeneralizedLinearRegressionModel, + private val diagInvAtWA: Array[Double], + @Since("2.0.0") val numIterations: Int) extends Serializable { + + import GeneralizedLinearRegression._ + + private lazy val family = Family.fromName(model.getFamily) + private lazy val link = if (model.isDefined(model.getParam("link"))) { + Link.fromName(model.getLink) + } else { + family.defaultLink + } + + /** Number of instances in DataFrame predictions */ + private lazy val numInstances: Long = predictions.count() + + /** The numeric rank of the fitted linear model */ + @Since("2.0.0") + lazy val rank: Long = if (model.getFitIntercept) { + model.coefficients.size + 1 + } else { + model.coefficients.size + } + + /** Degrees of freedom */ + @Since("2.0.0") + lazy val degreesOfFreedom: Long = { + numInstances - rank + } + + /** The residual degrees of freedom */ + @Since("2.0.0") + lazy val residualDegreeOfFreedom: Long = degreesOfFreedom + + /** The residual degrees of freedom for the null model */ + @Since("2.0.0") + lazy val residualDegreeOfFreedomNull: Long = if (model.getFitIntercept) { + numInstances - 1 + } else { + numInstances + } + + private lazy val devianceResiduals: DataFrame = { + val drUDF = udf { (y: Double, mu: Double, weight: Double) => + val r = math.sqrt(math.max(family.deviance(y, mu, weight), 0.0)) + if (y > mu) r else -1.0 * r + } + val w = if (model.getWeightCol.isEmpty) lit(1.0) else col(model.getWeightCol) + predictions.select( + drUDF(col(model.getLabelCol), col(predictionCol), w).as("devianceResiduals")) + } + + private lazy val pearsonResiduals: DataFrame = { + val prUDF = udf { mu: Double => family.variance(mu) } + val w = if (model.getWeightCol.isEmpty) lit(1.0) else col(model.getWeightCol) + predictions.select(col(model.getLabelCol).minus(col(predictionCol)) + .multiply(sqrt(w)).divide(sqrt(prUDF(col(predictionCol)))).as("pearsonResiduals")) + } + + private lazy val workingResiduals: DataFrame = { + val wrUDF = udf { (y: Double, mu: Double) => (y - mu) * link.deriv(mu) } + predictions.select(wrUDF(col(model.getLabelCol), col(predictionCol)).as("workingResiduals")) + } + + private lazy val responseResiduals: DataFrame = { + predictions.select(col(model.getLabelCol).minus(col(predictionCol)).as("responseResiduals")) + } + + /** + * Get the default residuals(deviance residuals) of the fitted model. + */ + @Since("2.0.0") + def residuals(): DataFrame = devianceResiduals + + /** + * Get the residuals of the fitted model by type. + * @param residualsType The type of residuals which should be returned. + * Supported options: deviance, pearson, working and response. + */ + @Since("2.0.0") + def residuals(residualsType: String): DataFrame = { + residualsType match { + case "deviance" => devianceResiduals + case "pearson" => pearsonResiduals + case "working" => workingResiduals + case "response" => responseResiduals + case other => throw new UnsupportedOperationException( + s"The residuals type $other is not supported by Generalized Linear Regression.") + } + } + + /** + * The deviance for the null model. + */ + @Since("2.0.0") + lazy val nullDeviance: Double = { + val w = if (model.getWeightCol.isEmpty) lit(1.0) else col(model.getWeightCol) + val wtdmu: Double = if (model.getFitIntercept) { + val agg = predictions.agg(sum(w.multiply(col(model.getLabelCol))), sum(w)).first() + agg.getDouble(0) / agg.getDouble(1) + } else { + link.unlink(0.0) + } + predictions.select(col(model.getLabelCol), w).rdd.map { + case Row(y: Double, weight: Double) => + family.deviance(y, wtdmu, weight) + }.sum() + } + + /** + * The deviance for the fitted model. + */ + @Since("2.0.0") + lazy val deviance: Double = { + val w = if (model.getWeightCol.isEmpty) lit(1.0) else col(model.getWeightCol) + predictions.select(col(model.getLabelCol), col(predictionCol), w).rdd.map { + case Row(label: Double, pred: Double, weight: Double) => + family.deviance(label, pred, weight) + }.sum() + } + + /** + * The dispersion of the fitted model. + * It is taken as 1.0 for the "binomial" and "poisson" families, and otherwise + * estimated by the residual Pearson's Chi-Squared statistic(which is defined as + * sum of the squares of the Pearson residuals) divided by the residual degrees of freedom. + */ + @Since("2.0.0") + lazy val dispersion: Double = if ( + model.getFamily == Binomial.name || model.getFamily == Poisson.name) { + 1.0 + } else { + val rss = pearsonResiduals.agg(sum(pow(col("pearsonResiduals"), 2.0))).first().getDouble(0) + rss / degreesOfFreedom + } + + /** Akaike's "An Information Criterion"(AIC) for the fitted model. */ + @Since("2.0.0") + lazy val aic: Double = { + val w = if (model.getWeightCol.isEmpty) lit(1.0) else col(model.getWeightCol) + val weightSum = predictions.select(w).agg(sum(w)).first().getDouble(0) + val t = predictions.select(col(model.getLabelCol), col(predictionCol), w).rdd.map { + case Row(label: Double, pred: Double, weight: Double) => + (label, pred, weight) + } + family.aic(t, deviance, numInstances, weightSum) + 2 * rank + } + + /** + * Standard error of estimated coefficients and intercept. + */ + @Since("2.0.0") + lazy val coefficientStandardErrors: Array[Double] = { + diagInvAtWA.map(_ * dispersion).map(math.sqrt) + } + + /** + * T-statistic of estimated coefficients and intercept. + */ + @Since("2.0.0") + lazy val tValues: Array[Double] = { + val estimate = if (model.getFitIntercept) { + Array.concat(model.coefficients.toArray, Array(model.intercept)) + } else { + model.coefficients.toArray + } + estimate.zip(coefficientStandardErrors).map { x => x._1 / x._2 } + } + + /** + * Two-sided p-value of estimated coefficients and intercept. + */ + @Since("2.0.0") + lazy val pValues: Array[Double] = { + if (model.getFamily == Binomial.name || model.getFamily == Poisson.name) { + tValues.map { x => 2.0 * (1.0 - dist.Gaussian(0.0, 1.0).cdf(math.abs(x))) } + } else { + tValues.map { x => 2.0 * (1.0 - dist.StudentsT(degreesOfFreedom.toDouble).cdf(math.abs(x))) } + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index 618304ad19..6d570f7bde 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.regression import scala.util.Random import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.classification.LogisticRegressionSuite._ @@ -29,6 +30,7 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions._ class GeneralizedLinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -466,6 +468,461 @@ class GeneralizedLinearRegressionSuite } } + test("glm summary: gaussian family with weight") { + /* + R code: + + A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2) + b <- c(17, 19, 23, 29) + w <- c(1, 2, 3, 4) + df <- as.data.frame(cbind(A, b)) + */ + val datasetWithWeight = sqlContext.createDataFrame(sc.parallelize(Seq( + Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), + Instance(19.0, 2.0, Vectors.dense(1.0, 7.0)), + Instance(23.0, 3.0, Vectors.dense(2.0, 11.0)), + Instance(29.0, 4.0, Vectors.dense(3.0, 13.0)) + ), 2)) + /* + R code: + + model <- glm(formula = "b ~ .", family="gaussian", data = df, weights = w) + summary(model) + + Deviance Residuals: + 1 2 3 4 + 1.920 -1.358 -1.109 0.960 + + Coefficients: + Estimate Std. Error t value Pr(>|t|) + (Intercept) 18.080 9.608 1.882 0.311 + V1 6.080 5.556 1.094 0.471 + V2 -0.600 1.960 -0.306 0.811 + + (Dispersion parameter for gaussian family taken to be 7.68) + + Null deviance: 202.00 on 3 degrees of freedom + Residual deviance: 7.68 on 1 degrees of freedom + AIC: 18.783 + + Number of Fisher Scoring iterations: 2 + + residuals(model, type="pearson") + 1 2 3 4 + 1.920000 -1.357645 -1.108513 0.960000 + + residuals(model, type="working") + 1 2 3 4 + 1.92 -0.96 -0.64 0.48 + + residuals(model, type="response") + 1 2 3 4 + 1.92 -0.96 -0.64 0.48 + */ + val trainer = new GeneralizedLinearRegression() + .setWeightCol("weight") + + val model = trainer.fit(datasetWithWeight) + + val coefficientsR = Vectors.dense(Array(6.080, -0.600)) + val interceptR = 18.080 + val devianceResidualsR = Array(1.920, -1.358, -1.109, 0.960) + val pearsonResidualsR = Array(1.920000, -1.357645, -1.108513, 0.960000) + val workingResidualsR = Array(1.92, -0.96, -0.64, 0.48) + val responseResidualsR = Array(1.92, -0.96, -0.64, 0.48) + val seCoefR = Array(5.556, 1.960, 9.608) + val tValsR = Array(1.094, -0.306, 1.882) + val pValsR = Array(0.471, 0.811, 0.311) + val dispersionR = 7.68 + val nullDevianceR = 202.00 + val residualDevianceR = 7.68 + val residualDegreeOfFreedomNullR = 3 + val residualDegreeOfFreedomR = 1 + val aicR = 18.783 + + val summary = model.summary + + val devianceResiduals = summary.residuals() + .select(col("devianceResiduals")) + .collect() + .map(_.getDouble(0)) + val pearsonResiduals = summary.residuals("pearson") + .select(col("pearsonResiduals")) + .collect() + .map(_.getDouble(0)) + val workingResiduals = summary.residuals("working") + .select(col("workingResiduals")) + .collect() + .map(_.getDouble(0)) + val responseResiduals = summary.residuals("response") + .select(col("responseResiduals")) + .collect() + .map(_.getDouble(0)) + + assert(model.coefficients ~== coefficientsR absTol 1E-3) + assert(model.intercept ~== interceptR absTol 1E-3) + devianceResiduals.zip(devianceResidualsR).foreach { x => + assert(x._1 ~== x._2 absTol 1E-3) } + pearsonResiduals.zip(pearsonResidualsR).foreach { x => + assert(x._1 ~== x._2 absTol 1E-3) } + workingResiduals.zip(workingResidualsR).foreach { x => + assert(x._1 ~== x._2 absTol 1E-3) } + responseResiduals.zip(responseResidualsR).foreach { x => + assert(x._1 ~== x._2 absTol 1E-3) } + summary.coefficientStandardErrors.zip(seCoefR).foreach{ x => + assert(x._1 ~== x._2 absTol 1E-3) } + summary.tValues.zip(tValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } + summary.pValues.zip(pValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } + assert(summary.dispersion ~== dispersionR absTol 1E-3) + assert(summary.nullDeviance ~== nullDevianceR absTol 1E-3) + assert(summary.deviance ~== residualDevianceR absTol 1E-3) + assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR) + assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR) + assert(summary.aic ~== aicR absTol 1E-3) + } + + test("glm summary: binomial family with weight") { + /* + R code: + + A <- matrix(c(0, 1, 2, 3, 5, 2, 1, 3), 4, 2) + b <- c(1, 0, 1, 0) + w <- c(1, 2, 3, 4) + df <- as.data.frame(cbind(A, b)) + */ + val datasetWithWeight = sqlContext.createDataFrame(sc.parallelize(Seq( + Instance(1.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), + Instance(0.0, 2.0, Vectors.dense(1.0, 2.0)), + Instance(1.0, 3.0, Vectors.dense(2.0, 1.0)), + Instance(0.0, 4.0, Vectors.dense(3.0, 3.0)) + ), 2)) + /* + R code: + + model <- glm(formula = "b ~ . -1", family="binomial", data = df, weights = w) + summary(model) + + Deviance Residuals: + 1 2 3 4 + 1.273 -1.437 2.533 -1.556 + + Coefficients: + Estimate Std. Error z value Pr(>|z|) + V1 -0.30217 0.46242 -0.653 0.513 + V2 -0.04452 0.37124 -0.120 0.905 + + (Dispersion parameter for binomial family taken to be 1) + + Null deviance: 13.863 on 4 degrees of freedom + Residual deviance: 12.524 on 2 degrees of freedom + AIC: 16.524 + + Number of Fisher Scoring iterations: 5 + + residuals(model, type="pearson") + 1 2 3 4 + 1.117731 -1.162962 2.395838 -1.189005 + + residuals(model, type="working") + 1 2 3 4 + 2.249324 -1.676240 2.913346 -1.353433 + + residuals(model, type="response") + 1 2 3 4 + 0.5554219 -0.4034267 0.6567520 -0.2611382 + */ + val trainer = new GeneralizedLinearRegression() + .setFamily("binomial") + .setWeightCol("weight") + .setFitIntercept(false) + + val model = trainer.fit(datasetWithWeight) + + val coefficientsR = Vectors.dense(Array(-0.30217, -0.04452)) + val interceptR = 0.0 + val devianceResidualsR = Array(1.273, -1.437, 2.533, -1.556) + val pearsonResidualsR = Array(1.117731, -1.162962, 2.395838, -1.189005) + val workingResidualsR = Array(2.249324, -1.676240, 2.913346, -1.353433) + val responseResidualsR = Array(0.5554219, -0.4034267, 0.6567520, -0.2611382) + val seCoefR = Array(0.46242, 0.37124) + val tValsR = Array(-0.653, -0.120) + val pValsR = Array(0.513, 0.905) + val dispersionR = 1.0 + val nullDevianceR = 13.863 + val residualDevianceR = 12.524 + val residualDegreeOfFreedomNullR = 4 + val residualDegreeOfFreedomR = 2 + val aicR = 16.524 + + val summary = model.summary + val devianceResiduals = summary.residuals() + .select(col("devianceResiduals")) + .collect() + .map(_.getDouble(0)) + val pearsonResiduals = summary.residuals("pearson") + .select(col("pearsonResiduals")) + .collect() + .map(_.getDouble(0)) + val workingResiduals = summary.residuals("working") + .select(col("workingResiduals")) + .collect() + .map(_.getDouble(0)) + val responseResiduals = summary.residuals("response") + .select(col("responseResiduals")) + .collect() + .map(_.getDouble(0)) + + assert(model.coefficients ~== coefficientsR absTol 1E-3) + assert(model.intercept ~== interceptR absTol 1E-3) + devianceResiduals.zip(devianceResidualsR).foreach { x => + assert(x._1 ~== x._2 absTol 1E-3) } + pearsonResiduals.zip(pearsonResidualsR).foreach { x => + assert(x._1 ~== x._2 absTol 1E-3) } + workingResiduals.zip(workingResidualsR).foreach { x => + assert(x._1 ~== x._2 absTol 1E-3) } + responseResiduals.zip(responseResidualsR).foreach { x => + assert(x._1 ~== x._2 absTol 1E-3) } + summary.coefficientStandardErrors.zip(seCoefR).foreach{ x => + assert(x._1 ~== x._2 absTol 1E-3) } + summary.tValues.zip(tValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } + summary.pValues.zip(pValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } + assert(summary.dispersion ~== dispersionR absTol 1E-3) + assert(summary.nullDeviance ~== nullDevianceR absTol 1E-3) + assert(summary.deviance ~== residualDevianceR absTol 1E-3) + assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR) + assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR) + assert(summary.aic ~== aicR absTol 1E-3) + } + + test("glm summary: poisson family with weight") { + /* + R code: + + A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2) + b <- c(2, 8, 3, 9) + w <- c(1, 2, 3, 4) + df <- as.data.frame(cbind(A, b)) + */ + val datasetWithWeight = sqlContext.createDataFrame(sc.parallelize(Seq( + Instance(2.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), + Instance(8.0, 2.0, Vectors.dense(1.0, 7.0)), + Instance(3.0, 3.0, Vectors.dense(2.0, 11.0)), + Instance(9.0, 4.0, Vectors.dense(3.0, 13.0)) + ), 2)) + /* + R code: + + model <- glm(formula = "b ~ .", family="poisson", data = df, weights = w) + summary(model) + + Deviance Residuals: + 1 2 3 4 + -0.28952 0.11048 0.14839 -0.07268 + + Coefficients: + Estimate Std. Error z value Pr(>|z|) + (Intercept) 6.2999 1.6086 3.916 8.99e-05 *** + V1 3.3241 1.0184 3.264 0.00110 ** + V2 -1.0818 0.3522 -3.071 0.00213 ** + --- + Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1 + + (Dispersion parameter for poisson family taken to be 1) + + Null deviance: 15.38066 on 3 degrees of freedom + Residual deviance: 0.12333 on 1 degrees of freedom + AIC: 41.803 + + Number of Fisher Scoring iterations: 3 + + residuals(model, type="pearson") + 1 2 3 4 + -0.28043145 0.11099310 0.14963714 -0.07253611 + + residuals(model, type="working") + 1 2 3 4 + -0.17960679 0.02813593 0.05113852 -0.01201650 + + residuals(model, type="response") + 1 2 3 4 + -0.4378554 0.2189277 0.1459518 -0.1094638 + */ + val trainer = new GeneralizedLinearRegression() + .setFamily("poisson") + .setWeightCol("weight") + .setFitIntercept(true) + + val model = trainer.fit(datasetWithWeight) + + val coefficientsR = Vectors.dense(Array(3.3241, -1.0818)) + val interceptR = 6.2999 + val devianceResidualsR = Array(-0.28952, 0.11048, 0.14839, -0.07268) + val pearsonResidualsR = Array(-0.28043145, 0.11099310, 0.14963714, -0.07253611) + val workingResidualsR = Array(-0.17960679, 0.02813593, 0.05113852, -0.01201650) + val responseResidualsR = Array(-0.4378554, 0.2189277, 0.1459518, -0.1094638) + val seCoefR = Array(1.0184, 0.3522, 1.6086) + val tValsR = Array(3.264, -3.071, 3.916) + val pValsR = Array(0.00110, 0.00213, 0.00009) + val dispersionR = 1.0 + val nullDevianceR = 15.38066 + val residualDevianceR = 0.12333 + val residualDegreeOfFreedomNullR = 3 + val residualDegreeOfFreedomR = 1 + val aicR = 41.803 + + val summary = model.summary + val devianceResiduals = summary.residuals() + .select(col("devianceResiduals")) + .collect() + .map(_.getDouble(0)) + val pearsonResiduals = summary.residuals("pearson") + .select(col("pearsonResiduals")) + .collect() + .map(_.getDouble(0)) + val workingResiduals = summary.residuals("working") + .select(col("workingResiduals")) + .collect() + .map(_.getDouble(0)) + val responseResiduals = summary.residuals("response") + .select(col("responseResiduals")) + .collect() + .map(_.getDouble(0)) + + assert(model.coefficients ~== coefficientsR absTol 1E-3) + assert(model.intercept ~== interceptR absTol 1E-3) + devianceResiduals.zip(devianceResidualsR).foreach { x => + assert(x._1 ~== x._2 absTol 1E-3) } + pearsonResiduals.zip(pearsonResidualsR).foreach { x => + assert(x._1 ~== x._2 absTol 1E-3) } + workingResiduals.zip(workingResidualsR).foreach { x => + assert(x._1 ~== x._2 absTol 1E-3) } + responseResiduals.zip(responseResidualsR).foreach { x => + assert(x._1 ~== x._2 absTol 1E-3) } + summary.coefficientStandardErrors.zip(seCoefR).foreach{ x => + assert(x._1 ~== x._2 absTol 1E-3) } + summary.tValues.zip(tValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } + summary.pValues.zip(pValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } + assert(summary.dispersion ~== dispersionR absTol 1E-3) + assert(summary.nullDeviance ~== nullDevianceR absTol 1E-3) + assert(summary.deviance ~== residualDevianceR absTol 1E-3) + assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR) + assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR) + assert(summary.aic ~== aicR absTol 1E-3) + } + + test("glm summary: gamma family with weight") { + /* + R code: + + A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2) + b <- c(2, 8, 3, 9) + w <- c(1, 2, 3, 4) + df <- as.data.frame(cbind(A, b)) + */ + val datasetWithWeight = sqlContext.createDataFrame(sc.parallelize(Seq( + Instance(2.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), + Instance(8.0, 2.0, Vectors.dense(1.0, 7.0)), + Instance(3.0, 3.0, Vectors.dense(2.0, 11.0)), + Instance(9.0, 4.0, Vectors.dense(3.0, 13.0)) + ), 2)) + /* + R code: + + model <- glm(formula = "b ~ .", family="Gamma", data = df, weights = w) + summary(model) + + Deviance Residuals: + 1 2 3 4 + -0.26343 0.05761 0.12818 -0.03484 + + Coefficients: + Estimate Std. Error t value Pr(>|t|) + (Intercept) -0.81511 0.23449 -3.476 0.178 + V1 -0.72730 0.16137 -4.507 0.139 + V2 0.23894 0.05481 4.359 0.144 + + (Dispersion parameter for Gamma family taken to be 0.07986091) + + Null deviance: 2.937462 on 3 degrees of freedom + Residual deviance: 0.090358 on 1 degrees of freedom + AIC: 23.202 + + Number of Fisher Scoring iterations: 4 + + residuals(model, type="pearson") + 1 2 3 4 + -0.24082508 0.05839241 0.13135766 -0.03463621 + + residuals(model, type="working") + 1 2 3 4 + 0.091414181 -0.005374314 -0.027196998 0.001890910 + + residuals(model, type="response") + 1 2 3 4 + -0.6344390 0.3172195 0.2114797 -0.1586097 + */ + val trainer = new GeneralizedLinearRegression() + .setFamily("gamma") + .setWeightCol("weight") + + val model = trainer.fit(datasetWithWeight) + + val coefficientsR = Vectors.dense(Array(-0.72730, 0.23894)) + val interceptR = -0.81511 + val devianceResidualsR = Array(-0.26343, 0.05761, 0.12818, -0.03484) + val pearsonResidualsR = Array(-0.24082508, 0.05839241, 0.13135766, -0.03463621) + val workingResidualsR = Array(0.091414181, -0.005374314, -0.027196998, 0.001890910) + val responseResidualsR = Array(-0.6344390, 0.3172195, 0.2114797, -0.1586097) + val seCoefR = Array(0.16137, 0.05481, 0.23449) + val tValsR = Array(-4.507, 4.359, -3.476) + val pValsR = Array(0.139, 0.144, 0.178) + val dispersionR = 0.07986091 + val nullDevianceR = 2.937462 + val residualDevianceR = 0.090358 + val residualDegreeOfFreedomNullR = 3 + val residualDegreeOfFreedomR = 1 + val aicR = 23.202 + + val summary = model.summary + val devianceResiduals = summary.residuals() + .select(col("devianceResiduals")) + .collect() + .map(_.getDouble(0)) + val pearsonResiduals = summary.residuals("pearson") + .select(col("pearsonResiduals")) + .collect() + .map(_.getDouble(0)) + val workingResiduals = summary.residuals("working") + .select(col("workingResiduals")) + .collect() + .map(_.getDouble(0)) + val responseResiduals = summary.residuals("response") + .select(col("responseResiduals")) + .collect() + .map(_.getDouble(0)) + + assert(model.coefficients ~== coefficientsR absTol 1E-3) + assert(model.intercept ~== interceptR absTol 1E-3) + devianceResiduals.zip(devianceResidualsR).foreach { x => + assert(x._1 ~== x._2 absTol 1E-3) } + pearsonResiduals.zip(pearsonResidualsR).foreach { x => + assert(x._1 ~== x._2 absTol 1E-3) } + workingResiduals.zip(workingResidualsR).foreach { x => + assert(x._1 ~== x._2 absTol 1E-3) } + responseResiduals.zip(responseResidualsR).foreach { x => + assert(x._1 ~== x._2 absTol 1E-3) } + summary.coefficientStandardErrors.zip(seCoefR).foreach{ x => + assert(x._1 ~== x._2 absTol 1E-3) } + summary.tValues.zip(tValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } + summary.pValues.zip(pValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } + assert(summary.dispersion ~== dispersionR absTol 1E-3) + assert(summary.nullDeviance ~== nullDevianceR absTol 1E-3) + assert(summary.deviance ~== residualDevianceR absTol 1E-3) + assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR) + assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR) + assert(summary.aic ~== aicR absTol 1E-3) + } + test("read/write") { def checkModelData( model: GeneralizedLinearRegressionModel, -- cgit v1.2.3