aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorSean Owen <sowen@cloudera.com>2014-01-18 12:53:01 +0000
committerSean Owen <sowen@cloudera.com>2014-01-18 12:53:01 +0000
commite91ad3f164b64e727f41ced6ae20d70ca4c92521 (patch)
tree717f4906a9cca74d87e14587eeb218895d49b986 /mllib/src
parentd749d472b37448edb322bc7208a3db925c9a4fc2 (diff)
downloadspark-e91ad3f164b64e727f41ced6ae20d70ca4c92521.tar.gz
spark-e91ad3f164b64e727f41ced6ae20d70ca4c92521.tar.bz2
spark-e91ad3f164b64e727f41ced6ae20d70ca4c92521.zip
Correct L2 regularized weight update with canonical form
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala6
1 files changed, 5 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala
index 4c51f4f881..37124f261e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala
@@ -86,13 +86,17 @@ class L1Updater extends Updater {
/**
* Updater that adjusts the learning rate and performs L2 regularization
+ *
+ * See, for example, explanation of gradient and loss with L2 regularization on slide 21-22
+ * of <a href="http://people.cs.umass.edu/~sheldon/teaching/2012fa/ml/files/lec7-annotated.pdf">
+ * these slides</a>.
*/
class SquaredL2Updater extends Updater {
override def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix,
stepSize: Double, iter: Int, regParam: Double): (DoubleMatrix, Double) = {
val thisIterStepSize = stepSize / math.sqrt(iter)
val normGradient = gradient.mul(thisIterStepSize)
- val newWeights = weightsOld.sub(normGradient).div(2.0 * thisIterStepSize * regParam + 1.0)
+ val newWeights = weightsOld.mul(1.0 - 2.0 * thisIterStepSize * regParam).sub(normGradient)
(newWeights, pow(newWeights.norm2, 2.0) * regParam)
}
}