aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xdocs/_layouts/global.html5
-rw-r--r--docs/mllib-classification-regression.md294
-rw-r--r--docs/mllib-optimization.md164
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala23
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala42
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala61
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala31
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala30
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala17
9 files changed, 535 insertions, 132 deletions
diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html
index b65686c0b1..7114e1f5dd 100755
--- a/docs/_layouts/global.html
+++ b/docs/_layouts/global.html
@@ -196,6 +196,11 @@
</body>
<!-- MathJax Section -->
+ <script type="text/x-mathjax-config">
+ MathJax.Hub.Config({
+ TeX: { equationNumbers: { autoNumber: "AMS" } }
+ });
+ </script>
<script type="text/javascript"
src="http://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script>
diff --git a/docs/mllib-classification-regression.md b/docs/mllib-classification-regression.md
index edb9338907..18a3e8e075 100644
--- a/docs/mllib-classification-regression.md
+++ b/docs/mllib-classification-regression.md
@@ -7,45 +7,256 @@ title: MLlib - Classification and Regression
{:toc}
-# Binary Classification
-
-Binary classification is a supervised learning problem in which we want to
-classify entities into one of two distinct categories or labels, e.g.,
-predicting whether or not emails are spam. This problem involves executing a
-learning *Algorithm* on a set of *labeled* examples, i.e., a set of entities
-represented via (numerical) features along with underlying category labels.
-The algorithm returns a trained *Model* that can predict the label for new
-entities for which the underlying label is unknown.
-
-MLlib currently supports two standard model families for binary classification,
-namely [Linear Support Vector Machines
-(SVMs)](http://en.wikipedia.org/wiki/Support_vector_machine) and [Logistic
-Regression](http://en.wikipedia.org/wiki/Logistic_regression), along with [L1
-and L2 regularized](http://en.wikipedia.org/wiki/Regularization_(mathematics))
-variants of each model family. The training algorithms all leverage an
-underlying gradient descent primitive (described
-[below](#gradient-descent-primitive)), and take as input a regularization
-parameter (*regParam*) along with various parameters associated with gradient
-descent (*stepSize*, *numIterations*, *miniBatchFraction*).
+`\[
+\newcommand{\R}{\mathbb{R}}
+\newcommand{\E}{\mathbb{E}}
+\newcommand{\x}{\mathbf{x}}
+\newcommand{\y}{\mathbf{y}}
+\newcommand{\wv}{\mathbf{w}}
+\newcommand{\av}{\mathbf{\alpha}}
+\newcommand{\bv}{\mathbf{b}}
+\newcommand{\N}{\mathbb{N}}
+\newcommand{\id}{\mathbf{I}}
+\newcommand{\ind}{\mathbf{1}}
+\newcommand{\0}{\mathbf{0}}
+\newcommand{\unit}{\mathbf{e}}
+\newcommand{\one}{\mathbf{1}}
+\newcommand{\zero}{\mathbf{0}}
+\]`
+
+
+# Supervised Machine Learning
+Supervised machine learning is the setting where we are given a set of training data examples
+`$\{\x_i\}$`, each example `$\x_i$` coming with a corresponding label `$y_i$`.
+Given the training data `$\{(\x_i,y_i)\}$`, we want to learn a function to predict these labels.
+The two most well known classes of methods are
+[classification](http://en.wikipedia.org/wiki/Statistical_classification), and
+[regression](http://en.wikipedia.org/wiki/Regression_analysis).
+In classification, the label is a category (e.g. whether or not emails are spam), whereas in
+regression, the label is real value, and we want our prediction to be as close to the true value
+as possible.
+
+Supervised Learning involves executing a learning *Algorithm* on a set of *labeled* training
+examples. The algorithm returns a trained *Model* (such as for example a linear function) that
+can predict the label for new data examples for which the label is unknown.
+
+
+## Mathematical Formulation
+Many standard *machine learning* methods can be formulated as a convex optimization problem, i.e.
+the task of finding a minimizer of a convex function `$f$` that depends on a variable vector
+`$\wv$` (called `weights` in the code), which has `$d$` entries.
+Formally, we can write this as the optimization problem `$\min_{\wv \in\R^d} \; f(\wv)$`, where
+the objective function is of the form
+`\begin{equation}
+ f(\wv) :=
+ \lambda\, R(\wv) +
+ \frac1n \sum_{i=1}^n L(\wv;\x_i,y_i)
+ \label{eq:regPrimal}
+ \ .
+\end{equation}`
+Here the vectors `$\x_i\in\R^d$` are the training data examples, for `$1\le i\le n$`, and
+`$y_i\in\R$` are their corresponding labels, which we want to predict.
+
+The objective function `$f$` has two parts:
+The *loss-function* measures the error of the model on the training data. The loss-function
+`$L(\wv;.)$` must be a convex function in `$\wv$`.
+The purpose of the [regularizer](http://en.wikipedia.org/wiki/Regularization_(mathematics)) is to
+encourage simple models, by punishing the complexity of the model `$\wv$`, in order to e.g. avoid
+over-fitting.
+Usually, the regularizer `$R(.)$` is chosen as either the standard (Euclidean) L2-norm, `$R(\wv)
+:= \frac{1}{2}\|\wv\|^2$`, or the L1-norm, `$R(\wv) := \|\wv\|_1$`, see
+[below](#using-different-regularizers) for more details.
+
+The fixed regularization parameter `$\lambda\ge0$` (`regParam` in the code) defines the trade-off
+between the two goals of small loss and small model complexity.
+
+
+## Binary Classification
+
+**Input:** Datapoints `$\x_i\in\R^{d}$`, labels `$y_i\in\{+1,-1\}$`, for `$1\le i\le n$`.
+
+**Distributed Datasets.**
+For all currently implemented optimization methods for classification, the data must be
+distributed between the worker machines *by examples*. Every machine holds a consecutive block of
+the `$n$` example/label pairs `$(\x_i,y_i)$`.
+In other words, the input distributed dataset
+([RDD](scala-programming-guide.html#resilient-distributed-datasets-rdds)) must be the set of
+vectors `$\x_i\in\R^d$`.
+
+### Support Vector Machine
+The linear [Support Vector Machine (SVM)](http://en.wikipedia.org/wiki/Support_vector_machine)
+has become a standard choice for classification tasks.
+Here the loss function in formulation `$\eqref{eq:regPrimal}$` is given by the hinge-loss
+`\[
+L(\wv;\x_i,y_i) := \max \{0, 1-y_i \wv^T \x_i \} \ .
+\]`
+
+By default, SVMs are trained with an L2 regularization, which gives rise to the large-margin
+interpretation if these classifiers. We also support alternative L1 regularization. In this case,
+the primal optimization problem becomes an [LP](http://en.wikipedia.org/wiki/Linear_programming).
+
+### Logistic Regression
+Despite its name, [Logistic Regression](http://en.wikipedia.org/wiki/Logistic_regression) is a
+binary classification method, again when the labels are given by binary values
+`$y_i\in\{+1,-1\}$`. The logistic loss function in formulation `$\eqref{eq:regPrimal}$` is
+defined as
+`\[
+L(\wv;\x_i,y_i) := \log(1+\exp( -y_i \wv^T \x_i)) \ .
+\]`
+
+
+## Linear Regression (Least Squares, Lasso and Ridge Regression)
+
+**Input:** Data matrix `$A\in\R^{n\times d}$`, right hand side vector `$\y\in\R^n$`.
+
+**Distributed Datasets.**
+For all currently implemented optimization methods for regression, the data matrix
+`$A\in\R^{n\times d}$` must be distributed between the worker machines *by rows* of `$A$`. In
+other words, the input distributed dataset
+([RDD](scala-programming-guide.html#resilient-distributed-datasets-rdds)) must be the set of the
+`$n$` rows `$A_{i:}$` of `$A$`.
+
+Least Squares Regression refers to the setting where we try to fit a vector `$\y\in\R^n$` by
+linear combination of our observed data `$A\in\R^{n\times d}$`, which is given as a matrix.
+
+It comes in 3 flavors:
+
+### Least Squares
+Plain old [least squares](http://en.wikipedia.org/wiki/Least_squares) linear regression is the
+problem of minimizing
+ `\[ f_{\text{LS}}(\wv) := \frac1n \|A\wv-\y\|_2^2 \ . \]`
+
+### Lasso
+The popular [Lasso](http://en.wikipedia.org/wiki/Lasso_(statistics)#Lasso_method) (alternatively
+also known as `$L_1$`-regularized least squares regression) is given by
+ `\[ f_{\text{Lasso}}(\wv) := \frac1n \|A\wv-\y\|_2^2 + \lambda \|\wv\|_1 \ . \]`
+
+### Ridge Regression
+[Ridge regression](http://en.wikipedia.org/wiki/Ridge_regression) uses the same loss function but
+with a L2 regularizer term:
+ `\[ f_{\text{Ridge}}(\wv) := \frac1n \|A\wv-\y\|_2^2 + \frac{\lambda}{2}\|\wv\|^2 \ . \]`
+
+**Loss Function.**
+For all 3, the loss function (i.e. the measure of model fit) is given by the squared deviations
+from the right hand side `$\y$`.
+`\[
+\frac1n \|A\wv-\y\|_2^2
+= \frac1n \sum_{i=1}^n (A_{i:} \wv - y_i )^2
+= \frac1n \sum_{i=1}^n L(\wv;\x_i,y_i)
+\]`
+This is also known as the [mean squared error](http://en.wikipedia.org/wiki/Mean_squared_error).
+In our generic problem formulation `$\eqref{eq:regPrimal}$`, this means the loss function is
+`$L(\wv;\x_i,y_i) := (A_{i:} \wv - y_i )^2$`, each depending only on a single row `$A_{i:}$` of
+the data matrix `$A$`.
+
+
+## Using Different Regularizers
+
+As we have mentioned above, the purpose of *regularizer* in `$\eqref{eq:regPrimal}$` is to
+encourage simple models, by punishing the complexity of the model `$\wv$`, in order to e.g. avoid
+over-fitting.
+All machine learning methods for classification and regression that we have mentioned above are
+of interest for different types of regularization, the 3 most common ones being
+
+* **L2-Regularization.**
+`$R(\wv) := \frac{1}{2}\|\wv\|^2$`.
+This regularizer is most commonly used for SVMs, logistic regression and ridge regression.
+
+* **L1-Regularization.**
+`$R(\wv) := \|\wv\|_1$`. The L1 norm `$\|\wv\|_1$` is the sum of the absolut values of the
+entries of a vector `$\wv$`.
+This regularizer is most commonly used for sparse methods, and feature selection, such as the
+Lasso.
+
+* **Non-Regularized.**
+`$R(\wv):=0$`.
+Of course we can also train the models without any regularization, or equivalently by setting the
+regularization parameter `$\lambda:=0$`.
+
+The optimization problems of the form `$\eqref{eq:regPrimal}$` with convex regularizers such as
+the 3 mentioned here can be conveniently optimized with gradient descent type methods (such as
+SGD) which is implemented in `MLlib` currently, and explained in the next section.
+
+
+# Optimization Methods Working on the Primal Formulation
+
+**Stochastic subGradient Descent (SGD).**
+For optimization objectives `$f$` written as a sum, *stochastic subgradient descent (SGD)* can be
+an efficient choice of optimization method, as we describe in the <a
+href="mllib-optimization.html">optimization section</a> in more detail.
+Because all methods considered here fit into the optimization formulation
+`$\eqref{eq:regPrimal}$`, this is especially natural, because the loss is written as an average
+of the individual losses coming from each datapoint.
+
+Picking one datapoint `$i\in[1..n]$` uniformly at random, we obtain a stochastic subgradient of
+`$\eqref{eq:regPrimal}$`, with respect to `$\wv$` as follows:
+`\[
+f'_{\wv,i} := L'_{\wv,i} + \lambda\, R'_\wv \ ,
+\]`
+where `$L'_{\wv,i} \in \R^d$` is a subgradient of the part of the loss function determined by the
+`$i$`-th datapoint, that is `$L'_{\wv,i} \in \frac{\partial}{\partial \wv} L(\wv;\x_i,y_i)$`.
+Furthermore, `$R'_\wv$` is a subgradient of the regularizer `$R(\wv)$`, i.e. `$R'_\wv \in
+\frac{\partial}{\partial \wv} R(\wv)$`. The term `$R'_\wv$` does not depend on which random
+datapoint is picked.
+
+
+
+**Gradients.**
+The following table summarizes the gradients (or subgradients) of all loss functions and
+regularizers that we currently support:
+
+<table class="table">
+ <thead>
+ <tr><th></th><th>Function</th><th>Stochastic (Sub)Gradient</th></tr>
+ </thead>
+ <tbody>
+ <tr>
+ <td>SVM Hinge Loss</td><td>$L(\wv;\x_i,y_i) := \max \{0, 1-y_i \wv^T \x_i \}$</td>
+ <td>$L'_{\wv,i} = \begin{cases}-y_i \x_i & \text{if $y_i \wv^T \x_i <1$}, \\ 0 &
+\text{otherwise}.\end{cases}$</td>
+ </tr>
+ <tr>
+ <td>Logistic Loss</td><td>$L(\wv;\x_i,y_i) := \log(1+\exp( -y_i \wv^T \x_i))$</td>
+ <td>$L'_{\wv,i} = -y_i \x_i \left(1-\frac1{1+\exp(-y_i \wv^T \x_i)} \right)$</td>
+ </tr>
+ <tr>
+ <td>Least Squares Loss</td><td>$L(\wv;\x_i,y_i) := (A_{i:} \wv - y_i)^2$</td>
+ <td>$L'_{\wv,i} = 2 A_{i:}^T (A_{i:} \wv - y_i)$</td>
+ </tr>
+ <tr>
+ <td>Non-Regularized</td><td>$R(\wv) := 0$</td><td>$R'_\wv = \0$</td>
+ </tr>
+ <tr>
+ <td>L2 Regularizer</td><td>$R(\wv) := \frac{1}{2}\|\wv\|^2$</td><td>$R'_\wv = \wv$</td>
+ </tr>
+ <tr>
+ <td>L1 Regularizer</td><td>$R(\wv) := \|\wv\|_1$</td><td>$R'_\wv = \mathop{sign}(\wv)$</td>
+ </tr>
+ </tbody>
+</table>
+
+Here `$\mathop{sign}(\wv)$` is the vector consisting of the signs (`$\pm1$`) of all the entries
+of `$\wv$`.
+Also, note that `$A_{i:} \in \R^d$` is a row-vector, but the gradient is a column vector.
+
+
+
+## Implementation in MLlib
+
+For both classification and regression, `MLlib` implements a simple distributed version of
+stochastic subgradient descent (SGD), building on the underlying gradient descent primitive (as
+described in the
+<a href="mllib-optimization.html">optimization section</a>).
+All provided algorithms take as input a regularization parameter (`regParam`) along with various
+parameters associated with stochastic gradient
+descent (`stepSize`, `numIterations`, `miniBatchFraction`).
+For each of them, we support all 3 possible regularizations (none, L1 or L2).
Available algorithms for binary classification:
* [SVMWithSGD](api/mllib/index.html#org.apache.spark.mllib.classification.SVMWithSGD)
* [LogisticRegressionWithSGD](api/mllib/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithSGD)
-# Linear Regression
-
-Linear regression is another classical supervised learning setting. In this
-problem, each entity is associated with a real-valued label (as opposed to a
-binary label as in binary classification), and we want to predict labels as
-closely as possible given numerical features representing entities. MLlib
-supports linear regression as well as L1
-([lasso](http://en.wikipedia.org/wiki/Lasso_(statistics)#Lasso_method)) and L2
-([ridge](http://en.wikipedia.org/wiki/Ridge_regression)) regularized variants.
-The regression algorithms in MLlib also leverage the underlying gradient
-descent primitive (described [below](#gradient-descent-primitive)), and have
-the same parameters as the binary classification algorithms described above.
-
Available algorithms for linear regression:
* [LinearRegressionWithSGD](api/mllib/index.html#org.apache.spark.mllib.regression.LinearRegressionWithSGD)
@@ -59,6 +270,9 @@ gradient descent primitive in MLlib, see the
* [GradientDescent](api/mllib/index.html#org.apache.spark.mllib.optimization.GradientDescent)
+
+
+
# Usage in Scala
Following code snippets can be executed in `spark-shell`.
@@ -115,9 +329,10 @@ val modelL1 = svmAlg.run(parsedData)
{% endhighlight %}
## Linear Regression
-The following example demonstrate how to load training data, parse it as an RDD of LabeledPoint. The
-example then uses LinearRegressionWithSGD to build a simple linear model to predict label values. We
-compute the Mean Squared Error at the end to evaluate
+
+The following example demonstrate how to load training data, parse it as an RDD of LabeledPoint.
+The example then uses LinearRegressionWithSGD to build a simple linear model to predict label
+values. We compute the Mean Squared Error at the end to evaluate
[goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit)
{% highlight scala %}
@@ -157,6 +372,7 @@ Spark Java API uses a separate `JavaRDD` class. You can convert a Java RDD to a
calling `.rdd()` on your `JavaRDD` object.
# Usage in Python
+
Following examples can be tested in the PySpark shell.
## Binary Classification
@@ -182,9 +398,9 @@ print("Training Error = " + str(trainErr))
{% endhighlight %}
## Linear Regression
-The following example demonstrate how to load training data, parse it as an RDD of LabeledPoint. The
-example then uses LinearRegressionWithSGD to build a simple linear model to predict label values. We
-compute the Mean Squared Error at the end to evaluate
+The following example demonstrate how to load training data, parse it as an RDD of LabeledPoint.
+The example then uses LinearRegressionWithSGD to build a simple linear model to predict label
+values. We compute the Mean Squared Error at the end to evaluate
[goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit)
{% highlight python %}
diff --git a/docs/mllib-optimization.md b/docs/mllib-optimization.md
index 428284ef29..396b98d52a 100644
--- a/docs/mllib-optimization.md
+++ b/docs/mllib-optimization.md
@@ -6,35 +6,161 @@ title: MLlib - Optimization
* Table of contents
{:toc}
+`\[
+\newcommand{\R}{\mathbb{R}}
+\newcommand{\E}{\mathbb{E}}
+\newcommand{\x}{\mathbf{x}}
+\newcommand{\y}{\mathbf{y}}
+\newcommand{\wv}{\mathbf{w}}
+\newcommand{\av}{\mathbf{\alpha}}
+\newcommand{\bv}{\mathbf{b}}
+\newcommand{\N}{\mathbb{N}}
+\newcommand{\id}{\mathbf{I}}
+\newcommand{\ind}{\mathbf{1}}
+\newcommand{\0}{\mathbf{0}}
+\newcommand{\unit}{\mathbf{e}}
+\newcommand{\one}{\mathbf{1}}
+\newcommand{\zero}{\mathbf{0}}
+\]`
-# Gradient Descent Primitive
-[Gradient descent](http://en.wikipedia.org/wiki/Gradient_descent) (along with
-stochastic variants thereof) are first-order optimization methods that are
-well-suited for large-scale and distributed computation. Gradient descent
-methods aim to find a local minimum of a function by iteratively taking steps
-in the direction of the negative gradient of the function at the current point,
-i.e., the current parameter value. Gradient descent is included as a low-level
-primitive in MLlib, upon which various ML algorithms are developed, and has the
-following parameters:
-* *gradient* is a class that computes the stochastic gradient of the function
+# Mathematical Description
+
+## (Sub)Gradient Descent
+The simplest method to solve optimization problems of the form `$\min_{\wv \in\R^d} \; f(\wv)$`
+is [gradient descent](http://en.wikipedia.org/wiki/Gradient_descent).
+Such first-order optimization methods (including gradient descent and stochastic variants
+thereof) are well-suited for large-scale and distributed computation.
+
+Gradient descent methods aim to find a local minimum of a function by iteratively taking steps in
+the direction of steepest descent, which is the negative of the derivative (called the
+[gradient](http://en.wikipedia.org/wiki/Gradient)) of the function at the current point, i.e., at
+the current parameter value.
+If the objective function `$f$` is not differentiable at all arguments, but still convex, then a
+*subgradient*
+is the natural generalization of the gradient, and assumes the role of the step direction.
+In any case, computing a gradient or subgradient of `$f$` is expensive --- it requires a full
+pass through the complete dataset, in order to compute the contributions from all loss terms.
+
+## Stochastic (Sub)Gradient Descent (SGD)
+Optimization problems whose objective function `$f$` is written as a sum are particularly
+suitable to be solved using *stochastic subgradient descent (SGD)*.
+In our case, for the optimization formulations commonly used in <a
+href="mllib-classification-regression.html">supervised machine learning</a>,
+`\begin{equation}
+ f(\wv) :=
+ \lambda\, R(\wv) +
+ \frac1n \sum_{i=1}^n L(\wv;\x_i,y_i)
+ \label{eq:regPrimal}
+ \ .
+\end{equation}`
+this is especially natural, because the loss is written as an average of the individual losses
+coming from each datapoint.
+
+A stochastic subgradient is a randomized choice of a vector, such that in expectation, we obtain
+a true subgradient of the original objective function.
+Picking one datapoint `$i\in[1..n]$` uniformly at random, we obtain a stochastic subgradient of
+`$\eqref{eq:regPrimal}$`, with respect to `$\wv$` as follows:
+`\[
+f'_{\wv,i} := L'_{\wv,i} + \lambda\, R'_\wv \ ,
+\]`
+where `$L'_{\wv,i} \in \R^d$` is a subgradient of the part of the loss function determined by the
+`$i$`-th datapoint, that is `$L'_{\wv,i} \in \frac{\partial}{\partial \wv} L(\wv;\x_i,y_i)$`.
+Furthermore, `$R'_\wv$` is a subgradient of the regularizer `$R(\wv)$`, i.e. `$R'_\wv \in
+\frac{\partial}{\partial \wv} R(\wv)$`. The term `$R'_\wv$` does not depend on which random
+datapoint is picked.
+Clearly, in expectation over the random choice of `$i\in[1..n]$`, we have that `$f'_{\wv,i}$` is
+a subgradient of the original objective `$f$`, meaning that `$\E\left[f'_{\wv,i}\right] \in
+\frac{\partial}{\partial \wv} f(\wv)$`.
+
+Running SGD now simply becomes walking in the direction of the negative stochastic subgradient
+`$f'_{\wv,i}$`, that is
+`\begin{equation}\label{eq:SGDupdate}
+\wv^{(t+1)} := \wv^{(t)} - \gamma \; f'_{\wv,i} \ .
+\end{equation}`
+**Step-size.**
+The parameter `$\gamma$` is the step-size, which in the default implementation is chosen
+decreasing with the square root of the iteration counter, i.e. `$\gamma := \frac{s}{\sqrt{t}}$`
+in the `$t$`-th iteration, with the input parameter `$s=$ stepSize`. Note that selecting the best
+step-size for SGD methods can often be delicate in practice and is a topic of active research.
+
+**Gradients.**
+A table of (sub)gradients of the machine learning methods implemented in MLlib, is available in
+the <a href="mllib-classification-regression.html">classification and regression</a> section.
+
+
+**Proximal Updates.**
+As an alternative to just use the subgradient `$R'(\wv)$` of the regularizer in the step
+direction, an improved update for some cases can be obtained by using the proximal operator
+instead.
+For the L1-regularizer, the proximal operator is given by soft thresholding, as implemented in
+[L1Updater](api/mllib/index.html#org.apache.spark.mllib.optimization.L1Updater).
+
+
+## Update Schemes for Distributed SGD
+The SGD implementation in
+[GradientDescent](api/mllib/index.html#org.apache.spark.mllib.optimization.GradientDescent) uses
+a simple (distributed) sampling of the data examples.
+We recall that the loss part of the optimization problem `$\eqref{eq:regPrimal}$` is
+`$\frac1n \sum_{i=1}^n L(\wv;\x_i,y_i)$`, and therefore `$\frac1n \sum_{i=1}^n L'_{\wv,i}$` would
+be the true (sub)gradient.
+Since this would require access to the full data set, the parameter `miniBatchFraction` specifies
+which fraction of the full data to use instead.
+The average of the gradients over this subset, i.e.
+`\[
+\frac1{|S|} \sum_{i\in S} L'_{\wv,i} \ ,
+\]`
+is a stochastic gradient. Here `$S$` is the sampled subset of size `$|S|=$ miniBatchFraction
+$\cdot n$`.
+
+In each iteration, the sampling over the distributed dataset
+([RDD](scala-programming-guide.html#resilient-distributed-datasets-rdds)), as well as the
+computation of the sum of the partial results from each worker machine is performed by the
+standard spark routines.
+
+If the fraction of points `miniBatchFraction` is set to 1 (default), then the resulting step in
+each iteration is exact (sub)gradient descent. In this case there is no randomness and no
+variance in the used step directions.
+On the other extreme, if `miniBatchFraction` is chosen very small, such that only a single point
+is sampled, i.e. `$|S|=$ miniBatchFraction $\cdot n = 1$`, then the algorithm is equivalent to
+standard SGD. In that case, the step direction depends from the uniformly random sampling of the
+point.
+
+
+
+# Implementation in MLlib
+
+Gradient descent methods including stochastic subgradient descent (SGD) as
+included as a low-level primitive in `MLlib`, upon which various ML algorithms
+are developed, see the
+<a href="mllib-classification-regression.html">classification and regression</a>
+section for example.
+
+The SGD method
+[GradientDescent.runMiniBatchSGD](api/mllib/index.html#org.apache.spark.mllib.optimization.GradientDescent)
+has the following parameters:
+
+* `gradient` is a class that computes the stochastic gradient of the function
being optimized, i.e., with respect to a single training example, at the
current parameter value. MLlib includes gradient classes for common loss
functions, e.g., hinge, logistic, least-squares. The gradient class takes as
input a training example, its label, and the current parameter value.
-* *updater* is a class that updates weights in each iteration of gradient
-descent. MLlib includes updaters for cases without regularization, as well as
+* `updater` is a class that performs the actual gradient descent step, i.e.
+updating the weights in each iteration, for a given gradient of the loss part.
+The updater is also responsible to perform the update from the regularization
+part. MLlib includes updaters for cases without regularization, as well as
L1 and L2 regularizers.
-* *stepSize* is a scalar value denoting the initial step size for gradient
+* `stepSize` is a scalar value denoting the initial step size for gradient
descent. All updaters in MLlib use a step size at the t-th step equal to
-stepSize / sqrt(t).
-* *numIterations* is the number of iterations to run.
-* *regParam* is the regularization parameter when using L1 or L2 regularization.
-* *miniBatchFraction* is the fraction of the data used to compute the gradient
-at each iteration.
+`stepSize $/ \sqrt{t}$`.
+* `numIterations` is the number of iterations to run.
+* `regParam` is the regularization parameter when using L1 or L2 regularization.
+* `miniBatchFraction` is the fraction of the total data that is sampled in
+each iteration, to compute the gradient direction.
Available algorithms for gradient descent:
-* [GradientDescent](api/mllib/index.html#org.apache.spark.mllib.optimization.GradientDescent)
+* [GradientDescent.runMiniBatchSGD](api/mllib/index.html#org.apache.spark.mllib.optimization.GradientDescent)
+
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala
index c590492e7a..82124703da 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala
@@ -24,10 +24,10 @@ import org.jblas.DoubleMatrix
*/
abstract class Gradient extends Serializable {
/**
- * Compute the gradient and loss given features of a single data point.
+ * Compute the gradient and loss given the features of a single data point.
*
- * @param data - Feature values for one data point. Column matrix of size nx1
- * where n is the number of features.
+ * @param data - Feature values for one data point. Column matrix of size dx1
+ * where d is the number of features.
* @param label - Label for this data item.
* @param weights - Column matrix containing weights for every feature.
*
@@ -40,7 +40,8 @@ abstract class Gradient extends Serializable {
}
/**
- * Compute gradient and loss for a logistic loss function.
+ * Compute gradient and loss for a logistic loss function, as used in binary classification.
+ * See also the documentation for the precise formulation.
*/
class LogisticGradient extends Gradient {
override def compute(data: DoubleMatrix, label: Double, weights: DoubleMatrix):
@@ -61,22 +62,26 @@ class LogisticGradient extends Gradient {
}
/**
- * Compute gradient and loss for a Least-squared loss function.
+ * Compute gradient and loss for a Least-squared loss function, as used in linear regression.
+ * This is correct for the averaged least squares loss function (mean squared error)
+ * L = 1/n ||A weights-y||^2
+ * See also the documentation for the precise formulation.
*/
-class SquaredGradient extends Gradient {
+class LeastSquaresGradient extends Gradient {
override def compute(data: DoubleMatrix, label: Double, weights: DoubleMatrix):
(DoubleMatrix, Double) = {
val diff: Double = data.dot(weights) - label
- val loss = 0.5 * diff * diff
- val gradient = data.mul(diff)
+ val loss = diff * diff
+ val gradient = data.mul(2.0 * diff)
(gradient, loss)
}
}
/**
- * Compute gradient and loss for a Hinge loss function.
+ * Compute gradient and loss for a Hinge loss function, as used in SVM binary classification.
+ * See also the documentation for the precise formulation.
* NOTE: This assumes that the labels are {0,1}
*/
class HingeGradient extends Gradient {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
index cd80134737..8e87b98bac 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
@@ -17,9 +17,8 @@
package org.apache.spark.mllib.optimization
-import org.apache.spark.{Logging, SparkContext}
+import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
-import org.apache.spark.SparkContext._
import org.jblas.DoubleMatrix
@@ -39,7 +38,8 @@ class GradientDescent(var gradient: Gradient, var updater: Updater)
private var miniBatchFraction: Double = 1.0
/**
- * Set the step size per-iteration of SGD. Default 1.0.
+ * Set the initial step size of SGD for the first step. Default 1.0.
+ * In subsequent steps, the step size will decrease with stepSize/sqrt(t)
*/
def setStepSize(step: Double): this.type = {
this.stepSize = step
@@ -47,7 +47,8 @@ class GradientDescent(var gradient: Gradient, var updater: Updater)
}
/**
- * Set fraction of data to be used for each SGD iteration. Default 1.0.
+ * Set fraction of data to be used for each SGD iteration.
+ * Default 1.0 (corresponding to deterministic/classical gradient descent)
*/
def setMiniBatchFraction(fraction: Double): this.type = {
this.miniBatchFraction = fraction
@@ -63,7 +64,7 @@ class GradientDescent(var gradient: Gradient, var updater: Updater)
}
/**
- * Set the regularization parameter used for SGD. Default 0.0.
+ * Set the regularization parameter. Default 0.0.
*/
def setRegParam(regParam: Double): this.type = {
this.regParam = regParam
@@ -71,7 +72,8 @@ class GradientDescent(var gradient: Gradient, var updater: Updater)
}
/**
- * Set the gradient function to be used for SGD.
+ * Set the gradient function (of the loss function of one single data example)
+ * to be used for SGD.
*/
def setGradient(gradient: Gradient): this.type = {
this.gradient = gradient
@@ -80,7 +82,9 @@ class GradientDescent(var gradient: Gradient, var updater: Updater)
/**
- * Set the updater function to be used for SGD.
+ * Set the updater function to actually perform a gradient step in a given direction.
+ * The updater is responsible to perform the update from the regularization term as well,
+ * and therefore determines what kind or regularization is used, if any.
*/
def setUpdater(updater: Updater): this.type = {
this.updater = updater
@@ -107,20 +111,26 @@ class GradientDescent(var gradient: Gradient, var updater: Updater)
// Top-level method to run gradient descent.
object GradientDescent extends Logging {
/**
- * Run gradient descent in parallel using mini batches.
+ * Run stochastic gradient descent (SGD) in parallel using mini batches.
+ * In each iteration, we sample a subset (fraction miniBatchFraction) of the total data
+ * in order to compute a gradient estimate.
+ * Sampling, and averaging the subgradients over this subset is performed using one standard
+ * spark map-reduce in each iteration.
*
- * @param data - Input data for SGD. RDD of form (label, [feature values]).
- * @param gradient - Gradient object that will be used to compute the gradient.
- * @param updater - Updater object that will be used to update the model.
- * @param stepSize - stepSize to be used during update.
+ * @param data - Input data for SGD. RDD of the set of data examples, each of
+ * the form (label, [feature values]).
+ * @param gradient - Gradient object (used to compute the gradient of the loss function of
+ * one single data example)
+ * @param updater - Updater function to actually perform a gradient step in a given direction.
+ * @param stepSize - initial step size for the first step
* @param numIterations - number of iterations that SGD should be run.
* @param regParam - regularization parameter
* @param miniBatchFraction - fraction of the input data set that should be used for
* one iteration of SGD. Default value 1.0.
*
* @return A tuple containing two elements. The first element is a column matrix containing
- * weights for every feature, and the second element is an array containing the stochastic
- * loss computed for every iteration.
+ * weights for every feature, and the second element is an array containing the
+ * stochastic loss computed for every iteration.
*/
def runMiniBatchSGD(
data: RDD[(Double, Array[Double])],
@@ -142,6 +152,8 @@ object GradientDescent extends Logging {
var regVal = 0.0
for (i <- 1 to numIterations) {
+ // Sample a subset (fraction miniBatchFraction) of the total data
+ // compute and sum up the subgradients on this subset (this is one map-reduce)
val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42 + i).map {
case (y, features) =>
val featuresCol = new DoubleMatrix(features.length, 1, features:_*)
@@ -160,7 +172,7 @@ object GradientDescent extends Logging {
regVal = update._2
}
- logInfo("GradientDescent finished. Last 10 stochastic losses %s".format(
+ logInfo("GradientDescent.runMiniBatchSGD finished. Last 10 stochastic losses %s".format(
stochasticLossHistory.takeRight(10).mkString(", ")))
(weights.toArray, stochasticLossHistory.toArray)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala
index 37124f261e..889a03e3e6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala
@@ -21,16 +21,25 @@ import scala.math._
import org.jblas.DoubleMatrix
/**
- * Class used to update weights used in Gradient Descent.
+ * Class used to perform steps (weight update) using Gradient Descent methods.
+ *
+ * For general minimization problems, or for regularized problems of the form
+ * min L(w) + regParam * R(w),
+ * the compute function performs the actual update step, when given some
+ * (e.g. stochastic) gradient direction for the loss L(w),
+ * and a desired step-size (learning rate).
+ *
+ * The updater is responsible to also perform the update coming from the
+ * regularization term R(w) (if any regularization is used).
*/
abstract class Updater extends Serializable {
/**
* Compute an updated value for weights given the gradient, stepSize, iteration number and
- * regularization parameter. Also returns the regularization value computed using the
- * *updated* weights.
+ * regularization parameter. Also returns the regularization value regParam * R(w)
+ * computed using the *updated* weights.
*
- * @param weightsOld - Column matrix of size nx1 where n is the number of features.
- * @param gradient - Column matrix of size nx1 where n is the number of features.
+ * @param weightsOld - Column matrix of size dx1 where d is the number of features.
+ * @param gradient - Column matrix of size dx1 where d is the number of features.
* @param stepSize - step size across iterations
* @param iter - Iteration number
* @param regParam - Regularization parameter
@@ -43,23 +52,29 @@ abstract class Updater extends Serializable {
}
/**
- * A simple updater that adaptively adjusts the learning rate the
- * square root of the number of iterations. Does not perform any regularization.
+ * A simple updater for gradient descent *without* any regularization.
+ * Uses a step-size decreasing with the square root of the number of iterations.
*/
class SimpleUpdater extends Updater {
override def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix,
stepSize: Double, iter: Int, regParam: Double): (DoubleMatrix, Double) = {
val thisIterStepSize = stepSize / math.sqrt(iter)
- val normGradient = gradient.mul(thisIterStepSize)
- (weightsOld.sub(normGradient), 0)
+ val step = gradient.mul(thisIterStepSize)
+ (weightsOld.sub(step), 0)
}
}
/**
- * Updater that adjusts learning rate and performs L1 regularization.
+ * Updater for L1 regularized problems.
+ * R(w) = ||w||_1
+ * Uses a step-size decreasing with the square root of the number of iterations.
+
+ * Instead of subgradient of the regularizer, the proximal operator for the
+ * L1 regularization is applied after the gradient step. This is known to
+ * result in better sparsity of the intermediate solution.
*
- * The corresponding proximal operator used is the soft-thresholding function.
- * That is, each weight component is shrunk towards 0 by shrinkageVal.
+ * The corresponding proximal operator for the L1 norm is the soft-thresholding
+ * function. That is, each weight component is shrunk towards 0 by shrinkageVal.
*
* If w > shrinkageVal, set weight component to w-shrinkageVal.
* If w < -shrinkageVal, set weight component to w+shrinkageVal.
@@ -71,10 +86,10 @@ class L1Updater extends Updater {
override def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix,
stepSize: Double, iter: Int, regParam: Double): (DoubleMatrix, Double) = {
val thisIterStepSize = stepSize / math.sqrt(iter)
- val normGradient = gradient.mul(thisIterStepSize)
+ val step = gradient.mul(thisIterStepSize)
// Take gradient step
- val newWeights = weightsOld.sub(normGradient)
- // Soft thresholding
+ val newWeights = weightsOld.sub(step)
+ // Apply proximal operator (soft thresholding)
val shrinkageVal = regParam * thisIterStepSize
(0 until newWeights.length).foreach { i =>
val wi = newWeights.get(i)
@@ -85,19 +100,19 @@ class L1Updater extends Updater {
}
/**
- * Updater that adjusts the learning rate and performs L2 regularization
- *
- * See, for example, explanation of gradient and loss with L2 regularization on slide 21-22
- * of <a href="http://people.cs.umass.edu/~sheldon/teaching/2012fa/ml/files/lec7-annotated.pdf">
- * these slides</a>.
+ * Updater for L2 regularized problems.
+ * R(w) = 1/2 ||w||^2
+ * Uses a step-size decreasing with the square root of the number of iterations.
*/
class SquaredL2Updater extends Updater {
override def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix,
stepSize: Double, iter: Int, regParam: Double): (DoubleMatrix, Double) = {
val thisIterStepSize = stepSize / math.sqrt(iter)
- val normGradient = gradient.mul(thisIterStepSize)
- val newWeights = weightsOld.mul(1.0 - 2.0 * thisIterStepSize * regParam).sub(normGradient)
- (newWeights, pow(newWeights.norm2, 2.0) * regParam)
+ val step = gradient.mul(thisIterStepSize)
+ // add up both updates from the gradient of the loss (= step) as well as
+ // the gradient of the regularizer (= regParam * weightsOld)
+ val newWeights = weightsOld.mul(1.0 - thisIterStepSize * regParam).sub(step)
+ (newWeights, 0.5 * pow(newWeights.norm2, 2.0) * regParam)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
index 7c41793722..fb2bc9b92a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
@@ -44,6 +44,11 @@ class LassoModel(
/**
* Train a regression model with L1-regularization using Stochastic Gradient Descent.
+ * This solves the l1-regularized least squares regression formulation
+ * f(weights) = 1/n ||A weights-y||^2 + regParam ||weights||_1
+ * Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with
+ * its corresponding right hand side label y.
+ * See also the documentation for the precise formulation.
*/
class LassoWithSGD private (
var stepSize: Double,
@@ -53,7 +58,7 @@ class LassoWithSGD private (
extends GeneralizedLinearAlgorithm[LassoModel]
with Serializable {
- val gradient = new SquaredGradient()
+ val gradient = new LeastSquaresGradient()
val updater = new L1Updater()
@transient val optimizer = new GradientDescent(gradient, updater).setStepSize(stepSize)
.setNumIterations(numIterations)
@@ -113,12 +118,13 @@ object LassoWithSGD {
/**
* Train a Lasso model given an RDD of (label, features) pairs. We run a fixed number
* of iterations of gradient descent using the specified step size. Each iteration uses
- * `miniBatchFraction` fraction of the data to calculate the gradient. The weights used in
- * gradient descent are initialized using the initial weights provided.
+ * `miniBatchFraction` fraction of the data to calculate a stochastic gradient. The weights used
+ * in gradient descent are initialized using the initial weights provided.
*
- * @param input RDD of (label, array of features) pairs.
+ * @param input RDD of (label, array of features) pairs. Each pair describes a row of the data
+ * matrix A as well as the corresponding right hand side label y
* @param numIterations Number of iterations of gradient descent to run.
- * @param stepSize Step size to be used for each iteration of gradient descent.
+ * @param stepSize Step size scaling to be used for the iterations of gradient descent.
* @param regParam Regularization parameter.
* @param miniBatchFraction Fraction of data to be used per iteration.
* @param initialWeights Initial set of weights to be used. Array should be equal in size to
@@ -140,9 +146,10 @@ object LassoWithSGD {
/**
* Train a Lasso model given an RDD of (label, features) pairs. We run a fixed number
* of iterations of gradient descent using the specified step size. Each iteration uses
- * `miniBatchFraction` fraction of the data to calculate the gradient.
+ * `miniBatchFraction` fraction of the data to calculate a stochastic gradient.
*
- * @param input RDD of (label, array of features) pairs.
+ * @param input RDD of (label, array of features) pairs. Each pair describes a row of the data
+ * matrix A as well as the corresponding right hand side label y
* @param numIterations Number of iterations of gradient descent to run.
* @param stepSize Step size to be used for each iteration of gradient descent.
* @param regParam Regularization parameter.
@@ -162,9 +169,10 @@ object LassoWithSGD {
/**
* Train a Lasso model given an RDD of (label, features) pairs. We run a fixed number
* of iterations of gradient descent using the specified step size. We use the entire data set to
- * update the gradient in each iteration.
+ * update the true gradient in each iteration.
*
- * @param input RDD of (label, array of features) pairs.
+ * @param input RDD of (label, array of features) pairs. Each pair describes a row of the data
+ * matrix A as well as the corresponding right hand side label y
* @param stepSize Step size to be used for each iteration of Gradient Descent.
* @param regParam Regularization parameter.
* @param numIterations Number of iterations of gradient descent to run.
@@ -183,9 +191,10 @@ object LassoWithSGD {
/**
* Train a Lasso model given an RDD of (label, features) pairs. We run a fixed number
* of iterations of gradient descent using a step size of 1.0. We use the entire data set to
- * update the gradient in each iteration.
+ * compute the true gradient in each iteration.
*
- * @param input RDD of (label, array of features) pairs.
+ * @param input RDD of (label, array of features) pairs. Each pair describes a row of the data
+ * matrix A as well as the corresponding right hand side label y
* @param numIterations Number of iterations of gradient descent to run.
* @return a LassoModel which has the weights and offset from training.
*/
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
index df599fde76..8ee40addb2 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
@@ -44,6 +44,12 @@ class LinearRegressionModel(
/**
* Train a linear regression model with no regularization using Stochastic Gradient Descent.
+ * This solves the least squares regression formulation
+ * f(weights) = 1/n ||A weights-y||^2
+ * (which is the mean squared error).
+ * Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with
+ * its corresponding right hand side label y.
+ * See also the documentation for the precise formulation.
*/
class LinearRegressionWithSGD private (
var stepSize: Double,
@@ -52,7 +58,7 @@ class LinearRegressionWithSGD private (
extends GeneralizedLinearAlgorithm[LinearRegressionModel]
with Serializable {
- val gradient = new SquaredGradient()
+ val gradient = new LeastSquaresGradient()
val updater = new SimpleUpdater()
val optimizer = new GradientDescent(gradient, updater).setStepSize(stepSize)
.setNumIterations(numIterations)
@@ -76,10 +82,11 @@ object LinearRegressionWithSGD {
/**
* Train a Linear Regression model given an RDD of (label, features) pairs. We run a fixed number
* of iterations of gradient descent using the specified step size. Each iteration uses
- * `miniBatchFraction` fraction of the data to calculate the gradient. The weights used in
- * gradient descent are initialized using the initial weights provided.
+ * `miniBatchFraction` fraction of the data to calculate a stochastic gradient. The weights used
+ * in gradient descent are initialized using the initial weights provided.
*
- * @param input RDD of (label, array of features) pairs.
+ * @param input RDD of (label, array of features) pairs. Each pair describes a row of the data
+ * matrix A as well as the corresponding right hand side label y
* @param numIterations Number of iterations of gradient descent to run.
* @param stepSize Step size to be used for each iteration of gradient descent.
* @param miniBatchFraction Fraction of data to be used per iteration.
@@ -101,9 +108,10 @@ object LinearRegressionWithSGD {
/**
* Train a LinearRegression model given an RDD of (label, features) pairs. We run a fixed number
* of iterations of gradient descent using the specified step size. Each iteration uses
- * `miniBatchFraction` fraction of the data to calculate the gradient.
+ * `miniBatchFraction` fraction of the data to calculate a stochastic gradient.
*
- * @param input RDD of (label, array of features) pairs.
+ * @param input RDD of (label, array of features) pairs. Each pair describes a row of the data
+ * matrix A as well as the corresponding right hand side label y
* @param numIterations Number of iterations of gradient descent to run.
* @param stepSize Step size to be used for each iteration of gradient descent.
* @param miniBatchFraction Fraction of data to be used per iteration.
@@ -121,9 +129,10 @@ object LinearRegressionWithSGD {
/**
* Train a LinearRegression model given an RDD of (label, features) pairs. We run a fixed number
* of iterations of gradient descent using the specified step size. We use the entire data set to
- * update the gradient in each iteration.
+ * compute the true gradient in each iteration.
*
- * @param input RDD of (label, array of features) pairs.
+ * @param input RDD of (label, array of features) pairs. Each pair describes a row of the data
+ * matrix A as well as the corresponding right hand side label y
* @param stepSize Step size to be used for each iteration of Gradient Descent.
* @param numIterations Number of iterations of gradient descent to run.
* @return a LinearRegressionModel which has the weights and offset from training.
@@ -140,9 +149,10 @@ object LinearRegressionWithSGD {
/**
* Train a LinearRegression model given an RDD of (label, features) pairs. We run a fixed number
* of iterations of gradient descent using a step size of 1.0. We use the entire data set to
- * update the gradient in each iteration.
+ * compute the true gradient in each iteration.
*
- * @param input RDD of (label, array of features) pairs.
+ * @param input RDD of (label, array of features) pairs. Each pair describes a row of the data
+ * matrix A as well as the corresponding right hand side label y
* @param numIterations Number of iterations of gradient descent to run.
* @return a LinearRegressionModel which has the weights and offset from training.
*/
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
index 0c0e67fb7b..c504d3d40c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
@@ -44,6 +44,11 @@ class RidgeRegressionModel(
/**
* Train a regression model with L2-regularization using Stochastic Gradient Descent.
+ * This solves the l1-regularized least squares regression formulation
+ * f(weights) = 1/n ||A weights-y||^2 + regParam/2 ||weights||^2
+ * Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with
+ * its corresponding right hand side label y.
+ * See also the documentation for the precise formulation.
*/
class RidgeRegressionWithSGD private (
var stepSize: Double,
@@ -53,7 +58,7 @@ class RidgeRegressionWithSGD private (
extends GeneralizedLinearAlgorithm[RidgeRegressionModel]
with Serializable {
- val gradient = new SquaredGradient()
+ val gradient = new LeastSquaresGradient()
val updater = new SquaredL2Updater()
@transient val optimizer = new GradientDescent(gradient, updater).setStepSize(stepSize)
@@ -114,8 +119,8 @@ object RidgeRegressionWithSGD {
/**
* Train a RidgeRegression model given an RDD of (label, features) pairs. We run a fixed number
* of iterations of gradient descent using the specified step size. Each iteration uses
- * `miniBatchFraction` fraction of the data to calculate the gradient. The weights used in
- * gradient descent are initialized using the initial weights provided.
+ * `miniBatchFraction` fraction of the data to calculate a stochastic gradient. The weights used
+ * in gradient descent are initialized using the initial weights provided.
*
* @param input RDD of (label, array of features) pairs.
* @param numIterations Number of iterations of gradient descent to run.
@@ -141,7 +146,7 @@ object RidgeRegressionWithSGD {
/**
* Train a RidgeRegression model given an RDD of (label, features) pairs. We run a fixed number
* of iterations of gradient descent using the specified step size. Each iteration uses
- * `miniBatchFraction` fraction of the data to calculate the gradient.
+ * `miniBatchFraction` fraction of the data to calculate a stochastic gradient.
*
* @param input RDD of (label, array of features) pairs.
* @param numIterations Number of iterations of gradient descent to run.
@@ -163,7 +168,7 @@ object RidgeRegressionWithSGD {
/**
* Train a RidgeRegression model given an RDD of (label, features) pairs. We run a fixed number
* of iterations of gradient descent using the specified step size. We use the entire data set to
- * update the gradient in each iteration.
+ * compute the true gradient in each iteration.
*
* @param input RDD of (label, array of features) pairs.
* @param stepSize Step size to be used for each iteration of Gradient Descent.
@@ -184,7 +189,7 @@ object RidgeRegressionWithSGD {
/**
* Train a RidgeRegression model given an RDD of (label, features) pairs. We run a fixed number
* of iterations of gradient descent using a step size of 1.0. We use the entire data set to
- * update the gradient in each iteration.
+ * compute the true gradient in each iteration.
*
* @param input RDD of (label, array of features) pairs.
* @param numIterations Number of iterations of gradient descent to run.