diff options
author | Davies Liu <davies@databricks.com> | 2014-11-18 10:11:13 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2014-11-18 10:11:13 -0800 |
commit | 8fbf72b7903b5bbec8d949151aa4693b4af26ff5 (patch) | |
tree | 47878ffe9262b30cc626c173c341caadb88390ce /mllib | |
parent | cedc3b5aa43a16e2da62f12a36317f00aa1002cc (diff) | |
download | spark-8fbf72b7903b5bbec8d949151aa4693b4af26ff5.tar.gz spark-8fbf72b7903b5bbec8d949151aa4693b4af26ff5.tar.bz2 spark-8fbf72b7903b5bbec8d949151aa4693b4af26ff5.zip |
[SPARK-4435] [MLlib] [PySpark] improve classification
This PR add setThrehold() and clearThreshold() for LogisticRegressionModel and SVMModel, also support RDD of vector in LogisticRegressionModel.predict(), SVNModel.predict() and NaiveBayes.predict()
Author: Davies Liu <davies@databricks.com>
Closes #3305 from davies/setThreshold and squashes the following commits:
d0b835f [Davies Liu] Merge branch 'master' of github.com:apache/spark into setThreshold
e4acd76 [Davies Liu] address comments
2231a5f [Davies Liu] bugfix
7bd9009 [Davies Liu] address comments
0b0a8a7 [Davies Liu] address comments
c1e5573 [Davies Liu] improve classification
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala | 2 | ||||
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala | 2 |
2 files changed, 2 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 18b95f1edc..94d757bc31 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -64,7 +64,7 @@ class LogisticRegressionModel ( val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept val score = 1.0 / (1.0 + math.exp(-margin)) threshold match { - case Some(t) => if (score < t) 0.0 else 1.0 + case Some(t) => if (score > t) 1.0 else 0.0 case None => score } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index ab9515b2a6..dd514ff8a3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -65,7 +65,7 @@ class SVMModel ( intercept: Double) = { val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept threshold match { - case Some(t) => if (margin < t) 0.0 else 1.0 + case Some(t) => if (margin > t) 1.0 else 0.0 case None => margin } } |