aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorDB Tsai <dbtsai@alpinenow.com>2015-01-07 10:13:41 -0800
committerXiangrui Meng <meng@databricks.com>2015-01-07 10:13:41 -0800
commit60e2d9e2902b132b14191c9791c71e8f0d42ce9d (patch)
tree3b743dec60461a9c49057cb3ccdfd0c780771bef /mllib
parent6e74edeca31acd7dc84a34402e430e017591d858 (diff)
downloadspark-60e2d9e2902b132b14191c9791c71e8f0d42ce9d.tar.gz
spark-60e2d9e2902b132b14191c9791c71e8f0d42ce9d.tar.bz2
spark-60e2d9e2902b132b14191c9791c71e8f0d42ce9d.zip
[SPARK-5128][MLLib] Add common used log1pExp API in MLUtils
When `x` is positive and large, computing `math.log(1 + math.exp(x))` will lead to arithmetic overflow. This will happen when `x > 709.78` which is not a very large number. It can be addressed by rewriting the formula into `x + math.log1p(math.exp(-x))` when `x > 0`. Author: DB Tsai <dbtsai@alpinenow.com> Closes #3915 from dbtsai/mathutil and squashes the following commits: bec6a84 [DB Tsai] remove empty line 3239541 [DB Tsai] revert part of patch into another PR 23144f3 [DB Tsai] doc 49f3658 [DB Tsai] temp 6c29ed3 [DB Tsai] formating f8447f9 [DB Tsai] address another overflow issue in gradientMultiplier in LOR gradient code 64eefd0 [DB Tsai] first commit
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala19
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala16
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala13
4 files changed, 37 insertions, 21 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala
index aaacf3a8a2..1ca0f36c6a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala
@@ -20,6 +20,7 @@ package org.apache.spark.mllib.optimization
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.linalg.BLAS.{axpy, dot, scal}
+import org.apache.spark.mllib.util.MLUtils
/**
* :: DeveloperApi ::
@@ -64,17 +65,12 @@ class LogisticGradient extends Gradient {
val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label
val gradient = data.copy
scal(gradientMultiplier, gradient)
- val minusYP = if (label > 0) margin else -margin
-
- // log1p is log(1+p) but more accurate for small p
- // Following two equations are the same analytically but not numerically, e.g.,
- // math.log1p(math.exp(1000)) == Infinity
- // 1000 + math.log1p(math.exp(-1000)) == 1000.0
val loss =
- if (minusYP < 0) {
- math.log1p(math.exp(minusYP))
+ if (label > 0) {
+ // The following is equivalent to log(1 + exp(margin)) but more numerically stable.
+ MLUtils.log1pExp(margin)
} else {
- math.log1p(math.exp(-minusYP)) + minusYP
+ MLUtils.log1pExp(margin) - margin
}
(gradient, loss)
@@ -89,9 +85,10 @@ class LogisticGradient extends Gradient {
val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label
axpy(gradientMultiplier, data, cumGradient)
if (label > 0) {
- math.log1p(math.exp(margin))
+ // The following is equivalent to log(1 + exp(margin)) but more numerically stable.
+ MLUtils.log1pExp(margin)
} else {
- math.log1p(math.exp(margin)) - margin
+ MLUtils.log1pExp(margin) - margin
}
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
index 7ce9fa6f86..55213e6956 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
@@ -20,6 +20,7 @@ package org.apache.spark.mllib.tree.loss
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.model.TreeEnsembleModel
+import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
/**
@@ -61,13 +62,8 @@ object LogLoss extends Loss {
data.map { case point =>
val prediction = model.predict(point.features)
val margin = 2.0 * point.label * prediction
- // The following are equivalent to 2.0 * log(1 + exp(-margin)) but are more numerically
- // stable.
- if (margin >= 0) {
- 2.0 * math.log1p(math.exp(-margin))
- } else {
- 2.0 * (-margin + math.log1p(math.exp(margin)))
- }
+ // The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable.
+ 2.0 * MLUtils.log1pExp(-margin)
}.mean()
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
index c7843464a7..5d6ddd47f6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
@@ -322,4 +322,20 @@ object MLUtils {
}
sqDist
}
+
+ /**
+ * When `x` is positive and large, computing `math.log(1 + math.exp(x))` will lead to arithmetic
+ * overflow. This will happen when `x > 709.78` which is not a very large number.
+ * It can be addressed by rewriting the formula into `x + math.log1p(math.exp(-x))` when `x > 0`.
+ *
+ * @param x a floating-point value as input.
+ * @return the result of `math.log(1 + math.exp(x))`.
+ */
+ private[mllib] def log1pExp(x: Double): Double = {
+ if (x > 0) {
+ x + math.log1p(math.exp(-x))
+ } else {
+ math.log1p(math.exp(x))
+ }
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
index 7778847f8b..668fc1d43c 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
@@ -20,18 +20,17 @@ package org.apache.spark.mllib.util
import java.io.File
import scala.io.Source
-import scala.math
import org.scalatest.FunSuite
-import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, norm => breezeNorm,
- squaredDistance => breezeSquaredDistance}
+import breeze.linalg.{squaredDistance => breezeSquaredDistance}
import com.google.common.base.Charsets
import com.google.common.io.Files
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils._
+import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils
class MLUtilsSuite extends FunSuite with MLlibTestSparkContext {
@@ -204,4 +203,12 @@ class MLUtilsSuite extends FunSuite with MLlibTestSparkContext {
assert(points.collect().toSet === loaded.collect().toSet)
Utils.deleteRecursively(tempDir)
}
+
+ test("log1pExp") {
+ assert(log1pExp(76.3) ~== math.log1p(math.exp(76.3)) relTol 1E-10)
+ assert(log1pExp(87296763.234) ~== 87296763.234 relTol 1E-10)
+
+ assert(log1pExp(-13.8) ~== math.log1p(math.exp(-13.8)) absTol 1E-10)
+ assert(log1pExp(-238423789.865) ~== math.log1p(math.exp(-238423789.865)) absTol 1E-10)
+ }
}