aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorXinghao <pxinghao@gmail.com>2013-07-28 22:12:39 -0700
committerXinghao <pxinghao@gmail.com>2013-07-28 22:12:39 -0700
commit96e04f4cb7de3a7c9d31aa7acba496d81066634e (patch)
treea81b6c706f31681ab3013bdac1f2403a48b7312d /mllib/src/test
parent9398dced0331c0ec098ef5eb4616571874ceefb6 (diff)
downloadspark-96e04f4cb7de3a7c9d31aa7acba496d81066634e.tar.gz
spark-96e04f4cb7de3a7c9d31aa7acba496d81066634e.tar.bz2
spark-96e04f4cb7de3a7c9d31aa7acba496d81066634e.zip
Fixed SVM and LR train functions to take Int instead of Double for Classification
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/spark/mllib/classification/LogisticRegressionSuite.scala10
1 files changed, 5 insertions, 5 deletions
diff --git a/mllib/src/test/scala/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/spark/mllib/classification/LogisticRegressionSuite.scala
index 144b8b1bc7..3aa9fe6d12 100644
--- a/mllib/src/test/scala/spark/mllib/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/spark/mllib/classification/LogisticRegressionSuite.scala
@@ -38,7 +38,7 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll {
offset: Double,
scale: Double,
nPoints: Int,
- seed: Int): Seq[(Double, Array[Double])] = {
+ seed: Int): Seq[(Int, Array[Double])] = {
val rnd = new Random(seed)
val x1 = Array.fill[Double](nPoints)(rnd.nextGaussian())
@@ -51,19 +51,19 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll {
// y <- A + B*x + rLogis()
// y <- as.numeric(y > 0)
- val y: Seq[Double] = (0 until nPoints).map { i =>
+ val y: Seq[Int] = (0 until nPoints).map { i =>
val yVal = offset + scale * x1(i) + rLogis(i)
- if (yVal > 0) 1.0 else 0.0
+ if (yVal > 0) 1 else 0
}
val testData = (0 until nPoints).map(i => (y(i), Array(x1(i))))
testData
}
- def validatePrediction(predictions: Seq[Double], input: Seq[(Double, Array[Double])]) {
+ def validatePrediction(predictions: Seq[Int], input: Seq[(Int, Array[Double])]) {
val numOffPredictions = predictions.zip(input).filter { case (prediction, (expected, _)) =>
// A prediction is off if the prediction is more than 0.5 away from expected value.
- math.abs(prediction - expected) > 0.5
+ math.abs(prediction.toDouble - expected.toDouble) > 0.5
}.size
// At least 80% of the predictions should be on.
assert(numOffPredictions < input.length / 5)