aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-09-08 20:51:20 -0700
committerXiangrui Meng <meng@databricks.com>2015-09-08 20:51:20 -0700
commit52fe32f6ac7a04fa9b4478fda1307c5b0c61c4a2 (patch)
tree3e2e955c99a17eee1a5b54179da650949db22b03
parent820913f554bef610d07ca2dadaead657f916ae63 (diff)
downloadspark-52fe32f6ac7a04fa9b4478fda1307c5b0c61c4a2.tar.gz
spark-52fe32f6ac7a04fa9b4478fda1307c5b0c61c4a2.tar.bz2
spark-52fe32f6ac7a04fa9b4478fda1307c5b0c61c4a2.zip
[SPARK-9834] [MLLIB] implement weighted least squares via normal equation
The goal of this PR is to have a weighted least squares implementation that takes the normal equation approach, and hence to be able to provide R-like summary statistics and support IRLS (used by GLMs). The tests match R's lm and glmnet. There are couple TODOs that can be addressed in future PRs: * consolidate summary statistics aggregators * move `dspr` to `BLAS` * etc It would be nice to have this merged first because it blocks couple other features. dbtsai Author: Xiangrui Meng <meng@databricks.com> Closes #8588 from mengxr/SPARK-9834.
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala296
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala7
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala3
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala133
4 files changed, 438 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
new file mode 100644
index 0000000000..a99e2ac4c6
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
@@ -0,0 +1,296 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.optim
+
+import com.github.fommil.netlib.LAPACK.{getInstance => lapack}
+import org.netlib.util.intW
+
+import org.apache.spark.Logging
+import org.apache.spark.mllib.linalg._
+import org.apache.spark.mllib.linalg.distributed.RowMatrix
+import org.apache.spark.rdd.RDD
+
+/**
+ * Model fitted by [[WeightedLeastSquares]].
+ * @param coefficients model coefficients
+ * @param intercept model intercept
+ */
+private[ml] class WeightedLeastSquaresModel(
+ val coefficients: DenseVector,
+ val intercept: Double) extends Serializable
+
+/**
+ * Weighted least squares solver via normal equation.
+ * Given weighted observations (w,,i,,, a,,i,,, b,,i,,), we use the following weighted least squares
+ * formulation:
+ *
+ * min,,x,z,, 1/2 sum,,i,, w,,i,, (a,,i,,^T^ x + z - b,,i,,)^2^ / sum,,i,, w_i
+ * + 1/2 lambda / delta sum,,j,, (sigma,,j,, x,,j,,)^2^,
+ *
+ * where lambda is the regularization parameter, and delta and sigma,,j,, are controlled by
+ * [[standardizeLabel]] and [[standardizeFeatures]], respectively.
+ *
+ * Set [[regParam]] to 0.0 and turn off both [[standardizeFeatures]] and [[standardizeLabel]] to
+ * match R's `lm`.
+ * Turn on [[standardizeLabel]] to match R's `glmnet`.
+ *
+ * @param fitIntercept whether to fit intercept. If false, z is 0.0.
+ * @param regParam L2 regularization parameter (lambda)
+ * @param standardizeFeatures whether to standardize features. If true, sigma_,,j,, is the
+ * population standard deviation of the j-th column of A. Otherwise,
+ * sigma,,j,, is 1.0.
+ * @param standardizeLabel whether to standardize label. If true, delta is the population standard
+ * deviation of the label column b. Otherwise, delta is 1.0.
+ */
+private[ml] class WeightedLeastSquares(
+ val fitIntercept: Boolean,
+ val regParam: Double,
+ val standardizeFeatures: Boolean,
+ val standardizeLabel: Boolean) extends Logging with Serializable {
+ import WeightedLeastSquares._
+
+ require(regParam >= 0.0, s"regParam cannot be negative: $regParam")
+ if (regParam == 0.0) {
+ logWarning("regParam is zero, which might cause numerical instability and overfitting.")
+ }
+
+ /**
+ * Creates a [[WeightedLeastSquaresModel]] from an RDD of [[Instance]]s.
+ */
+ def fit(instances: RDD[Instance]): WeightedLeastSquaresModel = {
+ val summary = instances.treeAggregate(new Aggregator)(_.add(_), _.merge(_))
+ summary.validate()
+ logInfo(s"Number of instances: ${summary.count}.")
+ val triK = summary.triK
+ val bBar = summary.bBar
+ val bStd = summary.bStd
+ val aBar = summary.aBar
+ val aVar = summary.aVar
+ val abBar = summary.abBar
+ val aaBar = summary.aaBar
+ val aaValues = aaBar.values
+
+ if (fitIntercept) {
+ // shift centers
+ // A^T A - aBar aBar^T
+ RowMatrix.dspr(-1.0, aBar, aaValues)
+ // A^T b - bBar aBar
+ BLAS.axpy(-bBar, aBar, abBar)
+ }
+
+ // add regularization to diagonals
+ var i = 0
+ var j = 2
+ while (i < triK) {
+ var lambda = regParam
+ if (standardizeFeatures) {
+ lambda *= aVar(j - 2)
+ }
+ if (standardizeLabel) {
+ // TODO: handle the case when bStd = 0
+ lambda /= bStd
+ }
+ aaValues(i) += lambda
+ i += j
+ j += 1
+ }
+
+ val x = choleskySolve(aaBar.values, abBar)
+
+ // compute intercept
+ val intercept = if (fitIntercept) {
+ bBar - BLAS.dot(aBar, x)
+ } else {
+ 0.0
+ }
+
+ new WeightedLeastSquaresModel(x, intercept)
+ }
+
+ /**
+ * Solves a symmetric positive definite linear system via Cholesky factorization.
+ * The input arguments are modified in-place to store the factorization and the solution.
+ * @param A the upper triangular part of A
+ * @param bx right-hand side
+ * @return the solution vector
+ */
+ // TODO: SPARK-10490 - consolidate this and the Cholesky solver in ALS
+ private def choleskySolve(A: Array[Double], bx: DenseVector): DenseVector = {
+ val k = bx.size
+ val info = new intW(0)
+ lapack.dppsv("U", k, 1, A, bx.values, k, info)
+ val code = info.`val`
+ assert(code == 0, s"lapack.dpotrs returned $code.")
+ bx
+ }
+}
+
+private[ml] object WeightedLeastSquares {
+
+ /**
+ * Case class for weighted observations.
+ * @param w weight, must be positive
+ * @param a features
+ * @param b label
+ */
+ case class Instance(w: Double, a: Vector, b: Double) {
+ require(w >= 0.0, s"Weight cannot be negative: $w.")
+ }
+
+ /**
+ * Aggregator to provide necessary summary statistics for solving [[WeightedLeastSquares]].
+ */
+ // TODO: consolidate aggregates for summary statistics
+ private class Aggregator extends Serializable {
+ var initialized: Boolean = false
+ var k: Int = _
+ var count: Long = _
+ var triK: Int = _
+ private var wSum: Double = _
+ private var wwSum: Double = _
+ private var bSum: Double = _
+ private var bbSum: Double = _
+ private var aSum: DenseVector = _
+ private var abSum: DenseVector = _
+ private var aaSum: DenseVector = _
+
+ private def init(k: Int): Unit = {
+ require(k <= 4096, "In order to take the normal equation approach efficiently, " +
+ s"we set the max number of features to 4096 but got $k.")
+ this.k = k
+ triK = k * (k + 1) / 2
+ count = 0L
+ wSum = 0.0
+ wwSum = 0.0
+ bSum = 0.0
+ bbSum = 0.0
+ aSum = new DenseVector(Array.ofDim(k))
+ abSum = new DenseVector(Array.ofDim(k))
+ aaSum = new DenseVector(Array.ofDim(triK))
+ initialized = true
+ }
+
+ /**
+ * Adds an instance.
+ */
+ def add(instance: Instance): this.type = {
+ val Instance(w, a, b) = instance
+ val ak = a.size
+ if (!initialized) {
+ init(ak)
+ initialized = true
+ }
+ assert(ak == k, s"Dimension mismatch. Expect vectors of size $k but got $ak.")
+ count += 1L
+ wSum += w
+ wwSum += w * w
+ bSum += w * b
+ bbSum += w * b * b
+ BLAS.axpy(w, a, aSum)
+ BLAS.axpy(w * b, a, abSum)
+ RowMatrix.dspr(w, a, aaSum.values)
+ this
+ }
+
+ /**
+ * Merges another [[Aggregator]].
+ */
+ def merge(other: Aggregator): this.type = {
+ if (!other.initialized) {
+ this
+ } else {
+ if (!initialized) {
+ init(other.k)
+ }
+ assert(k == other.k, s"dimension mismatch: this.k = $k but other.k = ${other.k}")
+ count += other.count
+ wSum += other.wSum
+ wwSum += other.wwSum
+ bSum += other.bSum
+ bbSum += other.bbSum
+ BLAS.axpy(1.0, other.aSum, aSum)
+ BLAS.axpy(1.0, other.abSum, abSum)
+ BLAS.axpy(1.0, other.aaSum, aaSum)
+ this
+ }
+ }
+
+ /**
+ * Validates that we have seen observations.
+ */
+ def validate(): Unit = {
+ assert(initialized, "Training dataset is empty.")
+ assert(wSum > 0.0, "Sum of weights cannot be zero.")
+ }
+
+ /**
+ * Weighted mean of features.
+ */
+ def aBar: DenseVector = {
+ val output = aSum.copy
+ BLAS.scal(1.0 / wSum, output)
+ output
+ }
+
+ /**
+ * Weighted mean of labels.
+ */
+ def bBar: Double = bSum / wSum
+
+ /**
+ * Weighted population standard deviation of labels.
+ */
+ def bStd: Double = math.sqrt(bbSum / wSum - bBar * bBar)
+
+ /**
+ * Weighted mean of (label * features).
+ */
+ def abBar: DenseVector = {
+ val output = abSum.copy
+ BLAS.scal(1.0 / wSum, output)
+ output
+ }
+
+ /**
+ * Weighted mean of (features * features^T^).
+ */
+ def aaBar: DenseVector = {
+ val output = aaSum.copy
+ BLAS.scal(1.0 / wSum, output)
+ output
+ }
+
+ /**
+ * Weighted population variance of features.
+ */
+ def aVar: DenseVector = {
+ val variance = Array.ofDim[Double](k)
+ var i = 0
+ var j = 2
+ val aaValues = aaSum.values
+ while (i < triK) {
+ val l = j - 2
+ val aw = aSum(l) / wSum
+ variance(l) = aaValues(i) / wSum - aw * aw
+ i += j
+ j += 1
+ }
+ new DenseVector(variance)
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
index ab475af264..9ee81eda8a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
@@ -92,6 +92,13 @@ private[spark] object BLAS extends Serializable with Logging {
}
}
+ /** Y += a * x */
+ private[spark] def axpy(a: Double, X: DenseMatrix, Y: DenseMatrix): Unit = {
+ require(X.numRows == Y.numRows && X.numCols == Y.numCols, "Dimension mismatch: " +
+ s"size(X) = ${(X.numRows, X.numCols)} but size(Y) = ${(Y.numRows, Y.numCols)}.")
+ f2jBLAS.daxpy(X.numRows * X.numCols, a, X.values, 1, Y.values, 1)
+ }
+
/**
* dot(x, y)
*/
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
index 9a423ddafd..83779ac889 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
@@ -678,7 +678,8 @@ object RowMatrix {
*
* @param U the upper triangular part of the matrix packed in an array (column major)
*/
- private def dspr(alpha: Double, v: Vector, U: Array[Double]): Unit = {
+ // TODO: SPARK-10491 - move this method to linalg.BLAS
+ private[spark] def dspr(alpha: Double, v: Vector, U: Array[Double]): Unit = {
// TODO: Find a better home (breeze?) for this method.
val n = v.size
v match {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala
new file mode 100644
index 0000000000..652f3adb98
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.optim
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.optim.WeightedLeastSquares.Instance
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.rdd.RDD
+
+class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+ private var instances: RDD[Instance] = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ /*
+ R code:
+
+ A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2)
+ b <- c(17, 19, 23, 29)
+ w <- c(1, 2, 3, 4)
+ */
+ instances = sc.parallelize(Seq(
+ Instance(1.0, Vectors.dense(0.0, 5.0).toSparse, 17.0),
+ Instance(2.0, Vectors.dense(1.0, 7.0), 19.0),
+ Instance(3.0, Vectors.dense(2.0, 11.0), 23.0),
+ Instance(4.0, Vectors.dense(3.0, 13.0), 29.0)
+ ), 2)
+ }
+
+ test("WLS against lm") {
+ /*
+ R code:
+
+ df <- as.data.frame(cbind(A, b))
+ for (formula in c(b ~ . -1, b ~ .)) {
+ model <- lm(formula, data=df, weights=w)
+ print(as.vector(coef(model)))
+ }
+
+ [1] -3.727121 3.009983
+ [1] 18.08 6.08 -0.60
+ */
+
+ val expected = Seq(
+ Vectors.dense(0.0, -3.727121, 3.009983),
+ Vectors.dense(18.08, 6.08, -0.60))
+
+ var idx = 0
+ for (fitIntercept <- Seq(false, true)) {
+ val wls = new WeightedLeastSquares(
+ fitIntercept, regParam = 0.0, standardizeFeatures = false, standardizeLabel = false)
+ .fit(instances)
+ val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1))
+ assert(actual ~== expected(idx) absTol 1e-4)
+ idx += 1
+ }
+ }
+
+ test("WLS against glmnet") {
+ /*
+ R code:
+
+ library(glmnet)
+
+ for (intercept in c(FALSE, TRUE)) {
+ for (lambda in c(0.0, 0.1, 1.0)) {
+ for (standardize in c(FALSE, TRUE)) {
+ model <- glmnet(A, b, weights=w, intercept=intercept, lambda=lambda,
+ standardize=standardize, alpha=0, thresh=1E-14)
+ print(as.vector(coef(model)))
+ }
+ }
+ }
+
+ [1] 0.000000 -3.727117 3.009982
+ [1] 0.000000 -3.727117 3.009982
+ [1] 0.000000 -3.307532 2.924206
+ [1] 0.000000 -2.914790 2.840627
+ [1] 0.000000 -1.526575 2.558158
+ [1] 0.00000000 0.06984238 2.20488344
+ [1] 18.0799727 6.0799832 -0.5999941
+ [1] 18.0799727 6.0799832 -0.5999941
+ [1] 13.5356178 3.2714044 0.3770744
+ [1] 14.064629 3.565802 0.269593
+ [1] 10.1238013 0.9708569 1.1475466
+ [1] 13.1860638 2.1761382 0.6213134
+ */
+
+ val expected = Seq(
+ Vectors.dense(0.0, -3.727117, 3.009982),
+ Vectors.dense(0.0, -3.727117, 3.009982),
+ Vectors.dense(0.0, -3.307532, 2.924206),
+ Vectors.dense(0.0, -2.914790, 2.840627),
+ Vectors.dense(0.0, -1.526575, 2.558158),
+ Vectors.dense(0.0, 0.06984238, 2.20488344),
+ Vectors.dense(18.0799727, 6.0799832, -0.5999941),
+ Vectors.dense(18.0799727, 6.0799832, -0.5999941),
+ Vectors.dense(13.5356178, 3.2714044, 0.3770744),
+ Vectors.dense(14.064629, 3.565802, 0.269593),
+ Vectors.dense(10.1238013, 0.9708569, 1.1475466),
+ Vectors.dense(13.1860638, 2.1761382, 0.6213134))
+
+ var idx = 0
+ for (fitIntercept <- Seq(false, true);
+ regParam <- Seq(0.0, 0.1, 1.0);
+ standardizeFeatures <- Seq(false, true)) {
+ val wls = new WeightedLeastSquares(
+ fitIntercept, regParam, standardizeFeatures, standardizeLabel = true)
+ .fit(instances)
+ val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1))
+ assert(actual ~== expected(idx) absTol 1e-4)
+ idx += 1
+ }
+ }
+}