aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2014-11-18 10:11:13 -0800
committerXiangrui Meng <meng@databricks.com>2014-11-18 10:11:22 -0800
commita28902f25fc2a685c4a5663e976c1d735265ecb0 (patch)
treea22be1932e4a58a1b1a3ed1fae759f10c4b9f012 /mllib/src
parent4f0477d6f94c85c4777a2f5d587faa539780cded (diff)
downloadspark-a28902f25fc2a685c4a5663e976c1d735265ecb0.tar.gz
spark-a28902f25fc2a685c4a5663e976c1d735265ecb0.tar.bz2
spark-a28902f25fc2a685c4a5663e976c1d735265ecb0.zip
[SPARK-4435] [MLlib] [PySpark] improve classification
This PR add setThrehold() and clearThreshold() for LogisticRegressionModel and SVMModel, also support RDD of vector in LogisticRegressionModel.predict(), SVNModel.predict() and NaiveBayes.predict() Author: Davies Liu <davies@databricks.com> Closes #3305 from davies/setThreshold and squashes the following commits: d0b835f [Davies Liu] Merge branch 'master' of github.com:apache/spark into setThreshold e4acd76 [Davies Liu] address comments 2231a5f [Davies Liu] bugfix 7bd9009 [Davies Liu] address comments 0b0a8a7 [Davies Liu] address comments c1e5573 [Davies Liu] improve classification (cherry picked from commit 8fbf72b7903b5bbec8d949151aa4693b4af26ff5) Signed-off-by: Xiangrui Meng <meng@databricks.com>
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala2
2 files changed, 2 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
index 18b95f1edc..94d757bc31 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
@@ -64,7 +64,7 @@ class LogisticRegressionModel (
val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
val score = 1.0 / (1.0 + math.exp(-margin))
threshold match {
- case Some(t) => if (score < t) 0.0 else 1.0
+ case Some(t) => if (score > t) 1.0 else 0.0
case None => score
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
index ab9515b2a6..dd514ff8a3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
@@ -65,7 +65,7 @@ class SVMModel (
intercept: Double) = {
val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
threshold match {
- case Some(t) => if (margin < t) 0.0 else 1.0
+ case Some(t) => if (margin > t) 1.0 else 0.0
case None => margin
}
}