aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-11-05 09:56:18 -0800
committerXiangrui Meng <meng@databricks.com>2015-11-05 09:56:18 -0800
commit9da7ceed81b0afce7deb8f39f3a6d565d401a391 (patch)
treee6c70d2cc92cfdfd48ee366ad6230f9b87795230 /mllib/src/main/scala/org
parentb072ff4d1d05fc212cd7036d1897a032a395f0b3 (diff)
downloadspark-9da7ceed81b0afce7deb8f39f3a6d565d401a391.tar.gz
spark-9da7ceed81b0afce7deb8f39f3a6d565d401a391.tar.bz2
spark-9da7ceed81b0afce7deb8f39f3a6d565d401a391.zip
[SPARK-11473][ML] R-like summary statistics with intercept for OLS via normal equation solver
Follow up [SPARK-9836](https://issues.apache.org/jira/browse/SPARK-9836), we should also support summary statistics for ```intercept```. Author: Yanbo Liang <ybliang8@gmail.com> Closes #9485 from yanboliang/spark-11473.
Diffstat (limited to 'mllib/src/main/scala/org')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala35
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala22
2 files changed, 32 insertions, 25 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
index e612a2122e..8617722ae5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
@@ -75,7 +75,7 @@ private[ml] class WeightedLeastSquares(
val summary = instances.treeAggregate(new Aggregator)(_.add(_), _.merge(_))
summary.validate()
logInfo(s"Number of instances: ${summary.count}.")
- val k = summary.k
+ val k = if (fitIntercept) summary.k + 1 else summary.k
val triK = summary.triK
val wSum = summary.wSum
val bBar = summary.bBar
@@ -86,14 +86,6 @@ private[ml] class WeightedLeastSquares(
val aaBar = summary.aaBar
val aaValues = aaBar.values
- if (fitIntercept) {
- // shift centers
- // A^T A - aBar aBar^T
- BLAS.spr(-1.0, aBar, aaValues)
- // A^T b - bBar aBar
- BLAS.axpy(-bBar, aBar, abBar)
- }
-
// add regularization to diagonals
var i = 0
var j = 2
@@ -111,21 +103,32 @@ private[ml] class WeightedLeastSquares(
j += 1
}
- val x = new DenseVector(CholeskyDecomposition.solve(aaBar.values, abBar.values))
+ val aa = if (fitIntercept) {
+ Array.concat(aaBar.values, aBar.values, Array(1.0))
+ } else {
+ aaBar.values
+ }
+ val ab = if (fitIntercept) {
+ Array.concat(abBar.values, Array(bBar))
+ } else {
+ abBar.values
+ }
+
+ val x = CholeskyDecomposition.solve(aa, ab)
+
+ val aaInv = CholeskyDecomposition.inverse(aa, k)
- val aaInv = CholeskyDecomposition.inverse(aaBar.values, k)
// aaInv is a packed upper triangular matrix, here we get all elements on diagonal
val diagInvAtWA = new DenseVector((1 to k).map { i =>
aaInv(i + (i - 1) * i / 2 - 1) / wSum }.toArray)
- // compute intercept
- val intercept = if (fitIntercept) {
- bBar - BLAS.dot(aBar, x)
+ val (coefficients, intercept) = if (fitIntercept) {
+ (new DenseVector(x.slice(0, x.length - 1)), x.last)
} else {
- 0.0
+ (new DenseVector(x), 0.0)
}
- new WeightedLeastSquaresModel(x, intercept, diagInvAtWA)
+ new WeightedLeastSquaresModel(coefficients, intercept, diagInvAtWA)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index c51e30483a..6638313818 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -511,8 +511,7 @@ class LinearRegressionSummary private[regression] (
}
/**
- * Standard error of estimated coefficients.
- * Note that standard error of estimated intercept is not supported currently.
+ * Standard error of estimated coefficients and intercept.
*/
lazy val coefficientStandardErrors: Array[Double] = {
if (diagInvAtWA.length == 1 && diagInvAtWA(0) == 0) {
@@ -532,21 +531,26 @@ class LinearRegressionSummary private[regression] (
}
}
- /** T-statistic of estimated coefficients.
- * Note that t-statistic of estimated intercept is not supported currently.
- */
+ /**
+ * T-statistic of estimated coefficients and intercept.
+ */
lazy val tValues: Array[Double] = {
if (diagInvAtWA.length == 1 && diagInvAtWA(0) == 0) {
throw new UnsupportedOperationException(
"No t-statistic available for this LinearRegressionModel")
} else {
- model.coefficients.toArray.zip(coefficientStandardErrors).map { x => x._1 / x._2 }
+ val estimate = if (model.getFitIntercept) {
+ Array.concat(model.coefficients.toArray, Array(model.intercept))
+ } else {
+ model.coefficients.toArray
+ }
+ estimate.zip(coefficientStandardErrors).map { x => x._1 / x._2 }
}
}
- /** Two-sided p-value of estimated coefficients.
- * Note that p-value of estimated intercept is not supported currently.
- */
+ /**
+ * Two-sided p-value of estimated coefficients and intercept.
+ */
lazy val pValues: Array[Double] = {
if (diagInvAtWA.length == 1 && diagInvAtWA(0) == 0) {
throw new UnsupportedOperationException(