diff options
author | Xinghao <pxinghao@gmail.com> | 2013-07-29 09:19:56 -0700 |
---|---|---|
committer | Xinghao <pxinghao@gmail.com> | 2013-07-29 09:19:56 -0700 |
commit | 75f375730025788a5982146d97bf3df9ef69ab23 (patch) | |
tree | f04a0c3a3c755ba5df7b76d9ffa509c6f0916f3f /mllib/src/main | |
parent | c823ee1e2bea7cde61cb4411a0f0db91f1df2af2 (diff) | |
download | spark-75f375730025788a5982146d97bf3df9ef69ab23.tar.gz spark-75f375730025788a5982146d97bf3df9ef69ab23.tar.bz2 spark-75f375730025788a5982146d97bf3df9ef69ab23.zip |
Fix rounding error in LogisticRegression.scala
Diffstat (limited to 'mllib/src/main')
-rw-r--r-- | mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala index cbc0d03ae1..bc1c327729 100644 --- a/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala @@ -21,6 +21,8 @@ import spark.{Logging, RDD, SparkContext} import spark.mllib.optimization._ import spark.mllib.util.MLUtils +import scala.math.round + import org.jblas.DoubleMatrix /** @@ -42,14 +44,14 @@ class LogisticRegressionModel( val localIntercept = intercept testData.map { x => val margin = new DoubleMatrix(1, x.length, x:_*).mmul(localWeights).get(0) + localIntercept - (1.0/ (1.0 + math.exp(margin * -1))).toInt + round(1.0/ (1.0 + math.exp(margin * -1))).toInt } } override def predict(testData: Array[Double]): Int = { val dataMat = new DoubleMatrix(1, testData.length, testData:_*) val margin = dataMat.mmul(weightsMatrix).get(0) + this.intercept - (1.0/ (1.0 + math.exp(margin * -1))).toInt + round(1.0/ (1.0 + math.exp(margin * -1))).toInt } } |