aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-03-15 22:30:07 -0700
committerXiangrui Meng <meng@databricks.com>2016-03-15 22:30:07 -0700
commit3665294d4e1c6ea13ee66e71cf802f1a961ab15c (patch)
treea6e10ebf8cd71a1be6055ffa631f466395de2efb /mllib
parent421f6c20e85b32f6462d37dad6a62dec2d46ed88 (diff)
downloadspark-3665294d4e1c6ea13ee66e71cf802f1a961ab15c.tar.gz
spark-3665294d4e1c6ea13ee66e71cf802f1a961ab15c.tar.bz2
spark-3665294d4e1c6ea13ee66e71cf802f1a961ab15c.zip
[SPARK-9837][ML] R-like summary statistics for GLMs via iteratively reweighted least squares
## What changes were proposed in this pull request? Provide R-like summary statistics for GLMs via iteratively reweighted least squares. ## How was this patch tested? unit tests. Author: Yanbo Liang <ybliang8@gmail.com> Closes #11694 from yanboliang/spark-9837.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala9
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala341
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala457
3 files changed, 796 insertions, 11 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala
index 6aa44e6ba7..fe82324ab2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala
@@ -26,10 +26,14 @@ import org.apache.spark.rdd.RDD
* Model fitted by [[IterativelyReweightedLeastSquares]].
* @param coefficients model coefficients
* @param intercept model intercept
+ * @param diagInvAtWA diagonal of matrix (A^T * W * A)^-1 in the last iteration
+ * @param numIterations number of iterations
*/
private[ml] class IterativelyReweightedLeastSquaresModel(
val coefficients: DenseVector,
- val intercept: Double) extends Serializable
+ val intercept: Double,
+ val diagInvAtWA: DenseVector,
+ val numIterations: Int) extends Serializable
/**
* Implements the method of iteratively reweighted least squares (IRLS) which is used to solve
@@ -103,6 +107,7 @@ private[ml] class IterativelyReweightedLeastSquares(
}
- new IterativelyReweightedLeastSquaresModel(model.coefficients, model.intercept)
+ new IterativelyReweightedLeastSquaresModel(
+ model.coefficients, model.intercept, model.diagInvAtWA, iter)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index de1dff9421..b4e47c8073 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -17,7 +17,7 @@
package org.apache.spark.ml.regression
-import breeze.stats.distributions.{Gaussian => GD}
+import breeze.stats.{distributions => dist}
import org.apache.hadoop.fs.Path
import org.apache.spark.{Logging, SparkException}
@@ -217,7 +217,15 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
val model = copyValues(
new GeneralizedLinearRegressionModel(uid, wlsModel.coefficients, wlsModel.intercept)
.setParent(this))
- return model
+ // Handle possible missing or invalid prediction columns
+ val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol()
+ val trainingSummary = new GeneralizedLinearRegressionSummary(
+ summaryModel.transform(dataset),
+ predictionColName,
+ model,
+ wlsModel.diagInvAtWA.toArray,
+ 1)
+ return model.setSummary(trainingSummary)
}
// Fit Generalized Linear Model by iteratively reweighted least squares (IRLS).
@@ -229,7 +237,16 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
val model = copyValues(
new GeneralizedLinearRegressionModel(uid, irlsModel.coefficients, irlsModel.intercept)
.setParent(this))
- model
+ // Handle possible missing or invalid prediction columns
+ val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol()
+ val trainingSummary = new GeneralizedLinearRegressionSummary(
+ summaryModel.transform(dataset),
+ predictionColName,
+ model,
+ irlsModel.diagInvAtWA.toArray,
+ irlsModel.numIterations)
+
+ model.setSummary(trainingSummary)
}
@Since("2.0.0")
@@ -318,6 +335,22 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
/** The variance of the endogenous variable's mean, given the value mu. */
def variance(mu: Double): Double
+ /** Deviance of (y, mu) pair. */
+ def deviance(y: Double, mu: Double, weight: Double): Double
+
+ /**
+ * Akaike's 'An Information Criterion'(AIC) value of the family for a given dataset.
+ * @param predictions an RDD of (y, mu, weight) of instances in evaluation dataset
+ * @param deviance the deviance for the fitted model in evaluation dataset
+ * @param numInstances number of instances in evaluation dataset
+ * @param weightSum weights sum of instances in evaluation dataset
+ */
+ def aic(
+ predictions: RDD[(Double, Double, Double)],
+ deviance: Double,
+ numInstances: Double,
+ weightSum: Double): Double
+
/** Trim the fitted value so that it will be in valid range. */
def project(mu: Double): Double = mu
}
@@ -348,7 +381,20 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
override def initialize(y: Double, weight: Double): Double = y
- def variance(mu: Double): Double = 1.0
+ override def variance(mu: Double): Double = 1.0
+
+ override def deviance(y: Double, mu: Double, weight: Double): Double = {
+ weight * (y - mu) * (y - mu)
+ }
+
+ override def aic(
+ predictions: RDD[(Double, Double, Double)],
+ deviance: Double,
+ numInstances: Double,
+ weightSum: Double): Double = {
+ val wt = predictions.map(x => math.log(x._3)).sum()
+ numInstances * (math.log(deviance / numInstances * 2.0 * math.Pi) + 1.0) + 2.0 - wt
+ }
override def project(mu: Double): Double = {
if (mu.isNegInfinity) {
@@ -378,6 +424,22 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
override def variance(mu: Double): Double = mu * (1.0 - mu)
+ override def deviance(y: Double, mu: Double, weight: Double): Double = {
+ val my = 1.0 - y
+ 2.0 * weight * (y * math.log(math.max(y, 1.0) / mu) +
+ my * math.log(math.max(my, 1.0) / (1.0 - mu)))
+ }
+
+ override def aic(
+ predictions: RDD[(Double, Double, Double)],
+ deviance: Double,
+ numInstances: Double,
+ weightSum: Double): Double = {
+ -2.0 * predictions.map { case (y: Double, mu: Double, weight: Double) =>
+ weight * dist.Binomial(1, mu).logProbabilityOf(math.round(y).toInt)
+ }.sum()
+ }
+
override def project(mu: Double): Double = {
if (mu < epsilon) {
epsilon
@@ -405,6 +467,20 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
override def variance(mu: Double): Double = mu
+ override def deviance(y: Double, mu: Double, weight: Double): Double = {
+ 2.0 * weight * (y * math.log(y / mu) - (y - mu))
+ }
+
+ override def aic(
+ predictions: RDD[(Double, Double, Double)],
+ deviance: Double,
+ numInstances: Double,
+ weightSum: Double): Double = {
+ -2.0 * predictions.map { case (y: Double, mu: Double, weight: Double) =>
+ weight * dist.Poisson(mu).logProbabilityOf(y.toInt)
+ }.sum()
+ }
+
override def project(mu: Double): Double = {
if (mu < epsilon) {
epsilon
@@ -430,7 +506,22 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
y
}
- override def variance(mu: Double): Double = math.pow(mu, 2.0)
+ override def variance(mu: Double): Double = mu * mu
+
+ override def deviance(y: Double, mu: Double, weight: Double): Double = {
+ -2.0 * weight * (math.log(y / mu) - (y - mu)/mu)
+ }
+
+ override def aic(
+ predictions: RDD[(Double, Double, Double)],
+ deviance: Double,
+ numInstances: Double,
+ weightSum: Double): Double = {
+ val disp = deviance / weightSum
+ -2.0 * predictions.map { case (y: Double, mu: Double, weight: Double) =>
+ weight * dist.Gamma(1.0 / disp, mu * disp).logPdf(y)
+ }.sum() + 2.0
+ }
override def project(mu: Double): Double = {
if (mu < epsilon) {
@@ -519,11 +610,13 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
private[ml] object Probit extends Link("probit") {
- override def link(mu: Double): Double = GD(0.0, 1.0).icdf(mu)
+ override def link(mu: Double): Double = dist.Gaussian(0.0, 1.0).icdf(mu)
- override def deriv(mu: Double): Double = 1.0 / GD(0.0, 1.0).pdf(GD(0.0, 1.0).icdf(mu))
+ override def deriv(mu: Double): Double = {
+ 1.0 / dist.Gaussian(0.0, 1.0).pdf(dist.Gaussian(0.0, 1.0).icdf(mu))
+ }
- override def unlink(eta: Double): Double = GD(0.0, 1.0).cdf(eta)
+ override def unlink(eta: Double): Double = dist.Gaussian(0.0, 1.0).cdf(eta)
}
private[ml] object CLogLog extends Link("cloglog") {
@@ -541,7 +634,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
override def deriv(mu: Double): Double = 1.0 / (2.0 * math.sqrt(mu))
- override def unlink(eta: Double): Double = math.pow(eta, 2.0)
+ override def unlink(eta: Double): Double = eta * eta
}
}
@@ -573,6 +666,39 @@ class GeneralizedLinearRegressionModel private[ml] (
familyAndLink.fitted(eta)
}
+ private var trainingSummary: Option[GeneralizedLinearRegressionSummary] = None
+
+ /**
+ * Gets R-like summary of model on training set. An exception is
+ * thrown if `trainingSummary == None`.
+ */
+ @Since("2.0.0")
+ def summary: GeneralizedLinearRegressionSummary = trainingSummary.getOrElse {
+ throw new SparkException(
+ "No training summary available for this GeneralizedLinearRegressionModel",
+ new RuntimeException())
+ }
+
+ private[regression] def setSummary(summary: GeneralizedLinearRegressionSummary): this.type = {
+ this.trainingSummary = Some(summary)
+ this
+ }
+
+ /**
+ * If the prediction column is set returns the current model and prediction column,
+ * otherwise generates a new column and sets it as the prediction column on a new copy
+ * of the current model.
+ */
+ private[regression] def findSummaryModelAndPredictionCol()
+ : (GeneralizedLinearRegressionModel, String) = {
+ $(predictionCol) match {
+ case "" =>
+ val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString()
+ (copy(ParamMap.empty).setPredictionCol(predictionColName), predictionColName)
+ case p => (this, p)
+ }
+ }
+
@Since("2.0.0")
override def copy(extra: ParamMap): GeneralizedLinearRegressionModel = {
copyValues(new GeneralizedLinearRegressionModel(uid, coefficients, intercept), extra)
@@ -633,3 +759,200 @@ object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegr
}
}
}
+
+/**
+ * :: Experimental ::
+ * Summarizing Generalized Linear regression Fits.
+ *
+ * @param predictions predictions outputted by the model's `transform` method
+ * @param predictionCol field in "predictions" which gives the prediction value of each instance
+ * @param model the model that should be summarized
+ * @param diagInvAtWA diagonal of matrix (A^T * W * A)^-1 in the last iteration
+ * @param numIterations number of iterations
+ */
+@Since("2.0.0")
+@Experimental
+class GeneralizedLinearRegressionSummary private[regression] (
+ @Since("2.0.0") @transient val predictions: DataFrame,
+ @Since("2.0.0") val predictionCol: String,
+ @Since("2.0.0") val model: GeneralizedLinearRegressionModel,
+ private val diagInvAtWA: Array[Double],
+ @Since("2.0.0") val numIterations: Int) extends Serializable {
+
+ import GeneralizedLinearRegression._
+
+ private lazy val family = Family.fromName(model.getFamily)
+ private lazy val link = if (model.isDefined(model.getParam("link"))) {
+ Link.fromName(model.getLink)
+ } else {
+ family.defaultLink
+ }
+
+ /** Number of instances in DataFrame predictions */
+ private lazy val numInstances: Long = predictions.count()
+
+ /** The numeric rank of the fitted linear model */
+ @Since("2.0.0")
+ lazy val rank: Long = if (model.getFitIntercept) {
+ model.coefficients.size + 1
+ } else {
+ model.coefficients.size
+ }
+
+ /** Degrees of freedom */
+ @Since("2.0.0")
+ lazy val degreesOfFreedom: Long = {
+ numInstances - rank
+ }
+
+ /** The residual degrees of freedom */
+ @Since("2.0.0")
+ lazy val residualDegreeOfFreedom: Long = degreesOfFreedom
+
+ /** The residual degrees of freedom for the null model */
+ @Since("2.0.0")
+ lazy val residualDegreeOfFreedomNull: Long = if (model.getFitIntercept) {
+ numInstances - 1
+ } else {
+ numInstances
+ }
+
+ private lazy val devianceResiduals: DataFrame = {
+ val drUDF = udf { (y: Double, mu: Double, weight: Double) =>
+ val r = math.sqrt(math.max(family.deviance(y, mu, weight), 0.0))
+ if (y > mu) r else -1.0 * r
+ }
+ val w = if (model.getWeightCol.isEmpty) lit(1.0) else col(model.getWeightCol)
+ predictions.select(
+ drUDF(col(model.getLabelCol), col(predictionCol), w).as("devianceResiduals"))
+ }
+
+ private lazy val pearsonResiduals: DataFrame = {
+ val prUDF = udf { mu: Double => family.variance(mu) }
+ val w = if (model.getWeightCol.isEmpty) lit(1.0) else col(model.getWeightCol)
+ predictions.select(col(model.getLabelCol).minus(col(predictionCol))
+ .multiply(sqrt(w)).divide(sqrt(prUDF(col(predictionCol)))).as("pearsonResiduals"))
+ }
+
+ private lazy val workingResiduals: DataFrame = {
+ val wrUDF = udf { (y: Double, mu: Double) => (y - mu) * link.deriv(mu) }
+ predictions.select(wrUDF(col(model.getLabelCol), col(predictionCol)).as("workingResiduals"))
+ }
+
+ private lazy val responseResiduals: DataFrame = {
+ predictions.select(col(model.getLabelCol).minus(col(predictionCol)).as("responseResiduals"))
+ }
+
+ /**
+ * Get the default residuals(deviance residuals) of the fitted model.
+ */
+ @Since("2.0.0")
+ def residuals(): DataFrame = devianceResiduals
+
+ /**
+ * Get the residuals of the fitted model by type.
+ * @param residualsType The type of residuals which should be returned.
+ * Supported options: deviance, pearson, working and response.
+ */
+ @Since("2.0.0")
+ def residuals(residualsType: String): DataFrame = {
+ residualsType match {
+ case "deviance" => devianceResiduals
+ case "pearson" => pearsonResiduals
+ case "working" => workingResiduals
+ case "response" => responseResiduals
+ case other => throw new UnsupportedOperationException(
+ s"The residuals type $other is not supported by Generalized Linear Regression.")
+ }
+ }
+
+ /**
+ * The deviance for the null model.
+ */
+ @Since("2.0.0")
+ lazy val nullDeviance: Double = {
+ val w = if (model.getWeightCol.isEmpty) lit(1.0) else col(model.getWeightCol)
+ val wtdmu: Double = if (model.getFitIntercept) {
+ val agg = predictions.agg(sum(w.multiply(col(model.getLabelCol))), sum(w)).first()
+ agg.getDouble(0) / agg.getDouble(1)
+ } else {
+ link.unlink(0.0)
+ }
+ predictions.select(col(model.getLabelCol), w).rdd.map {
+ case Row(y: Double, weight: Double) =>
+ family.deviance(y, wtdmu, weight)
+ }.sum()
+ }
+
+ /**
+ * The deviance for the fitted model.
+ */
+ @Since("2.0.0")
+ lazy val deviance: Double = {
+ val w = if (model.getWeightCol.isEmpty) lit(1.0) else col(model.getWeightCol)
+ predictions.select(col(model.getLabelCol), col(predictionCol), w).rdd.map {
+ case Row(label: Double, pred: Double, weight: Double) =>
+ family.deviance(label, pred, weight)
+ }.sum()
+ }
+
+ /**
+ * The dispersion of the fitted model.
+ * It is taken as 1.0 for the "binomial" and "poisson" families, and otherwise
+ * estimated by the residual Pearson's Chi-Squared statistic(which is defined as
+ * sum of the squares of the Pearson residuals) divided by the residual degrees of freedom.
+ */
+ @Since("2.0.0")
+ lazy val dispersion: Double = if (
+ model.getFamily == Binomial.name || model.getFamily == Poisson.name) {
+ 1.0
+ } else {
+ val rss = pearsonResiduals.agg(sum(pow(col("pearsonResiduals"), 2.0))).first().getDouble(0)
+ rss / degreesOfFreedom
+ }
+
+ /** Akaike's "An Information Criterion"(AIC) for the fitted model. */
+ @Since("2.0.0")
+ lazy val aic: Double = {
+ val w = if (model.getWeightCol.isEmpty) lit(1.0) else col(model.getWeightCol)
+ val weightSum = predictions.select(w).agg(sum(w)).first().getDouble(0)
+ val t = predictions.select(col(model.getLabelCol), col(predictionCol), w).rdd.map {
+ case Row(label: Double, pred: Double, weight: Double) =>
+ (label, pred, weight)
+ }
+ family.aic(t, deviance, numInstances, weightSum) + 2 * rank
+ }
+
+ /**
+ * Standard error of estimated coefficients and intercept.
+ */
+ @Since("2.0.0")
+ lazy val coefficientStandardErrors: Array[Double] = {
+ diagInvAtWA.map(_ * dispersion).map(math.sqrt)
+ }
+
+ /**
+ * T-statistic of estimated coefficients and intercept.
+ */
+ @Since("2.0.0")
+ lazy val tValues: Array[Double] = {
+ val estimate = if (model.getFitIntercept) {
+ Array.concat(model.coefficients.toArray, Array(model.intercept))
+ } else {
+ model.coefficients.toArray
+ }
+ estimate.zip(coefficientStandardErrors).map { x => x._1 / x._2 }
+ }
+
+ /**
+ * Two-sided p-value of estimated coefficients and intercept.
+ */
+ @Since("2.0.0")
+ lazy val pValues: Array[Double] = {
+ if (model.getFamily == Binomial.name || model.getFamily == Poisson.name) {
+ tValues.map { x => 2.0 * (1.0 - dist.Gaussian(0.0, 1.0).cdf(math.abs(x))) }
+ } else {
+ tValues.map { x => 2.0 * (1.0 - dist.StudentsT(degreesOfFreedom.toDouble).cdf(math.abs(x))) }
+ }
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
index 618304ad19..6d570f7bde 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.ml.regression
import scala.util.Random
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
@@ -29,6 +30,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.functions._
class GeneralizedLinearRegressionSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -466,6 +468,461 @@ class GeneralizedLinearRegressionSuite
}
}
+ test("glm summary: gaussian family with weight") {
+ /*
+ R code:
+
+ A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2)
+ b <- c(17, 19, 23, 29)
+ w <- c(1, 2, 3, 4)
+ df <- as.data.frame(cbind(A, b))
+ */
+ val datasetWithWeight = sqlContext.createDataFrame(sc.parallelize(Seq(
+ Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
+ Instance(19.0, 2.0, Vectors.dense(1.0, 7.0)),
+ Instance(23.0, 3.0, Vectors.dense(2.0, 11.0)),
+ Instance(29.0, 4.0, Vectors.dense(3.0, 13.0))
+ ), 2))
+ /*
+ R code:
+
+ model <- glm(formula = "b ~ .", family="gaussian", data = df, weights = w)
+ summary(model)
+
+ Deviance Residuals:
+ 1 2 3 4
+ 1.920 -1.358 -1.109 0.960
+
+ Coefficients:
+ Estimate Std. Error t value Pr(>|t|)
+ (Intercept) 18.080 9.608 1.882 0.311
+ V1 6.080 5.556 1.094 0.471
+ V2 -0.600 1.960 -0.306 0.811
+
+ (Dispersion parameter for gaussian family taken to be 7.68)
+
+ Null deviance: 202.00 on 3 degrees of freedom
+ Residual deviance: 7.68 on 1 degrees of freedom
+ AIC: 18.783
+
+ Number of Fisher Scoring iterations: 2
+
+ residuals(model, type="pearson")
+ 1 2 3 4
+ 1.920000 -1.357645 -1.108513 0.960000
+
+ residuals(model, type="working")
+ 1 2 3 4
+ 1.92 -0.96 -0.64 0.48
+
+ residuals(model, type="response")
+ 1 2 3 4
+ 1.92 -0.96 -0.64 0.48
+ */
+ val trainer = new GeneralizedLinearRegression()
+ .setWeightCol("weight")
+
+ val model = trainer.fit(datasetWithWeight)
+
+ val coefficientsR = Vectors.dense(Array(6.080, -0.600))
+ val interceptR = 18.080
+ val devianceResidualsR = Array(1.920, -1.358, -1.109, 0.960)
+ val pearsonResidualsR = Array(1.920000, -1.357645, -1.108513, 0.960000)
+ val workingResidualsR = Array(1.92, -0.96, -0.64, 0.48)
+ val responseResidualsR = Array(1.92, -0.96, -0.64, 0.48)
+ val seCoefR = Array(5.556, 1.960, 9.608)
+ val tValsR = Array(1.094, -0.306, 1.882)
+ val pValsR = Array(0.471, 0.811, 0.311)
+ val dispersionR = 7.68
+ val nullDevianceR = 202.00
+ val residualDevianceR = 7.68
+ val residualDegreeOfFreedomNullR = 3
+ val residualDegreeOfFreedomR = 1
+ val aicR = 18.783
+
+ val summary = model.summary
+
+ val devianceResiduals = summary.residuals()
+ .select(col("devianceResiduals"))
+ .collect()
+ .map(_.getDouble(0))
+ val pearsonResiduals = summary.residuals("pearson")
+ .select(col("pearsonResiduals"))
+ .collect()
+ .map(_.getDouble(0))
+ val workingResiduals = summary.residuals("working")
+ .select(col("workingResiduals"))
+ .collect()
+ .map(_.getDouble(0))
+ val responseResiduals = summary.residuals("response")
+ .select(col("responseResiduals"))
+ .collect()
+ .map(_.getDouble(0))
+
+ assert(model.coefficients ~== coefficientsR absTol 1E-3)
+ assert(model.intercept ~== interceptR absTol 1E-3)
+ devianceResiduals.zip(devianceResidualsR).foreach { x =>
+ assert(x._1 ~== x._2 absTol 1E-3) }
+ pearsonResiduals.zip(pearsonResidualsR).foreach { x =>
+ assert(x._1 ~== x._2 absTol 1E-3) }
+ workingResiduals.zip(workingResidualsR).foreach { x =>
+ assert(x._1 ~== x._2 absTol 1E-3) }
+ responseResiduals.zip(responseResidualsR).foreach { x =>
+ assert(x._1 ~== x._2 absTol 1E-3) }
+ summary.coefficientStandardErrors.zip(seCoefR).foreach{ x =>
+ assert(x._1 ~== x._2 absTol 1E-3) }
+ summary.tValues.zip(tValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) }
+ summary.pValues.zip(pValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) }
+ assert(summary.dispersion ~== dispersionR absTol 1E-3)
+ assert(summary.nullDeviance ~== nullDevianceR absTol 1E-3)
+ assert(summary.deviance ~== residualDevianceR absTol 1E-3)
+ assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR)
+ assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR)
+ assert(summary.aic ~== aicR absTol 1E-3)
+ }
+
+ test("glm summary: binomial family with weight") {
+ /*
+ R code:
+
+ A <- matrix(c(0, 1, 2, 3, 5, 2, 1, 3), 4, 2)
+ b <- c(1, 0, 1, 0)
+ w <- c(1, 2, 3, 4)
+ df <- as.data.frame(cbind(A, b))
+ */
+ val datasetWithWeight = sqlContext.createDataFrame(sc.parallelize(Seq(
+ Instance(1.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
+ Instance(0.0, 2.0, Vectors.dense(1.0, 2.0)),
+ Instance(1.0, 3.0, Vectors.dense(2.0, 1.0)),
+ Instance(0.0, 4.0, Vectors.dense(3.0, 3.0))
+ ), 2))
+ /*
+ R code:
+
+ model <- glm(formula = "b ~ . -1", family="binomial", data = df, weights = w)
+ summary(model)
+
+ Deviance Residuals:
+ 1 2 3 4
+ 1.273 -1.437 2.533 -1.556
+
+ Coefficients:
+ Estimate Std. Error z value Pr(>|z|)
+ V1 -0.30217 0.46242 -0.653 0.513
+ V2 -0.04452 0.37124 -0.120 0.905
+
+ (Dispersion parameter for binomial family taken to be 1)
+
+ Null deviance: 13.863 on 4 degrees of freedom
+ Residual deviance: 12.524 on 2 degrees of freedom
+ AIC: 16.524
+
+ Number of Fisher Scoring iterations: 5
+
+ residuals(model, type="pearson")
+ 1 2 3 4
+ 1.117731 -1.162962 2.395838 -1.189005
+
+ residuals(model, type="working")
+ 1 2 3 4
+ 2.249324 -1.676240 2.913346 -1.353433
+
+ residuals(model, type="response")
+ 1 2 3 4
+ 0.5554219 -0.4034267 0.6567520 -0.2611382
+ */
+ val trainer = new GeneralizedLinearRegression()
+ .setFamily("binomial")
+ .setWeightCol("weight")
+ .setFitIntercept(false)
+
+ val model = trainer.fit(datasetWithWeight)
+
+ val coefficientsR = Vectors.dense(Array(-0.30217, -0.04452))
+ val interceptR = 0.0
+ val devianceResidualsR = Array(1.273, -1.437, 2.533, -1.556)
+ val pearsonResidualsR = Array(1.117731, -1.162962, 2.395838, -1.189005)
+ val workingResidualsR = Array(2.249324, -1.676240, 2.913346, -1.353433)
+ val responseResidualsR = Array(0.5554219, -0.4034267, 0.6567520, -0.2611382)
+ val seCoefR = Array(0.46242, 0.37124)
+ val tValsR = Array(-0.653, -0.120)
+ val pValsR = Array(0.513, 0.905)
+ val dispersionR = 1.0
+ val nullDevianceR = 13.863
+ val residualDevianceR = 12.524
+ val residualDegreeOfFreedomNullR = 4
+ val residualDegreeOfFreedomR = 2
+ val aicR = 16.524
+
+ val summary = model.summary
+ val devianceResiduals = summary.residuals()
+ .select(col("devianceResiduals"))
+ .collect()
+ .map(_.getDouble(0))
+ val pearsonResiduals = summary.residuals("pearson")
+ .select(col("pearsonResiduals"))
+ .collect()
+ .map(_.getDouble(0))
+ val workingResiduals = summary.residuals("working")
+ .select(col("workingResiduals"))
+ .collect()
+ .map(_.getDouble(0))
+ val responseResiduals = summary.residuals("response")
+ .select(col("responseResiduals"))
+ .collect()
+ .map(_.getDouble(0))
+
+ assert(model.coefficients ~== coefficientsR absTol 1E-3)
+ assert(model.intercept ~== interceptR absTol 1E-3)
+ devianceResiduals.zip(devianceResidualsR).foreach { x =>
+ assert(x._1 ~== x._2 absTol 1E-3) }
+ pearsonResiduals.zip(pearsonResidualsR).foreach { x =>
+ assert(x._1 ~== x._2 absTol 1E-3) }
+ workingResiduals.zip(workingResidualsR).foreach { x =>
+ assert(x._1 ~== x._2 absTol 1E-3) }
+ responseResiduals.zip(responseResidualsR).foreach { x =>
+ assert(x._1 ~== x._2 absTol 1E-3) }
+ summary.coefficientStandardErrors.zip(seCoefR).foreach{ x =>
+ assert(x._1 ~== x._2 absTol 1E-3) }
+ summary.tValues.zip(tValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) }
+ summary.pValues.zip(pValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) }
+ assert(summary.dispersion ~== dispersionR absTol 1E-3)
+ assert(summary.nullDeviance ~== nullDevianceR absTol 1E-3)
+ assert(summary.deviance ~== residualDevianceR absTol 1E-3)
+ assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR)
+ assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR)
+ assert(summary.aic ~== aicR absTol 1E-3)
+ }
+
+ test("glm summary: poisson family with weight") {
+ /*
+ R code:
+
+ A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2)
+ b <- c(2, 8, 3, 9)
+ w <- c(1, 2, 3, 4)
+ df <- as.data.frame(cbind(A, b))
+ */
+ val datasetWithWeight = sqlContext.createDataFrame(sc.parallelize(Seq(
+ Instance(2.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
+ Instance(8.0, 2.0, Vectors.dense(1.0, 7.0)),
+ Instance(3.0, 3.0, Vectors.dense(2.0, 11.0)),
+ Instance(9.0, 4.0, Vectors.dense(3.0, 13.0))
+ ), 2))
+ /*
+ R code:
+
+ model <- glm(formula = "b ~ .", family="poisson", data = df, weights = w)
+ summary(model)
+
+ Deviance Residuals:
+ 1 2 3 4
+ -0.28952 0.11048 0.14839 -0.07268
+
+ Coefficients:
+ Estimate Std. Error z value Pr(>|z|)
+ (Intercept) 6.2999 1.6086 3.916 8.99e-05 ***
+ V1 3.3241 1.0184 3.264 0.00110 **
+ V2 -1.0818 0.3522 -3.071 0.00213 **
+ ---
+ Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
+
+ (Dispersion parameter for poisson family taken to be 1)
+
+ Null deviance: 15.38066 on 3 degrees of freedom
+ Residual deviance: 0.12333 on 1 degrees of freedom
+ AIC: 41.803
+
+ Number of Fisher Scoring iterations: 3
+
+ residuals(model, type="pearson")
+ 1 2 3 4
+ -0.28043145 0.11099310 0.14963714 -0.07253611
+
+ residuals(model, type="working")
+ 1 2 3 4
+ -0.17960679 0.02813593 0.05113852 -0.01201650
+
+ residuals(model, type="response")
+ 1 2 3 4
+ -0.4378554 0.2189277 0.1459518 -0.1094638
+ */
+ val trainer = new GeneralizedLinearRegression()
+ .setFamily("poisson")
+ .setWeightCol("weight")
+ .setFitIntercept(true)
+
+ val model = trainer.fit(datasetWithWeight)
+
+ val coefficientsR = Vectors.dense(Array(3.3241, -1.0818))
+ val interceptR = 6.2999
+ val devianceResidualsR = Array(-0.28952, 0.11048, 0.14839, -0.07268)
+ val pearsonResidualsR = Array(-0.28043145, 0.11099310, 0.14963714, -0.07253611)
+ val workingResidualsR = Array(-0.17960679, 0.02813593, 0.05113852, -0.01201650)
+ val responseResidualsR = Array(-0.4378554, 0.2189277, 0.1459518, -0.1094638)
+ val seCoefR = Array(1.0184, 0.3522, 1.6086)
+ val tValsR = Array(3.264, -3.071, 3.916)
+ val pValsR = Array(0.00110, 0.00213, 0.00009)
+ val dispersionR = 1.0
+ val nullDevianceR = 15.38066
+ val residualDevianceR = 0.12333
+ val residualDegreeOfFreedomNullR = 3
+ val residualDegreeOfFreedomR = 1
+ val aicR = 41.803
+
+ val summary = model.summary
+ val devianceResiduals = summary.residuals()
+ .select(col("devianceResiduals"))
+ .collect()
+ .map(_.getDouble(0))
+ val pearsonResiduals = summary.residuals("pearson")
+ .select(col("pearsonResiduals"))
+ .collect()
+ .map(_.getDouble(0))
+ val workingResiduals = summary.residuals("working")
+ .select(col("workingResiduals"))
+ .collect()
+ .map(_.getDouble(0))
+ val responseResiduals = summary.residuals("response")
+ .select(col("responseResiduals"))
+ .collect()
+ .map(_.getDouble(0))
+
+ assert(model.coefficients ~== coefficientsR absTol 1E-3)
+ assert(model.intercept ~== interceptR absTol 1E-3)
+ devianceResiduals.zip(devianceResidualsR).foreach { x =>
+ assert(x._1 ~== x._2 absTol 1E-3) }
+ pearsonResiduals.zip(pearsonResidualsR).foreach { x =>
+ assert(x._1 ~== x._2 absTol 1E-3) }
+ workingResiduals.zip(workingResidualsR).foreach { x =>
+ assert(x._1 ~== x._2 absTol 1E-3) }
+ responseResiduals.zip(responseResidualsR).foreach { x =>
+ assert(x._1 ~== x._2 absTol 1E-3) }
+ summary.coefficientStandardErrors.zip(seCoefR).foreach{ x =>
+ assert(x._1 ~== x._2 absTol 1E-3) }
+ summary.tValues.zip(tValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) }
+ summary.pValues.zip(pValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) }
+ assert(summary.dispersion ~== dispersionR absTol 1E-3)
+ assert(summary.nullDeviance ~== nullDevianceR absTol 1E-3)
+ assert(summary.deviance ~== residualDevianceR absTol 1E-3)
+ assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR)
+ assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR)
+ assert(summary.aic ~== aicR absTol 1E-3)
+ }
+
+ test("glm summary: gamma family with weight") {
+ /*
+ R code:
+
+ A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2)
+ b <- c(2, 8, 3, 9)
+ w <- c(1, 2, 3, 4)
+ df <- as.data.frame(cbind(A, b))
+ */
+ val datasetWithWeight = sqlContext.createDataFrame(sc.parallelize(Seq(
+ Instance(2.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
+ Instance(8.0, 2.0, Vectors.dense(1.0, 7.0)),
+ Instance(3.0, 3.0, Vectors.dense(2.0, 11.0)),
+ Instance(9.0, 4.0, Vectors.dense(3.0, 13.0))
+ ), 2))
+ /*
+ R code:
+
+ model <- glm(formula = "b ~ .", family="Gamma", data = df, weights = w)
+ summary(model)
+
+ Deviance Residuals:
+ 1 2 3 4
+ -0.26343 0.05761 0.12818 -0.03484
+
+ Coefficients:
+ Estimate Std. Error t value Pr(>|t|)
+ (Intercept) -0.81511 0.23449 -3.476 0.178
+ V1 -0.72730 0.16137 -4.507 0.139
+ V2 0.23894 0.05481 4.359 0.144
+
+ (Dispersion parameter for Gamma family taken to be 0.07986091)
+
+ Null deviance: 2.937462 on 3 degrees of freedom
+ Residual deviance: 0.090358 on 1 degrees of freedom
+ AIC: 23.202
+
+ Number of Fisher Scoring iterations: 4
+
+ residuals(model, type="pearson")
+ 1 2 3 4
+ -0.24082508 0.05839241 0.13135766 -0.03463621
+
+ residuals(model, type="working")
+ 1 2 3 4
+ 0.091414181 -0.005374314 -0.027196998 0.001890910
+
+ residuals(model, type="response")
+ 1 2 3 4
+ -0.6344390 0.3172195 0.2114797 -0.1586097
+ */
+ val trainer = new GeneralizedLinearRegression()
+ .setFamily("gamma")
+ .setWeightCol("weight")
+
+ val model = trainer.fit(datasetWithWeight)
+
+ val coefficientsR = Vectors.dense(Array(-0.72730, 0.23894))
+ val interceptR = -0.81511
+ val devianceResidualsR = Array(-0.26343, 0.05761, 0.12818, -0.03484)
+ val pearsonResidualsR = Array(-0.24082508, 0.05839241, 0.13135766, -0.03463621)
+ val workingResidualsR = Array(0.091414181, -0.005374314, -0.027196998, 0.001890910)
+ val responseResidualsR = Array(-0.6344390, 0.3172195, 0.2114797, -0.1586097)
+ val seCoefR = Array(0.16137, 0.05481, 0.23449)
+ val tValsR = Array(-4.507, 4.359, -3.476)
+ val pValsR = Array(0.139, 0.144, 0.178)
+ val dispersionR = 0.07986091
+ val nullDevianceR = 2.937462
+ val residualDevianceR = 0.090358
+ val residualDegreeOfFreedomNullR = 3
+ val residualDegreeOfFreedomR = 1
+ val aicR = 23.202
+
+ val summary = model.summary
+ val devianceResiduals = summary.residuals()
+ .select(col("devianceResiduals"))
+ .collect()
+ .map(_.getDouble(0))
+ val pearsonResiduals = summary.residuals("pearson")
+ .select(col("pearsonResiduals"))
+ .collect()
+ .map(_.getDouble(0))
+ val workingResiduals = summary.residuals("working")
+ .select(col("workingResiduals"))
+ .collect()
+ .map(_.getDouble(0))
+ val responseResiduals = summary.residuals("response")
+ .select(col("responseResiduals"))
+ .collect()
+ .map(_.getDouble(0))
+
+ assert(model.coefficients ~== coefficientsR absTol 1E-3)
+ assert(model.intercept ~== interceptR absTol 1E-3)
+ devianceResiduals.zip(devianceResidualsR).foreach { x =>
+ assert(x._1 ~== x._2 absTol 1E-3) }
+ pearsonResiduals.zip(pearsonResidualsR).foreach { x =>
+ assert(x._1 ~== x._2 absTol 1E-3) }
+ workingResiduals.zip(workingResidualsR).foreach { x =>
+ assert(x._1 ~== x._2 absTol 1E-3) }
+ responseResiduals.zip(responseResidualsR).foreach { x =>
+ assert(x._1 ~== x._2 absTol 1E-3) }
+ summary.coefficientStandardErrors.zip(seCoefR).foreach{ x =>
+ assert(x._1 ~== x._2 absTol 1E-3) }
+ summary.tValues.zip(tValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) }
+ summary.pValues.zip(pValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) }
+ assert(summary.dispersion ~== dispersionR absTol 1E-3)
+ assert(summary.nullDeviance ~== nullDevianceR absTol 1E-3)
+ assert(summary.deviance ~== residualDevianceR absTol 1E-3)
+ assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR)
+ assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR)
+ assert(summary.aic ~== aicR absTol 1E-3)
+ }
+
test("read/write") {
def checkModelData(
model: GeneralizedLinearRegressionModel,