aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala12
1 files changed, 9 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala
index 5a419d1640..aaacf3a8a2 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala
@@ -64,11 +64,17 @@ class LogisticGradient extends Gradient {
val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label
val gradient = data.copy
scal(gradientMultiplier, gradient)
+ val minusYP = if (label > 0) margin else -margin
+
+ // log1p is log(1+p) but more accurate for small p
+ // Following two equations are the same analytically but not numerically, e.g.,
+ // math.log1p(math.exp(1000)) == Infinity
+ // 1000 + math.log1p(math.exp(-1000)) == 1000.0
val loss =
- if (label > 0) {
- math.log1p(math.exp(margin)) // log1p is log(1+p) but more accurate for small p
+ if (minusYP < 0) {
+ math.log1p(math.exp(minusYP))
} else {
- math.log1p(math.exp(margin)) - margin
+ math.log1p(math.exp(-minusYP)) + minusYP
}
(gradient, loss)