aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2016-11-08 12:58:29 -0800
committerJoseph K. Bradley <joseph@databricks.com>2016-11-08 12:58:29 -0800
commit26e1c53aceee37e3687a372ff6c6f05463fd8a94 (patch)
tree8cb7c651a72b4e5484bf345890c207e3b1ae4458 /mllib/src/main
parent245e5a2f80e3195b7f8a38b480b29bfc23af66bf (diff)
downloadspark-26e1c53aceee37e3687a372ff6c6f05463fd8a94.tar.gz
spark-26e1c53aceee37e3687a372ff6c6f05463fd8a94.tar.bz2
spark-26e1c53aceee37e3687a372ff6c6f05463fd8a94.zip
[SPARK-17748][ML] Minor cleanups to one-pass linear regression with elastic net
## What changes were proposed in this pull request? * Made SingularMatrixException private ml * WeightedLeastSquares: Changed to allow tol >= 0 instead of only tol > 0 ## How was this patch tested? existing tests Author: Joseph K. Bradley <joseph@databricks.com> Closes #15779 from jkbradley/wls-cleanups.
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/optim/NormalEquationSolver.scala9
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala22
3 files changed, 23 insertions, 12 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/NormalEquationSolver.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/NormalEquationSolver.scala
index 2f5299b010..96fd0d18b5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/optim/NormalEquationSolver.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/optim/NormalEquationSolver.scala
@@ -16,9 +16,10 @@
*/
package org.apache.spark.ml.optim
+import scala.collection.mutable
+
import breeze.linalg.{DenseVector => BDV}
import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN}
-import scala.collection.mutable
import org.apache.spark.ml.linalg.{BLAS, DenseVector, Vectors}
import org.apache.spark.mllib.linalg.CholeskyDecomposition
@@ -57,7 +58,7 @@ private[ml] sealed trait NormalEquationSolver {
*/
private[ml] class CholeskySolver extends NormalEquationSolver {
- def solve(
+ override def solve(
bBar: Double,
bbBar: Double,
abBar: DenseVector,
@@ -80,7 +81,7 @@ private[ml] class QuasiNewtonSolver(
tol: Double,
l1RegFunc: Option[(Int) => Double]) extends NormalEquationSolver {
- def solve(
+ override def solve(
bBar: Double,
bbBar: Double,
abBar: DenseVector,
@@ -156,7 +157,7 @@ private[ml] class QuasiNewtonSolver(
* Exception thrown when solving a linear system Ax = b for which the matrix A is non-invertible
* (singular).
*/
-class SingularMatrixException(message: String, cause: Throwable)
+private[spark] class SingularMatrixException(message: String, cause: Throwable)
extends IllegalArgumentException(message, cause) {
def this(message: String) = this(message, null)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
index 90c24e1b59..56ab967570 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
@@ -47,7 +47,7 @@ private[ml] class WeightedLeastSquaresModel(
* formulation:
*
* min,,x,z,, 1/2 sum,,i,, w,,i,, (a,,i,,^T^ x + z - b,,i,,)^2^ / sum,,i,, w,,i,,
- * + lambda / delta (1/2 (1 - alpha) sumj,, (sigma,,j,, x,,j,,)^2^
+ * + lambda / delta (1/2 (1 - alpha) sum,,j,, (sigma,,j,, x,,j,,)^2^
* + alpha sum,,j,, abs(sigma,,j,, x,,j,,)),
*
* where lambda is the regularization parameter, alpha is the ElasticNet mixing parameter,
@@ -91,7 +91,7 @@ private[ml] class WeightedLeastSquares(
require(elasticNetParam >= 0.0 && elasticNetParam <= 1.0,
s"elasticNetParam must be in [0, 1]: $elasticNetParam")
require(maxIter >= 0, s"maxIter must be a positive integer: $maxIter")
- require(tol > 0, s"tol must be greater than zero: $tol")
+ require(tol >= 0.0, s"tol must be >= 0, but was set to $tol")
/**
* Creates a [[WeightedLeastSquaresModel]] from an RDD of [[Instance]]s.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index ae876b3839..9639b07496 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -31,7 +31,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.linalg.BLAS._
-import org.apache.spark.ml.optim.{NormalEquationSolver, WeightedLeastSquares}
+import org.apache.spark.ml.optim.WeightedLeastSquares
import org.apache.spark.ml.PredictorParams
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared._
@@ -160,11 +160,13 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
/**
* Set the solver algorithm used for optimization.
* In case of linear regression, this can be "l-bfgs", "normal" and "auto".
- * "l-bfgs" denotes Limited-memory BFGS which is a limited-memory quasi-Newton
- * optimization method. "normal" denotes using Normal Equation as an analytical
- * solution to the linear regression problem.
- * The default value is "auto" which means that the solver algorithm is
- * selected automatically.
+ * - "l-bfgs" denotes Limited-memory BFGS which is a limited-memory quasi-Newton
+ * optimization method.
+ * - "normal" denotes using Normal Equation as an analytical solution to the linear regression
+ * problem. This solver is limited to [[LinearRegression.MAX_FEATURES_FOR_NORMAL_SOLVER]].
+ * - "auto" (default) means that the solver algorithm is selected automatically.
+ * The Normal Equations solver will be used when possible, but this will automatically fall
+ * back to iterative optimization methods when needed.
*
* @group setParam
*/
@@ -404,6 +406,14 @@ object LinearRegression extends DefaultParamsReadable[LinearRegression] {
@Since("1.6.0")
override def load(path: String): LinearRegression = super.load(path)
+
+ /**
+ * When using [[LinearRegression.solver]] == "normal", the solver must limit the number of
+ * features to at most this number. The entire covariance matrix X^T^X will be collected
+ * to the driver. This limit helps prevent memory overflow errors.
+ */
+ @Since("2.1.0")
+ val MAX_FEATURES_FOR_NORMAL_SOLVER: Int = WeightedLeastSquares.MAX_NUM_FEATURES
}
/**