diff options
author | Xinghao <pxinghao@gmail.com> | 2013-07-28 22:12:39 -0700 |
---|---|---|
committer | Xinghao <pxinghao@gmail.com> | 2013-07-28 22:12:39 -0700 |
commit | 96e04f4cb7de3a7c9d31aa7acba496d81066634e (patch) | |
tree | a81b6c706f31681ab3013bdac1f2403a48b7312d /mllib/src/test | |
parent | 9398dced0331c0ec098ef5eb4616571874ceefb6 (diff) | |
download | spark-96e04f4cb7de3a7c9d31aa7acba496d81066634e.tar.gz spark-96e04f4cb7de3a7c9d31aa7acba496d81066634e.tar.bz2 spark-96e04f4cb7de3a7c9d31aa7acba496d81066634e.zip |
Fixed SVM and LR train functions to take Int instead of Double for Classification
Diffstat (limited to 'mllib/src/test')
-rw-r--r-- | mllib/src/test/scala/spark/mllib/classification/LogisticRegressionSuite.scala | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/mllib/src/test/scala/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/spark/mllib/classification/LogisticRegressionSuite.scala index 144b8b1bc7..3aa9fe6d12 100644 --- a/mllib/src/test/scala/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/spark/mllib/classification/LogisticRegressionSuite.scala @@ -38,7 +38,7 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll { offset: Double, scale: Double, nPoints: Int, - seed: Int): Seq[(Double, Array[Double])] = { + seed: Int): Seq[(Int, Array[Double])] = { val rnd = new Random(seed) val x1 = Array.fill[Double](nPoints)(rnd.nextGaussian()) @@ -51,19 +51,19 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll { // y <- A + B*x + rLogis() // y <- as.numeric(y > 0) - val y: Seq[Double] = (0 until nPoints).map { i => + val y: Seq[Int] = (0 until nPoints).map { i => val yVal = offset + scale * x1(i) + rLogis(i) - if (yVal > 0) 1.0 else 0.0 + if (yVal > 0) 1 else 0 } val testData = (0 until nPoints).map(i => (y(i), Array(x1(i)))) testData } - def validatePrediction(predictions: Seq[Double], input: Seq[(Double, Array[Double])]) { + def validatePrediction(predictions: Seq[Int], input: Seq[(Int, Array[Double])]) { val numOffPredictions = predictions.zip(input).filter { case (prediction, (expected, _)) => // A prediction is off if the prediction is more than 0.5 away from expected value. - math.abs(prediction - expected) > 0.5 + math.abs(prediction.toDouble - expected.toDouble) > 0.5 }.size // At least 80% of the predictions should be on. assert(numOffPredictions < input.length / 5) |