aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala41
3 files changed, 50 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
index 8e87b98bac..b967b22e81 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
@@ -149,7 +149,13 @@ object GradientDescent extends Logging {
// Initialize weights as a column vector
var weights = new DoubleMatrix(initialWeights.length, 1, initialWeights:_*)
- var regVal = 0.0
+
+ /**
+ * For the first iteration, the regVal will be initialized as sum of sqrt of
+ * weights if it's L2 update; for L1 update; the same logic is followed.
+ */
+ var regVal = updater.compute(
+ weights, new DoubleMatrix(initialWeights.length, 1), 0, 1, regParam)._2
for (i <- 1 to numIterations) {
// Sample a subset (fraction miniBatchFraction) of the total data
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala
index 889a03e3e6..bf8f731459 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala
@@ -111,6 +111,8 @@ class SquaredL2Updater extends Updater {
val step = gradient.mul(thisIterStepSize)
// add up both updates from the gradient of the loss (= step) as well as
// the gradient of the regularizer (= regParam * weightsOld)
+ // w' = w - thisIterStepSize * (gradient + regParam * w)
+ // w' = (1 - thisIterStepSize * regParam) * w - thisIterStepSize * gradient
val newWeights = weightsOld.mul(1.0 - thisIterStepSize * regParam).sub(step)
(newWeights, 0.5 * pow(newWeights.norm2, 2.0) * regParam)
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
index a453de6767..631d0e2ad9 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
@@ -104,4 +104,45 @@ class GradientDescentSuite extends FunSuite with LocalSparkContext with ShouldMa
val lossDiff = loss.init.zip(loss.tail).map { case (lhs, rhs) => lhs - rhs }
assert(lossDiff.count(_ > 0).toDouble / lossDiff.size > 0.8)
}
+
+ test("Test the loss and gradient of first iteration with regularization.") {
+
+ val gradient = new LogisticGradient()
+ val updater = new SquaredL2Updater()
+
+ // Add a extra variable consisting of all 1.0's for the intercept.
+ val testData = GradientDescentSuite.generateGDInput(2.0, -1.5, 10000, 42)
+ val data = testData.map { case LabeledPoint(label, features) =>
+ label -> Array(1.0, features: _*)
+ }
+
+ val dataRDD = sc.parallelize(data, 2).cache()
+
+ // Prepare non-zero weights
+ val initialWeightsWithIntercept = Array(1.0, 0.5)
+
+ val regParam0 = 0
+ val (newWeights0, loss0) = GradientDescent.runMiniBatchSGD(
+ dataRDD, gradient, updater, 1, 1, regParam0, 1.0, initialWeightsWithIntercept)
+
+ val regParam1 = 1
+ val (newWeights1, loss1) = GradientDescent.runMiniBatchSGD(
+ dataRDD, gradient, updater, 1, 1, regParam1, 1.0, initialWeightsWithIntercept)
+
+ def compareDouble(x: Double, y: Double, tol: Double = 1E-3): Boolean = {
+ math.abs(x - y) / (math.abs(y) + 1e-15) < tol
+ }
+
+ assert(compareDouble(
+ loss1(0),
+ loss0(0) + (math.pow(initialWeightsWithIntercept(0), 2) +
+ math.pow(initialWeightsWithIntercept(1), 2)) / 2),
+ """For non-zero weights, the regVal should be \frac{1}{2}\sum_i w_i^2.""")
+
+ assert(
+ compareDouble(newWeights1(0) , newWeights0(0) - initialWeightsWithIntercept(0)) &&
+ compareDouble(newWeights1(1) , newWeights0(1) - initialWeightsWithIntercept(1)),
+ "The different between newWeights with/without regularization " +
+ "should be initialWeightsWithIntercept.")
+ }
}