From d6f10aa7ea2806c0fbcfc31d7dee91d28319fab7 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 3 Nov 2015 08:29:07 -0800 Subject: [SPARK-9836][ML] Provide R-like summary statistics for OLS via normal equation solver https://issues.apache.org/jira/browse/SPARK-9836 Author: Yanbo Liang Closes #9413 from yanboliang/spark-9836. --- .../spark/ml/optim/WeightedLeastSquares.scala | 15 ++- .../spark/ml/regression/LinearRegression.scala | 90 +++++++++++++- .../spark/mllib/linalg/CholeskyDecomposition.scala | 16 +++ .../ml/regression/LinearRegressionSuite.scala | 129 +++++++++++++++++++++ 4 files changed, 243 insertions(+), 7 deletions(-) (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala index 3d64f7f296..e612a2122e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -26,10 +26,12 @@ import org.apache.spark.rdd.RDD * Model fitted by [[WeightedLeastSquares]]. * @param coefficients model coefficients * @param intercept model intercept + * @param diagInvAtWA diagonal of matrix (A^T * W * A)^-1 */ private[ml] class WeightedLeastSquaresModel( val coefficients: DenseVector, - val intercept: Double) extends Serializable + val intercept: Double, + val diagInvAtWA: DenseVector) extends Serializable /** * Weighted least squares solver via normal equation. @@ -73,7 +75,9 @@ private[ml] class WeightedLeastSquares( val summary = instances.treeAggregate(new Aggregator)(_.add(_), _.merge(_)) summary.validate() logInfo(s"Number of instances: ${summary.count}.") + val k = summary.k val triK = summary.triK + val wSum = summary.wSum val bBar = summary.bBar val bStd = summary.bStd val aBar = summary.aBar @@ -109,6 +113,11 @@ private[ml] class WeightedLeastSquares( val x = new DenseVector(CholeskyDecomposition.solve(aaBar.values, abBar.values)) + val aaInv = CholeskyDecomposition.inverse(aaBar.values, k) + // aaInv is a packed upper triangular matrix, here we get all elements on diagonal + val diagInvAtWA = new DenseVector((1 to k).map { i => + aaInv(i + (i - 1) * i / 2 - 1) / wSum }.toArray) + // compute intercept val intercept = if (fitIntercept) { bBar - BLAS.dot(aBar, x) @@ -116,7 +125,7 @@ private[ml] class WeightedLeastSquares( 0.0 } - new WeightedLeastSquaresModel(x, intercept) + new WeightedLeastSquaresModel(x, intercept, diagInvAtWA) } } @@ -131,7 +140,7 @@ private[ml] object WeightedLeastSquares { var k: Int = _ var count: Long = _ var triK: Int = _ - private var wSum: Double = _ + var wSum: Double = _ private var wwSum: Double = _ private var bSum: Double = _ private var bbSum: Double = _ diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 6e9c7442b8..c51e30483a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -21,6 +21,7 @@ import scala.collection.mutable import breeze.linalg.{DenseVector => BDV} import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} +import breeze.stats.distributions.StudentsT import org.apache.spark.{Logging, SparkException} import org.apache.spark.annotation.Experimental @@ -36,7 +37,7 @@ import org.apache.spark.mllib.linalg.BLAS._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.functions.{col, udf, lit} +import org.apache.spark.sql.functions._ import org.apache.spark.storage.StorageLevel /** @@ -173,8 +174,11 @@ class LinearRegression(override val uid: String) summaryModel.transform(dataset), predictionColName, $(labelCol), + summaryModel, + model.diagInvAtWA.toArray, $(featuresCol), Array(0D)) + return lrModel.setSummary(trainingSummary) } @@ -221,6 +225,8 @@ class LinearRegression(override val uid: String) summaryModel.transform(dataset), predictionColName, $(labelCol), + model, + Array(0D), $(featuresCol), Array(0D)) return copyValues(model.setSummary(trainingSummary)) @@ -316,6 +322,8 @@ class LinearRegression(override val uid: String) summaryModel.transform(dataset), predictionColName, $(labelCol), + model, + Array(0D), $(featuresCol), objectiveHistory) model.setSummary(trainingSummary) @@ -371,7 +379,8 @@ class LinearRegressionModel private[ml] ( private[regression] def evaluate(dataset: DataFrame): LinearRegressionSummary = { // Handle possible missing or invalid prediction columns val (summaryModel, predictionColName) = findSummaryModelAndPredictionCol() - new LinearRegressionSummary(summaryModel.transform(dataset), predictionColName, $(labelCol)) + new LinearRegressionSummary(summaryModel.transform(dataset), predictionColName, + $(labelCol), this, Array(0D)) } /** @@ -412,9 +421,11 @@ class LinearRegressionTrainingSummary private[regression] ( predictions: DataFrame, predictionCol: String, labelCol: String, + model: LinearRegressionModel, + diagInvAtWA: Array[Double], val featuresCol: String, val objectiveHistory: Array[Double]) - extends LinearRegressionSummary(predictions, predictionCol, labelCol) { + extends LinearRegressionSummary(predictions, predictionCol, labelCol, model, diagInvAtWA) { /** Number of training iterations until termination */ val totalIterations = objectiveHistory.length @@ -430,7 +441,9 @@ class LinearRegressionTrainingSummary private[regression] ( class LinearRegressionSummary private[regression] ( @transient val predictions: DataFrame, val predictionCol: String, - val labelCol: String) extends Serializable { + val labelCol: String, + val model: LinearRegressionModel, + val diagInvAtWA: Array[Double]) extends Serializable { @transient private val metrics = new RegressionMetrics( predictions @@ -474,6 +487,75 @@ class LinearRegressionSummary private[regression] ( predictions.select(t(col(predictionCol), col(labelCol)).as("residuals")) } + /** Number of instances in DataFrame predictions */ + lazy val numInstances: Long = predictions.count() + + /** Degrees of freedom */ + private val degreesOfFreedom: Long = if (model.getFitIntercept) { + numInstances - model.coefficients.size - 1 + } else { + numInstances - model.coefficients.size + } + + /** + * The weighted residuals, the usual residuals rescaled by + * the square root of the instance weights. + */ + lazy val devianceResiduals: Array[Double] = { + val weighted = if (model.getWeightCol.isEmpty) lit(1.0) else sqrt(col(model.getWeightCol)) + val dr = predictions.select(col(model.getLabelCol).minus(col(model.getPredictionCol)) + .multiply(weighted).as("weightedResiduals")) + .select(min(col("weightedResiduals")).as("min"), max(col("weightedResiduals")).as("max")) + .first() + Array(dr.getDouble(0), dr.getDouble(1)) + } + + /** + * Standard error of estimated coefficients. + * Note that standard error of estimated intercept is not supported currently. + */ + lazy val coefficientStandardErrors: Array[Double] = { + if (diagInvAtWA.length == 1 && diagInvAtWA(0) == 0) { + throw new UnsupportedOperationException( + "No Std. Error of coefficients available for this LinearRegressionModel") + } else { + val rss = if (model.getWeightCol.isEmpty) { + meanSquaredError * numInstances + } else { + val t = udf { (pred: Double, label: Double, weight: Double) => + math.pow(label - pred, 2.0) * weight } + predictions.select(t(col(model.getPredictionCol), col(model.getLabelCol), + col(model.getWeightCol)).as("wse")).agg(sum(col("wse"))).first().getDouble(0) + } + val sigma2 = rss / degreesOfFreedom + diagInvAtWA.map(_ * sigma2).map(math.sqrt(_)) + } + } + + /** T-statistic of estimated coefficients. + * Note that t-statistic of estimated intercept is not supported currently. + */ + lazy val tValues: Array[Double] = { + if (diagInvAtWA.length == 1 && diagInvAtWA(0) == 0) { + throw new UnsupportedOperationException( + "No t-statistic available for this LinearRegressionModel") + } else { + model.coefficients.toArray.zip(coefficientStandardErrors).map { x => x._1 / x._2 } + } + } + + /** Two-sided p-value of estimated coefficients. + * Note that p-value of estimated intercept is not supported currently. + */ + lazy val pValues: Array[Double] = { + if (diagInvAtWA.length == 1 && diagInvAtWA(0) == 0) { + throw new UnsupportedOperationException( + "No p-value available for this LinearRegressionModel") + } else { + tValues.map { x => 2.0 * (1.0 - StudentsT(degreesOfFreedom.toDouble).cdf(math.abs(x))) } + } + } + } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/CholeskyDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/CholeskyDecomposition.scala index 66eb40b6f4..0cd371e9cc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/CholeskyDecomposition.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/CholeskyDecomposition.scala @@ -40,4 +40,20 @@ private[spark] object CholeskyDecomposition { assert(code == 0, s"lapack.dpotrs returned $code.") bx } + + /** + * Computes the inverse of a real symmetric positive definite matrix A + * using the Cholesky factorization A = U**T*U. + * The input arguments are modified in-place to store the inverse matrix. + * @param UAi the upper triangular factor U from the Cholesky factorization A = U**T*U + * @param k the dimension of A + * @return the upper triangle of the (symmetric) inverse of A + */ + def inverse(UAi: Array[Double], k: Int): Array[Double] = { + val info = new intW(0) + lapack.dpptri("U", k, UAi, info) + val code = info.`val` + assert(code == 0, s"lapack.dpptri returned $code.") + UAi + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 235c796d78..fbf83e8922 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -35,6 +35,7 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { @transient var datasetWithDenseFeature: DataFrame = _ @transient var datasetWithDenseFeatureWithoutIntercept: DataFrame = _ @transient var datasetWithSparseFeature: DataFrame = _ + @transient var datasetWithWeight: DataFrame = _ /* In `LinearRegressionSuite`, we will make sure that the model trained by SparkML @@ -73,6 +74,22 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { xMean = Seq.fill(featureSize)(r.nextDouble).toArray, xVariance = Seq.fill(featureSize)(r.nextDouble).toArray, nPoints = 200, seed, eps = 0.1, sparsity = 0.7), 2)) + + /* + R code: + + A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2) + b <- c(17, 19, 23, 29) + w <- c(1, 2, 3, 4) + df <- as.data.frame(cbind(A, b)) + */ + datasetWithWeight = sqlContext.createDataFrame( + sc.parallelize(Seq( + Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), + Instance(19.0, 2.0, Vectors.dense(1.0, 7.0)), + Instance(23.0, 3.0, Vectors.dense(2.0, 11.0)), + Instance(29.0, 4.0, Vectors.dense(3.0, 13.0)) + ), 2)) } test("params") { @@ -603,6 +620,16 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { // To clalify that the normal solver is used here. assert(model.summary.objectiveHistory.length == 1) assert(model.summary.objectiveHistory(0) == 0.0) + val devianceResidualsR = Array(-0.35566, 0.34504) + val seCoefR = Array(0.0011756, 0.0009032) + val tValsR = Array(3998, 7971) + val pValsR = Array(0, 0) + model.summary.devianceResiduals.zip(devianceResidualsR).foreach { x => + assert(x._1 ~== x._2 absTol 1E-3) } + model.summary.coefficientStandardErrors.zip(seCoefR).foreach{ x => + assert(x._1 ~== x._2 absTol 1E-3) } + model.summary.tValues.map(_.round).zip(tValsR).foreach{ x => assert(x._1 === x._2) } + model.summary.pValues.map(_.round).zip(pValsR).foreach{ x => assert(x._1 === x._2) } } } } @@ -725,4 +752,106 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { .sliding(2) .forall(x => x(0) >= x(1))) } + + test("linear regression summary with weighted samples and intercept by normal solver") { + /* + R code: + + model <- glm(formula = "b ~ .", data = df, weights = w) + summary(model) + + Call: + glm(formula = "b ~ .", data = df, weights = w) + + Deviance Residuals: + 1 2 3 4 + 1.920 -1.358 -1.109 0.960 + + Coefficients: + Estimate Std. Error t value Pr(>|t|) + (Intercept) 18.080 9.608 1.882 0.311 + V1 6.080 5.556 1.094 0.471 + V2 -0.600 1.960 -0.306 0.811 + + (Dispersion parameter for gaussian family taken to be 7.68) + + Null deviance: 202.00 on 3 degrees of freedom + Residual deviance: 7.68 on 1 degrees of freedom + AIC: 18.783 + + Number of Fisher Scoring iterations: 2 + */ + + val model = new LinearRegression() + .setWeightCol("weight") + .setSolver("normal") + .fit(datasetWithWeight) + val coefficientsR = Vectors.dense(Array(6.080, -0.600)) + val interceptR = 18.080 + val devianceResidualsR = Array(-1.358, 1.920) + val seCoefR = Array(5.556, 1.960) + val tValsR = Array(1.094, -0.306) + val pValsR = Array(0.471, 0.811) + + assert(model.coefficients ~== coefficientsR absTol 1E-3) + assert(model.intercept ~== interceptR absTol 1E-3) + model.summary.devianceResiduals.zip(devianceResidualsR).foreach { x => + assert(x._1 ~== x._2 absTol 1E-3) } + model.summary.coefficientStandardErrors.zip(seCoefR).foreach{ x => + assert(x._1 ~== x._2 absTol 1E-3) } + model.summary.tValues.zip(tValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } + model.summary.pValues.zip(pValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } + } + + test("linear regression summary with weighted samples and w/o intercept by normal solver") { + /* + R code: + + model <- glm(formula = "b ~ . -1", data = df, weights = w) + summary(model) + + Call: + glm(formula = "b ~ . -1", data = df, weights = w) + + Deviance Residuals: + 1 2 3 4 + 1.950 2.344 -4.600 2.103 + + Coefficients: + Estimate Std. Error t value Pr(>|t|) + V1 -3.7271 2.9032 -1.284 0.3279 + V2 3.0100 0.6022 4.998 0.0378 * + --- + Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1 + + (Dispersion parameter for gaussian family taken to be 17.4376) + + Null deviance: 5962.000 on 4 degrees of freedom + Residual deviance: 34.875 on 2 degrees of freedom + AIC: 22.835 + + Number of Fisher Scoring iterations: 2 + */ + + val model = new LinearRegression() + .setWeightCol("weight") + .setSolver("normal") + .setFitIntercept(false) + .fit(datasetWithWeight) + val coefficientsR = Vectors.dense(Array(-3.7271, 3.0100)) + val interceptR = 0.0 + val devianceResidualsR = Array(-4.600, 2.344) + val seCoefR = Array(2.9032, 0.6022) + val tValsR = Array(-1.284, 4.998) + val pValsR = Array(0.3279, 0.0378) + + assert(model.coefficients ~== coefficientsR absTol 1E-3) + assert(model.intercept === interceptR) + model.summary.devianceResiduals.zip(devianceResidualsR).foreach { x => + assert(x._1 ~== x._2 absTol 1E-3) } + model.summary.coefficientStandardErrors.zip(seCoefR).foreach{ x => + assert(x._1 ~== x._2 absTol 1E-3) } + model.summary.tValues.zip(tValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } + model.summary.pValues.zip(pValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } + } } -- cgit v1.2.3