aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala115
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala21
3 files changed, 101 insertions, 37 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index f92c6816eb..cc9ad22cb8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -34,6 +34,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.StatCounter
+import org.apache.spark.Logging
/**
* Params for linear regression.
@@ -48,7 +49,7 @@ private[regression] trait LinearRegressionParams extends RegressorParams
*/
@AlphaComponent
class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegressionModel]
- with LinearRegressionParams {
+ with LinearRegressionParams with Logging {
/**
* Set the regularization parameter.
@@ -110,6 +111,15 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
val yMean = statCounter.mean
val yStd = math.sqrt(statCounter.variance)
+ // If the yStd is zero, then the intercept is yMean with zero weights;
+ // as a result, training is not needed.
+ if (yStd == 0.0) {
+ logWarning(s"The standard deviation of the label is zero, so the weights will be zeros " +
+ s"and the intercept will be the mean of the label; as a result, training is not needed.")
+ if (handlePersistence) instances.unpersist()
+ return new LinearRegressionModel(this, paramMap, Vectors.sparse(numFeatures, Seq()), yMean)
+ }
+
val featuresMean = summarizer.mean.toArray
val featuresStd = summarizer.variance.toArray.map(math.sqrt)
@@ -141,7 +151,6 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
}
lossHistory += state.value
- // TODO: Based on the sparsity of weights, we may convert the weights to the sparse vector.
// The weights are trained in the scaled space; we're converting them back to
// the original space.
val weights = {
@@ -158,11 +167,10 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
// converged. See the following discussion for detail.
// http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet
val intercept = yMean - dot(weights, Vectors.dense(featuresMean))
+ if (handlePersistence) instances.unpersist()
- if (handlePersistence) {
- instances.unpersist()
- }
- new LinearRegressionModel(this, paramMap, weights, intercept)
+ // TODO: Converts to sparse format based on the storage, but may base on the scoring speed.
+ new LinearRegressionModel(this, paramMap, weights.compressed, intercept)
}
}
@@ -198,15 +206,84 @@ class LinearRegressionModel private[ml] (
* Two LeastSquaresAggregator can be merged together to have a summary of loss and gradient of
* the corresponding joint dataset.
*
-
- * * Compute gradient and loss for a Least-squared loss function, as used in linear regression.
- * This is correct for the averaged least squares loss function (mean squared error)
- * L = 1/2n ||A weights-y||^2
- * See also the documentation for the precise formulation.
+ * For improving the convergence rate during the optimization process, and also preventing against
+ * features with very large variances exerting an overly large influence during model training,
+ * package like R's GLMNET performs the scaling to unit variance and removing the mean to reduce
+ * the condition number, and then trains the model in scaled space but returns the weights in
+ * the original scale. See page 9 in http://cran.r-project.org/web/packages/glmnet/glmnet.pdf
+ *
+ * However, we don't want to apply the `StandardScaler` on the training dataset, and then cache
+ * the standardized dataset since it will create a lot of overhead. As a result, we perform the
+ * scaling implicitly when we compute the objective function. The following is the mathematical
+ * derivation.
+ *
+ * Note that we don't deal with intercept by adding bias here, because the intercept
+ * can be computed using closed form after the coefficients are converged.
+ * See this discussion for detail.
+ * http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet
+ *
+ * The objective function in the scaled space is given by
+ * {{{
+ * L = 1/2n ||\sum_i w_i(x_i - \bar{x_i}) / \hat{x_i} - (y - \bar{y}) / \hat{y}||^2,
+ * }}}
+ * where \bar{x_i} is the mean of x_i, \hat{x_i} is the standard deviation of x_i,
+ * \bar{y} is the mean of label, and \hat{y} is the standard deviation of label.
+ *
+ * This can be rewritten as
+ * {{{
+ * L = 1/2n ||\sum_i (w_i/\hat{x_i})x_i - \sum_i (w_i/\hat{x_i})\bar{x_i} - y / \hat{y}
+ * + \bar{y} / \hat{y}||^2
+ * = 1/2n ||\sum_i w_i^\prime x_i - y / \hat{y} + offset||^2 = 1/2n diff^2
+ * }}}
+ * where w_i^\prime is the effective weights defined by w_i/\hat{x_i}, offset is
+ * {{{
+ * - \sum_i (w_i/\hat{x_i})\bar{x_i} + \bar{y} / \hat{y}.
+ * }}}, and diff is
+ * {{{
+ * \sum_i w_i^\prime x_i - y / \hat{y} + offset
+ * }}}
*
- * @param weights weights/coefficients corresponding to features
+ * Note that the effective weights and offset don't depend on training dataset,
+ * so they can be precomputed.
*
- * @param updater Updater to be used to update weights after every iteration.
+ * Now, the first derivative of the objective function in scaled space is
+ * {{{
+ * \frac{\partial L}{\partial\w_i} = diff/N (x_i - \bar{x_i}) / \hat{x_i}
+ * }}}
+ * However, ($x_i - \bar{x_i}$) will densify the computation, so it's not
+ * an ideal formula when the training dataset is sparse format.
+ *
+ * This can be addressed by adding the dense \bar{x_i} / \har{x_i} terms
+ * in the end by keeping the sum of diff. The first derivative of total
+ * objective function from all the samples is
+ * {{{
+ * \frac{\partial L}{\partial\w_i} =
+ * 1/N \sum_j diff_j (x_{ij} - \bar{x_i}) / \hat{x_i}
+ * = 1/N ((\sum_j diff_j x_{ij} / \hat{x_i}) - diffSum \bar{x_i}) / \hat{x_i})
+ * = 1/N ((\sum_j diff_j x_{ij} / \hat{x_i}) + correction_i)
+ * }}},
+ * where correction_i = - diffSum \bar{x_i}) / \hat{x_i}
+ *
+ * A simple math can show that diffSum is actually zero, so we don't even
+ * need to add the correction terms in the end. From the definition of diff,
+ * {{{
+ * diffSum = \sum_j (\sum_i w_i(x_{ij} - \bar{x_i}) / \hat{x_i} - (y_j - \bar{y}) / \hat{y})
+ * = N * (\sum_i w_i(\bar{x_i} - \bar{x_i}) / \hat{x_i} - (\bar{y_j} - \bar{y}) / \hat{y})
+ * = 0
+ * }}}
+ *
+ * As a result, the first derivative of the total objective function only depends on
+ * the training dataset, which can be easily computed in distributed fashion, and is
+ * sparse format friendly.
+ * {{{
+ * \frac{\partial L}{\partial\w_i} = 1/N ((\sum_j diff_j x_{ij} / \hat{x_i})
+ * }}},
+ *
+ * @param weights The weights/coefficients corresponding to the features.
+ * @param labelStd The standard deviation value of the label.
+ * @param labelMean The mean value of the label.
+ * @param featuresStd The standard deviation values of the features.
+ * @param featuresMean The mean values of the features.
*/
private class LeastSquaresAggregator(
weights: Vector,
@@ -302,18 +379,6 @@ private class LeastSquaresAggregator(
def gradient: Vector = {
val result = Vectors.dense(gradientSumArray.clone())
-
- val correction = {
- val temp = effectiveWeightsArray.clone()
- var i = 0
- while (i < temp.length) {
- temp(i) *= featuresMean(i)
- i += 1
- }
- Vectors.dense(temp)
- }
-
- axpy(-diffSum, correction, result)
scal(1.0 / totalCnt, result)
result
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
index 9fd60ff7a0..26be30ff9d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
@@ -225,7 +225,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
throw new SparkException("Input validation failed.")
}
- /*
+ /**
* Scaling columns to unit variance as a heuristic to reduce the condition number:
*
* During the optimization process, the convergence (rate) depends on the condition number of
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala
index d7bb943e84..91c2f4ccd3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala
@@ -103,17 +103,16 @@ object LinearDataGenerator {
val rnd = new Random(seed)
val x = Array.fill[Array[Double]](nPoints)(
- Array.fill[Double](weights.length)(rnd.nextDouble))
-
- x.map(vector => {
- // This doesn't work if `vector` is a sparse vector.
- val vectorArray = vector.toArray
- var i = 0
- while (i < vectorArray.size) {
- vectorArray(i) = (vectorArray(i) - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i)
- i += 1
- }
- })
+ Array.fill[Double](weights.length)(rnd.nextDouble()))
+
+ x.foreach {
+ case v =>
+ var i = 0
+ while (i < v.length) {
+ v(i) = (v(i) - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i)
+ i += 1
+ }
+ }
val y = x.map { xi =>
blas.ddot(weights.length, xi, 1, weights, 1) + intercept + eps * rnd.nextGaussian()