aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXinghao <pxinghao@gmail.com>2013-07-29 09:22:31 -0700
committerXinghao <pxinghao@gmail.com>2013-07-29 09:22:31 -0700
commit07f17439a52b65d4f5ef8c8d80bc25dadc0182a8 (patch)
tree03b38fc234cc3a5dc5990d9c3b23a0d0bcfe7729 /mllib
parent3a8d07df8ca5bccdbed178991dd12fde74802542 (diff)
downloadspark-07f17439a52b65d4f5ef8c8d80bc25dadc0182a8.tar.gz
spark-07f17439a52b65d4f5ef8c8d80bc25dadc0182a8.tar.bz2
spark-07f17439a52b65d4f5ef8c8d80bc25dadc0182a8.zip
Fix validatePrediction functions for Classification models
Classifiers return categorical (Int) values that should be compared directly
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/test/scala/spark/mllib/classification/LogisticRegressionSuite.scala3
-rw-r--r--mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala3
2 files changed, 2 insertions, 4 deletions
diff --git a/mllib/src/test/scala/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/spark/mllib/classification/LogisticRegressionSuite.scala
index 3aa9fe6d12..d3fe58a382 100644
--- a/mllib/src/test/scala/spark/mllib/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/spark/mllib/classification/LogisticRegressionSuite.scala
@@ -62,8 +62,7 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll {
def validatePrediction(predictions: Seq[Int], input: Seq[(Int, Array[Double])]) {
val numOffPredictions = predictions.zip(input).filter { case (prediction, (expected, _)) =>
- // A prediction is off if the prediction is more than 0.5 away from expected value.
- math.abs(prediction.toDouble - expected.toDouble) > 0.5
+ (prediction != expected)
}.size
// At least 80% of the predictions should be on.
assert(numOffPredictions < input.length / 5)
diff --git a/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala
index 3f00398a0a..d546e0729e 100644
--- a/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala
+++ b/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala
@@ -52,8 +52,7 @@ class SVMSuite extends FunSuite with BeforeAndAfterAll {
def validatePrediction(predictions: Seq[Int], input: Seq[(Int, Array[Double])]) {
val numOffPredictions = predictions.zip(input).filter { case (prediction, (expected, _)) =>
- // A prediction is off if the prediction is more than 0.5 away from expected value.
- math.abs(prediction - expected) > 0.5
+ (prediction != expected)
}.size
// At least 80% of the predictions should be on.
assert(numOffPredictions < input.length / 5)