aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorXinghao <pxinghao@gmail.com>2013-07-29 09:19:56 -0700
committerXinghao <pxinghao@gmail.com>2013-07-29 09:19:56 -0700
commit75f375730025788a5982146d97bf3df9ef69ab23 (patch)
treef04a0c3a3c755ba5df7b76d9ffa509c6f0916f3f /mllib/src
parentc823ee1e2bea7cde61cb4411a0f0db91f1df2af2 (diff)
downloadspark-75f375730025788a5982146d97bf3df9ef69ab23.tar.gz
spark-75f375730025788a5982146d97bf3df9ef69ab23.tar.bz2
spark-75f375730025788a5982146d97bf3df9ef69ab23.zip
Fix rounding error in LogisticRegression.scala
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala6
1 files changed, 4 insertions, 2 deletions
diff --git a/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala
index cbc0d03ae1..bc1c327729 100644
--- a/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala
@@ -21,6 +21,8 @@ import spark.{Logging, RDD, SparkContext}
import spark.mllib.optimization._
import spark.mllib.util.MLUtils
+import scala.math.round
+
import org.jblas.DoubleMatrix
/**
@@ -42,14 +44,14 @@ class LogisticRegressionModel(
val localIntercept = intercept
testData.map { x =>
val margin = new DoubleMatrix(1, x.length, x:_*).mmul(localWeights).get(0) + localIntercept
- (1.0/ (1.0 + math.exp(margin * -1))).toInt
+ round(1.0/ (1.0 + math.exp(margin * -1))).toInt
}
}
override def predict(testData: Array[Double]): Int = {
val dataMat = new DoubleMatrix(1, testData.length, testData:_*)
val margin = dataMat.mmul(weightsMatrix).get(0) + this.intercept
- (1.0/ (1.0 + math.exp(margin * -1))).toInt
+ round(1.0/ (1.0 + math.exp(margin * -1))).toInt
}
}