aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/spark/ml/optimization/Gradient.scala
blob: 90b0999a5ec40c7af2ddba908de89f1ec7ee4b89 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
package spark.mllib.optimization

import org.jblas.DoubleMatrix

abstract class Gradient extends Serializable {
  /**
   * Compute the gradient for a given row of data.
   *
   * @param data - One row of data. Row matrix of size 1xn where n is the number of features.
   * @param label - Label for this data item.
   * @param weights - Column matrix containing weights for every feature.
   */
  def compute(data: DoubleMatrix, label: Double, weights: DoubleMatrix): 
      (DoubleMatrix, Double)
}

class LogisticGradient extends Gradient {
  override def compute(data: DoubleMatrix, label: Double, weights: DoubleMatrix): 
      (DoubleMatrix, Double) = {
    val margin: Double = -1.0 * data.dot(weights)
    val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label

    val gradient = data.mul(gradientMultiplier)
    val loss =
      if (margin > 0) {
        math.log(1 + math.exp(0 - margin))
      } else {
        math.log(1 + math.exp(margin)) - margin
      }

    (gradient, loss)
  }
}