diff options
author | Xinghao <pxinghao@gmail.com> | 2013-07-29 09:22:31 -0700 |
---|---|---|
committer | Xinghao <pxinghao@gmail.com> | 2013-07-29 09:22:31 -0700 |
commit | 07f17439a52b65d4f5ef8c8d80bc25dadc0182a8 (patch) | |
tree | 03b38fc234cc3a5dc5990d9c3b23a0d0bcfe7729 /mllib | |
parent | 3a8d07df8ca5bccdbed178991dd12fde74802542 (diff) | |
download | spark-07f17439a52b65d4f5ef8c8d80bc25dadc0182a8.tar.gz spark-07f17439a52b65d4f5ef8c8d80bc25dadc0182a8.tar.bz2 spark-07f17439a52b65d4f5ef8c8d80bc25dadc0182a8.zip |
Fix validatePrediction functions for Classification models
Classifiers return categorical (Int) values that should be compared
directly
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/test/scala/spark/mllib/classification/LogisticRegressionSuite.scala | 3 | ||||
-rw-r--r-- | mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala | 3 |
2 files changed, 2 insertions, 4 deletions
diff --git a/mllib/src/test/scala/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/spark/mllib/classification/LogisticRegressionSuite.scala index 3aa9fe6d12..d3fe58a382 100644 --- a/mllib/src/test/scala/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/spark/mllib/classification/LogisticRegressionSuite.scala @@ -62,8 +62,7 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll { def validatePrediction(predictions: Seq[Int], input: Seq[(Int, Array[Double])]) { val numOffPredictions = predictions.zip(input).filter { case (prediction, (expected, _)) => - // A prediction is off if the prediction is more than 0.5 away from expected value. - math.abs(prediction.toDouble - expected.toDouble) > 0.5 + (prediction != expected) }.size // At least 80% of the predictions should be on. assert(numOffPredictions < input.length / 5) diff --git a/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala index 3f00398a0a..d546e0729e 100644 --- a/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala +++ b/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala @@ -52,8 +52,7 @@ class SVMSuite extends FunSuite with BeforeAndAfterAll { def validatePrediction(predictions: Seq[Int], input: Seq[(Int, Array[Double])]) { val numOffPredictions = predictions.zip(input).filter { case (prediction, (expected, _)) => - // A prediction is off if the prediction is more than 0.5 away from expected value. - math.abs(prediction - expected) > 0.5 + (prediction != expected) }.size // At least 80% of the predictions should be on. assert(numOffPredictions < input.length / 5) |