aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorShivaram Venkataraman <shivaram@eecs.berkeley.edu>2013-07-29 18:37:28 -0700
committerShivaram Venkataraman <shivaram@eecs.berkeley.edu>2013-07-29 18:37:28 -0700
commit3ca9faa341dcddb54f8b2e26b582c08901ea875f (patch)
tree22dba9691f1cc22c7e602fa285f9ff6b1389b86e /mllib
parent07da72b45190f7db9daa2c6bd33577d28e19e659 (diff)
downloadspark-3ca9faa341dcddb54f8b2e26b582c08901ea875f.tar.gz
spark-3ca9faa341dcddb54f8b2e26b582c08901ea875f.tar.bz2
spark-3ca9faa341dcddb54f8b2e26b582c08901ea875f.zip
Clarify how regVal is computed in Updater docs
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/spark/mllib/optimization/Updater.scala17
1 files changed, 9 insertions, 8 deletions
diff --git a/mllib/src/main/scala/spark/mllib/optimization/Updater.scala b/mllib/src/main/scala/spark/mllib/optimization/Updater.scala
index bbf21e5c28..e916a92c33 100644
--- a/mllib/src/main/scala/spark/mllib/optimization/Updater.scala
+++ b/mllib/src/main/scala/spark/mllib/optimization/Updater.scala
@@ -23,6 +23,7 @@ import org.jblas.DoubleMatrix
abstract class Updater extends Serializable {
/**
* Compute an updated value for weights given the gradient, stepSize and iteration number.
+ * Also returns the regularization value computed using the *updated* weights.
*
* @param weightsOlds - Column matrix of size nx1 where n is the number of features.
* @param gradient - Column matrix of size nx1 where n is the number of features.
@@ -31,7 +32,7 @@ abstract class Updater extends Serializable {
* @param regParam - Regularization parameter
*
* @return A tuple of 2 elements. The first element is a column matrix containing updated weights,
- * and the second element is the regularization value.
+ * and the second element is the regularization value computed using updated weights.
*/
def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix, stepSize: Double, iter: Int, regParam: Double):
(DoubleMatrix, Double)
@@ -46,13 +47,13 @@ class SimpleUpdater extends Updater {
}
/**
-* L1 regularization -- corresponding proximal operator is the soft-thresholding function
-* That is, each weight component is shrunk towards 0 by shrinkageVal
-* If w > shrinkageVal, set weight component to w-shrinkageVal.
-* If w < -shrinkageVal, set weight component to w+shrinkageVal.
-* If -shrinkageVal < w < shrinkageVal, set weight component to 0.
-* Equivalently, set weight component to signum(w) * max(0.0, abs(w) - shrinkageVal)
-**/
+ * L1 regularization -- corresponding proximal operator is the soft-thresholding function
+ * That is, each weight component is shrunk towards 0 by shrinkageVal
+ * If w > shrinkageVal, set weight component to w-shrinkageVal.
+ * If w < -shrinkageVal, set weight component to w+shrinkageVal.
+ * If -shrinkageVal < w < shrinkageVal, set weight component to 0.
+ * Equivalently, set weight component to signum(w) * max(0.0, abs(w) - shrinkageVal)
+ */
class L1Updater extends Updater {
override def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix,
stepSize: Double, iter: Int, regParam: Double): (DoubleMatrix, Double) = {