diff options
author | Ameet Talwalkar <atalwalkar@gmail.com> | 2013-07-29 22:21:50 -0700 |
---|---|---|
committer | Ameet Talwalkar <atalwalkar@gmail.com> | 2013-07-29 22:21:50 -0700 |
commit | e4387ddf5d1a46dfedece73feff4de6a30f9a220 (patch) | |
tree | 57659c36ae0397d1408d264fdb5d3a3a5e36cc53 | |
parent | 468a36c00526872396196458fd7875fd06ac7108 (diff) | |
download | spark-e4387ddf5d1a46dfedece73feff4de6a30f9a220.tar.gz spark-e4387ddf5d1a46dfedece73feff4de6a30f9a220.tar.bz2 spark-e4387ddf5d1a46dfedece73feff4de6a30f9a220.zip |
made SimpleUpdater consistent with other updaters
-rw-r--r-- | mllib/src/main/scala/spark/mllib/optimization/Updater.scala | 3 |
1 files changed, 2 insertions, 1 deletions
diff --git a/mllib/src/main/scala/spark/mllib/optimization/Updater.scala b/mllib/src/main/scala/spark/mllib/optimization/Updater.scala index e916a92c33..bf506d2f24 100644 --- a/mllib/src/main/scala/spark/mllib/optimization/Updater.scala +++ b/mllib/src/main/scala/spark/mllib/optimization/Updater.scala @@ -41,7 +41,8 @@ abstract class Updater extends Serializable { class SimpleUpdater extends Updater { override def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix, stepSize: Double, iter: Int, regParam: Double): (DoubleMatrix, Double) = { - val normGradient = gradient.mul(stepSize / math.sqrt(iter)) + val thisIterStepSize = stepSize / math.sqrt(iter) + val normGradient = gradient.mul(thisIterStepSize) (weightsOld.sub(normGradient), 0) } } |