aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorMartin Jaggi <m.jaggi@gmail.com>2014-02-09 15:19:50 -0800
committerReynold Xin <rxin@apache.org>2014-02-09 15:19:50 -0800
commit2182aa3c55737a90e0ff200eede7146b440801a3 (patch)
tree6696bfc522a68e9858679130b9758144132fb356 /mllib
parentafc8f3cb9a7afe3249500a7d135b4a54bb3e58c4 (diff)
downloadspark-2182aa3c55737a90e0ff200eede7146b440801a3.tar.gz
spark-2182aa3c55737a90e0ff200eede7146b440801a3.tar.bz2
spark-2182aa3c55737a90e0ff200eede7146b440801a3.zip
Merge pull request #566 from martinjaggi/copy-MLlib-d.
new MLlib documentation for optimization, regression and classification new documentation with tex formulas, hopefully improving usability and reproducibility of the offered MLlib methods. also did some minor changes in the code for consistency. scala tests pass. this is the rebased branch, i deleted the old PR jira: https://spark-project.atlassian.net/browse/MLLIB-19 Author: Martin Jaggi <m.jaggi@gmail.com> Closes #566 and squashes the following commits: 5f0f31e [Martin Jaggi] line wrap at 100 chars 4e094fb [Martin Jaggi] better description of GradientDescent 1d6965d [Martin Jaggi] remove broken url ea569c3 [Martin Jaggi] telling what updater actually does 964732b [Martin Jaggi] lambda R() in documentation a6c6228 [Martin Jaggi] better comments in SGD code for regression b32224a [Martin Jaggi] new optimization documentation d5dfef7 [Martin Jaggi] new classification and regression documentation b07ead6 [Martin Jaggi] correct scaling for MSE loss ba6158c [Martin Jaggi] use d for the number of features bab2ed2 [Martin Jaggi] renaming LeastSquaresGradient
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala23
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala42
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala61
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala31
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala30
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala17
6 files changed, 130 insertions, 74 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala
index c590492e7a..82124703da 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala
@@ -24,10 +24,10 @@ import org.jblas.DoubleMatrix
*/
abstract class Gradient extends Serializable {
/**
- * Compute the gradient and loss given features of a single data point.
+ * Compute the gradient and loss given the features of a single data point.
*
- * @param data - Feature values for one data point. Column matrix of size nx1
- * where n is the number of features.
+ * @param data - Feature values for one data point. Column matrix of size dx1
+ * where d is the number of features.
* @param label - Label for this data item.
* @param weights - Column matrix containing weights for every feature.
*
@@ -40,7 +40,8 @@ abstract class Gradient extends Serializable {
}
/**
- * Compute gradient and loss for a logistic loss function.
+ * Compute gradient and loss for a logistic loss function, as used in binary classification.
+ * See also the documentation for the precise formulation.
*/
class LogisticGradient extends Gradient {
override def compute(data: DoubleMatrix, label: Double, weights: DoubleMatrix):
@@ -61,22 +62,26 @@ class LogisticGradient extends Gradient {
}
/**
- * Compute gradient and loss for a Least-squared loss function.
+ * Compute gradient and loss for a Least-squared loss function, as used in linear regression.
+ * This is correct for the averaged least squares loss function (mean squared error)
+ * L = 1/n ||A weights-y||^2
+ * See also the documentation for the precise formulation.
*/
-class SquaredGradient extends Gradient {
+class LeastSquaresGradient extends Gradient {
override def compute(data: DoubleMatrix, label: Double, weights: DoubleMatrix):
(DoubleMatrix, Double) = {
val diff: Double = data.dot(weights) - label
- val loss = 0.5 * diff * diff
- val gradient = data.mul(diff)
+ val loss = diff * diff
+ val gradient = data.mul(2.0 * diff)
(gradient, loss)
}
}
/**
- * Compute gradient and loss for a Hinge loss function.
+ * Compute gradient and loss for a Hinge loss function, as used in SVM binary classification.
+ * See also the documentation for the precise formulation.
* NOTE: This assumes that the labels are {0,1}
*/
class HingeGradient extends Gradient {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
index cd80134737..8e87b98bac 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
@@ -17,9 +17,8 @@
package org.apache.spark.mllib.optimization
-import org.apache.spark.{Logging, SparkContext}
+import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
-import org.apache.spark.SparkContext._
import org.jblas.DoubleMatrix
@@ -39,7 +38,8 @@ class GradientDescent(var gradient: Gradient, var updater: Updater)
private var miniBatchFraction: Double = 1.0
/**
- * Set the step size per-iteration of SGD. Default 1.0.
+ * Set the initial step size of SGD for the first step. Default 1.0.
+ * In subsequent steps, the step size will decrease with stepSize/sqrt(t)
*/
def setStepSize(step: Double): this.type = {
this.stepSize = step
@@ -47,7 +47,8 @@ class GradientDescent(var gradient: Gradient, var updater: Updater)
}
/**
- * Set fraction of data to be used for each SGD iteration. Default 1.0.
+ * Set fraction of data to be used for each SGD iteration.
+ * Default 1.0 (corresponding to deterministic/classical gradient descent)
*/
def setMiniBatchFraction(fraction: Double): this.type = {
this.miniBatchFraction = fraction
@@ -63,7 +64,7 @@ class GradientDescent(var gradient: Gradient, var updater: Updater)
}
/**
- * Set the regularization parameter used for SGD. Default 0.0.
+ * Set the regularization parameter. Default 0.0.
*/
def setRegParam(regParam: Double): this.type = {
this.regParam = regParam
@@ -71,7 +72,8 @@ class GradientDescent(var gradient: Gradient, var updater: Updater)
}
/**
- * Set the gradient function to be used for SGD.
+ * Set the gradient function (of the loss function of one single data example)
+ * to be used for SGD.
*/
def setGradient(gradient: Gradient): this.type = {
this.gradient = gradient
@@ -80,7 +82,9 @@ class GradientDescent(var gradient: Gradient, var updater: Updater)
/**
- * Set the updater function to be used for SGD.
+ * Set the updater function to actually perform a gradient step in a given direction.
+ * The updater is responsible to perform the update from the regularization term as well,
+ * and therefore determines what kind or regularization is used, if any.
*/
def setUpdater(updater: Updater): this.type = {
this.updater = updater
@@ -107,20 +111,26 @@ class GradientDescent(var gradient: Gradient, var updater: Updater)
// Top-level method to run gradient descent.
object GradientDescent extends Logging {
/**
- * Run gradient descent in parallel using mini batches.
+ * Run stochastic gradient descent (SGD) in parallel using mini batches.
+ * In each iteration, we sample a subset (fraction miniBatchFraction) of the total data
+ * in order to compute a gradient estimate.
+ * Sampling, and averaging the subgradients over this subset is performed using one standard
+ * spark map-reduce in each iteration.
*
- * @param data - Input data for SGD. RDD of form (label, [feature values]).
- * @param gradient - Gradient object that will be used to compute the gradient.
- * @param updater - Updater object that will be used to update the model.
- * @param stepSize - stepSize to be used during update.
+ * @param data - Input data for SGD. RDD of the set of data examples, each of
+ * the form (label, [feature values]).
+ * @param gradient - Gradient object (used to compute the gradient of the loss function of
+ * one single data example)
+ * @param updater - Updater function to actually perform a gradient step in a given direction.
+ * @param stepSize - initial step size for the first step
* @param numIterations - number of iterations that SGD should be run.
* @param regParam - regularization parameter
* @param miniBatchFraction - fraction of the input data set that should be used for
* one iteration of SGD. Default value 1.0.
*
* @return A tuple containing two elements. The first element is a column matrix containing
- * weights for every feature, and the second element is an array containing the stochastic
- * loss computed for every iteration.
+ * weights for every feature, and the second element is an array containing the
+ * stochastic loss computed for every iteration.
*/
def runMiniBatchSGD(
data: RDD[(Double, Array[Double])],
@@ -142,6 +152,8 @@ object GradientDescent extends Logging {
var regVal = 0.0
for (i <- 1 to numIterations) {
+ // Sample a subset (fraction miniBatchFraction) of the total data
+ // compute and sum up the subgradients on this subset (this is one map-reduce)
val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42 + i).map {
case (y, features) =>
val featuresCol = new DoubleMatrix(features.length, 1, features:_*)
@@ -160,7 +172,7 @@ object GradientDescent extends Logging {
regVal = update._2
}
- logInfo("GradientDescent finished. Last 10 stochastic losses %s".format(
+ logInfo("GradientDescent.runMiniBatchSGD finished. Last 10 stochastic losses %s".format(
stochasticLossHistory.takeRight(10).mkString(", ")))
(weights.toArray, stochasticLossHistory.toArray)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala
index 37124f261e..889a03e3e6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala
@@ -21,16 +21,25 @@ import scala.math._
import org.jblas.DoubleMatrix
/**
- * Class used to update weights used in Gradient Descent.
+ * Class used to perform steps (weight update) using Gradient Descent methods.
+ *
+ * For general minimization problems, or for regularized problems of the form
+ * min L(w) + regParam * R(w),
+ * the compute function performs the actual update step, when given some
+ * (e.g. stochastic) gradient direction for the loss L(w),
+ * and a desired step-size (learning rate).
+ *
+ * The updater is responsible to also perform the update coming from the
+ * regularization term R(w) (if any regularization is used).
*/
abstract class Updater extends Serializable {
/**
* Compute an updated value for weights given the gradient, stepSize, iteration number and
- * regularization parameter. Also returns the regularization value computed using the
- * *updated* weights.
+ * regularization parameter. Also returns the regularization value regParam * R(w)
+ * computed using the *updated* weights.
*
- * @param weightsOld - Column matrix of size nx1 where n is the number of features.
- * @param gradient - Column matrix of size nx1 where n is the number of features.
+ * @param weightsOld - Column matrix of size dx1 where d is the number of features.
+ * @param gradient - Column matrix of size dx1 where d is the number of features.
* @param stepSize - step size across iterations
* @param iter - Iteration number
* @param regParam - Regularization parameter
@@ -43,23 +52,29 @@ abstract class Updater extends Serializable {
}
/**
- * A simple updater that adaptively adjusts the learning rate the
- * square root of the number of iterations. Does not perform any regularization.
+ * A simple updater for gradient descent *without* any regularization.
+ * Uses a step-size decreasing with the square root of the number of iterations.
*/
class SimpleUpdater extends Updater {
override def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix,
stepSize: Double, iter: Int, regParam: Double): (DoubleMatrix, Double) = {
val thisIterStepSize = stepSize / math.sqrt(iter)
- val normGradient = gradient.mul(thisIterStepSize)
- (weightsOld.sub(normGradient), 0)
+ val step = gradient.mul(thisIterStepSize)
+ (weightsOld.sub(step), 0)
}
}
/**
- * Updater that adjusts learning rate and performs L1 regularization.
+ * Updater for L1 regularized problems.
+ * R(w) = ||w||_1
+ * Uses a step-size decreasing with the square root of the number of iterations.
+
+ * Instead of subgradient of the regularizer, the proximal operator for the
+ * L1 regularization is applied after the gradient step. This is known to
+ * result in better sparsity of the intermediate solution.
*
- * The corresponding proximal operator used is the soft-thresholding function.
- * That is, each weight component is shrunk towards 0 by shrinkageVal.
+ * The corresponding proximal operator for the L1 norm is the soft-thresholding
+ * function. That is, each weight component is shrunk towards 0 by shrinkageVal.
*
* If w > shrinkageVal, set weight component to w-shrinkageVal.
* If w < -shrinkageVal, set weight component to w+shrinkageVal.
@@ -71,10 +86,10 @@ class L1Updater extends Updater {
override def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix,
stepSize: Double, iter: Int, regParam: Double): (DoubleMatrix, Double) = {
val thisIterStepSize = stepSize / math.sqrt(iter)
- val normGradient = gradient.mul(thisIterStepSize)
+ val step = gradient.mul(thisIterStepSize)
// Take gradient step
- val newWeights = weightsOld.sub(normGradient)
- // Soft thresholding
+ val newWeights = weightsOld.sub(step)
+ // Apply proximal operator (soft thresholding)
val shrinkageVal = regParam * thisIterStepSize
(0 until newWeights.length).foreach { i =>
val wi = newWeights.get(i)
@@ -85,19 +100,19 @@ class L1Updater extends Updater {
}
/**
- * Updater that adjusts the learning rate and performs L2 regularization
- *
- * See, for example, explanation of gradient and loss with L2 regularization on slide 21-22
- * of <a href="http://people.cs.umass.edu/~sheldon/teaching/2012fa/ml/files/lec7-annotated.pdf">
- * these slides</a>.
+ * Updater for L2 regularized problems.
+ * R(w) = 1/2 ||w||^2
+ * Uses a step-size decreasing with the square root of the number of iterations.
*/
class SquaredL2Updater extends Updater {
override def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix,
stepSize: Double, iter: Int, regParam: Double): (DoubleMatrix, Double) = {
val thisIterStepSize = stepSize / math.sqrt(iter)
- val normGradient = gradient.mul(thisIterStepSize)
- val newWeights = weightsOld.mul(1.0 - 2.0 * thisIterStepSize * regParam).sub(normGradient)
- (newWeights, pow(newWeights.norm2, 2.0) * regParam)
+ val step = gradient.mul(thisIterStepSize)
+ // add up both updates from the gradient of the loss (= step) as well as
+ // the gradient of the regularizer (= regParam * weightsOld)
+ val newWeights = weightsOld.mul(1.0 - thisIterStepSize * regParam).sub(step)
+ (newWeights, 0.5 * pow(newWeights.norm2, 2.0) * regParam)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
index 7c41793722..fb2bc9b92a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
@@ -44,6 +44,11 @@ class LassoModel(
/**
* Train a regression model with L1-regularization using Stochastic Gradient Descent.
+ * This solves the l1-regularized least squares regression formulation
+ * f(weights) = 1/n ||A weights-y||^2 + regParam ||weights||_1
+ * Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with
+ * its corresponding right hand side label y.
+ * See also the documentation for the precise formulation.
*/
class LassoWithSGD private (
var stepSize: Double,
@@ -53,7 +58,7 @@ class LassoWithSGD private (
extends GeneralizedLinearAlgorithm[LassoModel]
with Serializable {
- val gradient = new SquaredGradient()
+ val gradient = new LeastSquaresGradient()
val updater = new L1Updater()
@transient val optimizer = new GradientDescent(gradient, updater).setStepSize(stepSize)
.setNumIterations(numIterations)
@@ -113,12 +118,13 @@ object LassoWithSGD {
/**
* Train a Lasso model given an RDD of (label, features) pairs. We run a fixed number
* of iterations of gradient descent using the specified step size. Each iteration uses
- * `miniBatchFraction` fraction of the data to calculate the gradient. The weights used in
- * gradient descent are initialized using the initial weights provided.
+ * `miniBatchFraction` fraction of the data to calculate a stochastic gradient. The weights used
+ * in gradient descent are initialized using the initial weights provided.
*
- * @param input RDD of (label, array of features) pairs.
+ * @param input RDD of (label, array of features) pairs. Each pair describes a row of the data
+ * matrix A as well as the corresponding right hand side label y
* @param numIterations Number of iterations of gradient descent to run.
- * @param stepSize Step size to be used for each iteration of gradient descent.
+ * @param stepSize Step size scaling to be used for the iterations of gradient descent.
* @param regParam Regularization parameter.
* @param miniBatchFraction Fraction of data to be used per iteration.
* @param initialWeights Initial set of weights to be used. Array should be equal in size to
@@ -140,9 +146,10 @@ object LassoWithSGD {
/**
* Train a Lasso model given an RDD of (label, features) pairs. We run a fixed number
* of iterations of gradient descent using the specified step size. Each iteration uses
- * `miniBatchFraction` fraction of the data to calculate the gradient.
+ * `miniBatchFraction` fraction of the data to calculate a stochastic gradient.
*
- * @param input RDD of (label, array of features) pairs.
+ * @param input RDD of (label, array of features) pairs. Each pair describes a row of the data
+ * matrix A as well as the corresponding right hand side label y
* @param numIterations Number of iterations of gradient descent to run.
* @param stepSize Step size to be used for each iteration of gradient descent.
* @param regParam Regularization parameter.
@@ -162,9 +169,10 @@ object LassoWithSGD {
/**
* Train a Lasso model given an RDD of (label, features) pairs. We run a fixed number
* of iterations of gradient descent using the specified step size. We use the entire data set to
- * update the gradient in each iteration.
+ * update the true gradient in each iteration.
*
- * @param input RDD of (label, array of features) pairs.
+ * @param input RDD of (label, array of features) pairs. Each pair describes a row of the data
+ * matrix A as well as the corresponding right hand side label y
* @param stepSize Step size to be used for each iteration of Gradient Descent.
* @param regParam Regularization parameter.
* @param numIterations Number of iterations of gradient descent to run.
@@ -183,9 +191,10 @@ object LassoWithSGD {
/**
* Train a Lasso model given an RDD of (label, features) pairs. We run a fixed number
* of iterations of gradient descent using a step size of 1.0. We use the entire data set to
- * update the gradient in each iteration.
+ * compute the true gradient in each iteration.
*
- * @param input RDD of (label, array of features) pairs.
+ * @param input RDD of (label, array of features) pairs. Each pair describes a row of the data
+ * matrix A as well as the corresponding right hand side label y
* @param numIterations Number of iterations of gradient descent to run.
* @return a LassoModel which has the weights and offset from training.
*/
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
index df599fde76..8ee40addb2 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
@@ -44,6 +44,12 @@ class LinearRegressionModel(
/**
* Train a linear regression model with no regularization using Stochastic Gradient Descent.
+ * This solves the least squares regression formulation
+ * f(weights) = 1/n ||A weights-y||^2
+ * (which is the mean squared error).
+ * Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with
+ * its corresponding right hand side label y.
+ * See also the documentation for the precise formulation.
*/
class LinearRegressionWithSGD private (
var stepSize: Double,
@@ -52,7 +58,7 @@ class LinearRegressionWithSGD private (
extends GeneralizedLinearAlgorithm[LinearRegressionModel]
with Serializable {
- val gradient = new SquaredGradient()
+ val gradient = new LeastSquaresGradient()
val updater = new SimpleUpdater()
val optimizer = new GradientDescent(gradient, updater).setStepSize(stepSize)
.setNumIterations(numIterations)
@@ -76,10 +82,11 @@ object LinearRegressionWithSGD {
/**
* Train a Linear Regression model given an RDD of (label, features) pairs. We run a fixed number
* of iterations of gradient descent using the specified step size. Each iteration uses
- * `miniBatchFraction` fraction of the data to calculate the gradient. The weights used in
- * gradient descent are initialized using the initial weights provided.
+ * `miniBatchFraction` fraction of the data to calculate a stochastic gradient. The weights used
+ * in gradient descent are initialized using the initial weights provided.
*
- * @param input RDD of (label, array of features) pairs.
+ * @param input RDD of (label, array of features) pairs. Each pair describes a row of the data
+ * matrix A as well as the corresponding right hand side label y
* @param numIterations Number of iterations of gradient descent to run.
* @param stepSize Step size to be used for each iteration of gradient descent.
* @param miniBatchFraction Fraction of data to be used per iteration.
@@ -101,9 +108,10 @@ object LinearRegressionWithSGD {
/**
* Train a LinearRegression model given an RDD of (label, features) pairs. We run a fixed number
* of iterations of gradient descent using the specified step size. Each iteration uses
- * `miniBatchFraction` fraction of the data to calculate the gradient.
+ * `miniBatchFraction` fraction of the data to calculate a stochastic gradient.
*
- * @param input RDD of (label, array of features) pairs.
+ * @param input RDD of (label, array of features) pairs. Each pair describes a row of the data
+ * matrix A as well as the corresponding right hand side label y
* @param numIterations Number of iterations of gradient descent to run.
* @param stepSize Step size to be used for each iteration of gradient descent.
* @param miniBatchFraction Fraction of data to be used per iteration.
@@ -121,9 +129,10 @@ object LinearRegressionWithSGD {
/**
* Train a LinearRegression model given an RDD of (label, features) pairs. We run a fixed number
* of iterations of gradient descent using the specified step size. We use the entire data set to
- * update the gradient in each iteration.
+ * compute the true gradient in each iteration.
*
- * @param input RDD of (label, array of features) pairs.
+ * @param input RDD of (label, array of features) pairs. Each pair describes a row of the data
+ * matrix A as well as the corresponding right hand side label y
* @param stepSize Step size to be used for each iteration of Gradient Descent.
* @param numIterations Number of iterations of gradient descent to run.
* @return a LinearRegressionModel which has the weights and offset from training.
@@ -140,9 +149,10 @@ object LinearRegressionWithSGD {
/**
* Train a LinearRegression model given an RDD of (label, features) pairs. We run a fixed number
* of iterations of gradient descent using a step size of 1.0. We use the entire data set to
- * update the gradient in each iteration.
+ * compute the true gradient in each iteration.
*
- * @param input RDD of (label, array of features) pairs.
+ * @param input RDD of (label, array of features) pairs. Each pair describes a row of the data
+ * matrix A as well as the corresponding right hand side label y
* @param numIterations Number of iterations of gradient descent to run.
* @return a LinearRegressionModel which has the weights and offset from training.
*/
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
index 0c0e67fb7b..c504d3d40c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
@@ -44,6 +44,11 @@ class RidgeRegressionModel(
/**
* Train a regression model with L2-regularization using Stochastic Gradient Descent.
+ * This solves the l1-regularized least squares regression formulation
+ * f(weights) = 1/n ||A weights-y||^2 + regParam/2 ||weights||^2
+ * Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with
+ * its corresponding right hand side label y.
+ * See also the documentation for the precise formulation.
*/
class RidgeRegressionWithSGD private (
var stepSize: Double,
@@ -53,7 +58,7 @@ class RidgeRegressionWithSGD private (
extends GeneralizedLinearAlgorithm[RidgeRegressionModel]
with Serializable {
- val gradient = new SquaredGradient()
+ val gradient = new LeastSquaresGradient()
val updater = new SquaredL2Updater()
@transient val optimizer = new GradientDescent(gradient, updater).setStepSize(stepSize)
@@ -114,8 +119,8 @@ object RidgeRegressionWithSGD {
/**
* Train a RidgeRegression model given an RDD of (label, features) pairs. We run a fixed number
* of iterations of gradient descent using the specified step size. Each iteration uses
- * `miniBatchFraction` fraction of the data to calculate the gradient. The weights used in
- * gradient descent are initialized using the initial weights provided.
+ * `miniBatchFraction` fraction of the data to calculate a stochastic gradient. The weights used
+ * in gradient descent are initialized using the initial weights provided.
*
* @param input RDD of (label, array of features) pairs.
* @param numIterations Number of iterations of gradient descent to run.
@@ -141,7 +146,7 @@ object RidgeRegressionWithSGD {
/**
* Train a RidgeRegression model given an RDD of (label, features) pairs. We run a fixed number
* of iterations of gradient descent using the specified step size. Each iteration uses
- * `miniBatchFraction` fraction of the data to calculate the gradient.
+ * `miniBatchFraction` fraction of the data to calculate a stochastic gradient.
*
* @param input RDD of (label, array of features) pairs.
* @param numIterations Number of iterations of gradient descent to run.
@@ -163,7 +168,7 @@ object RidgeRegressionWithSGD {
/**
* Train a RidgeRegression model given an RDD of (label, features) pairs. We run a fixed number
* of iterations of gradient descent using the specified step size. We use the entire data set to
- * update the gradient in each iteration.
+ * compute the true gradient in each iteration.
*
* @param input RDD of (label, array of features) pairs.
* @param stepSize Step size to be used for each iteration of Gradient Descent.
@@ -184,7 +189,7 @@ object RidgeRegressionWithSGD {
/**
* Train a RidgeRegression model given an RDD of (label, features) pairs. We run a fixed number
* of iterations of gradient descent using a step size of 1.0. We use the entire data set to
- * update the gradient in each iteration.
+ * compute the true gradient in each iteration.
*
* @param input RDD of (label, array of features) pairs.
* @param numIterations Number of iterations of gradient descent to run.