aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-07-20 18:40:36 -0700
committerXiangrui Meng <meng@databricks.com>2014-07-20 18:40:36 -0700
commitb86db517b6a2795f687211205b6a14c8685873eb (patch)
treee9754a47e8c2c21581599a1b67018fcd59cc7042 /python
parent9564f8548917f563930d5e87911a304bf206d26e (diff)
downloadspark-b86db517b6a2795f687211205b6a14c8685873eb.tar.gz
spark-b86db517b6a2795f687211205b6a14c8685873eb.tar.bz2
spark-b86db517b6a2795f687211205b6a14c8685873eb.zip
[SPARK-2552][MLLIB] stabilize logistic function in pyspark
to avoid overflow in `exp(x)` if `x` is large. Author: Xiangrui Meng <meng@databricks.com> Closes #1493 from mengxr/py-logistic and squashes the following commits: 259e863 [Xiangrui Meng] stabilize logistic function in pyspark
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/mllib/classification.py5
1 files changed, 4 insertions, 1 deletions
diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py
index 1c0c536c4f..9e28dfbb91 100644
--- a/python/pyspark/mllib/classification.py
+++ b/python/pyspark/mllib/classification.py
@@ -63,7 +63,10 @@ class LogisticRegressionModel(LinearModel):
def predict(self, x):
_linear_predictor_typecheck(x, self._coeff)
margin = _dot(x, self._coeff) + self._intercept
- prob = 1/(1 + exp(-margin))
+ if margin > 0:
+ prob = 1 / (1 + exp(-margin))
+ else:
+ prob = 1 - 1 / (1 + exp(margin))
return 1 if prob > 0.5 else 0