diff options
author | Reynold Xin <reynoldx@gmail.com> | 2013-07-23 12:52:15 -0700 |
---|---|---|
committer | Reynold Xin <reynoldx@gmail.com> | 2013-07-23 12:52:15 -0700 |
commit | 2210e8ccf8d77f65442a344c4eae39e000fba927 (patch) | |
tree | ae961efd8905c9618352a40580941241d3b28217 /mllib | |
parent | 87a9dd898ff51fd110799edae087d59f6b714211 (diff) | |
download | spark-2210e8ccf8d77f65442a344c4eae39e000fba927.tar.gz spark-2210e8ccf8d77f65442a344c4eae39e000fba927.tar.bz2 spark-2210e8ccf8d77f65442a344c4eae39e000fba927.zip |
Use a different validation dataset for Logistic Regression prediction testing.
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/test/scala/spark/mllib/regression/LogisticRegressionSuite.scala | 29 |
1 files changed, 17 insertions, 12 deletions
diff --git a/mllib/src/test/scala/spark/mllib/regression/LogisticRegressionSuite.scala b/mllib/src/test/scala/spark/mllib/regression/LogisticRegressionSuite.scala index 6a8098b59d..0a99b78cf8 100644 --- a/mllib/src/test/scala/spark/mllib/regression/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/spark/mllib/regression/LogisticRegressionSuite.scala @@ -35,10 +35,11 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll { // Generate input of the form Y = logistic(offset + scale*X) def generateLogisticInput( - offset: Double, - scale: Double, - nPoints: Int) : Seq[(Double, Array[Double])] = { - val rnd = new Random(42) + offset: Double, + scale: Double, + nPoints: Int, + seed: Int): Seq[(Double, Array[Double])] = { + val rnd = new Random(seed) val x1 = Array.fill[Double](nPoints)(rnd.nextGaussian()) // NOTE: if U is uniform[0, 1] then ln(u) - ln(1-u) is Logistic(0,1) @@ -60,12 +61,12 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll { } def validatePrediction(predictions: Seq[Double], input: Seq[(Double, Array[Double])]) { - val offPredictions = predictions.zip(input).filter { case (prediction, (expected, _)) => + val numOffPredictions = predictions.zip(input).filter { case (prediction, (expected, _)) => // A prediction is off if the prediction is more than 0.5 away from expected value. math.abs(prediction - expected) > 0.5 }.size // At least 80% of the predictions should be on. - assert(offPredictions < input.length / 5) + assert(numOffPredictions < input.length / 5) } // Test if we can correctly learn A, B where Y = logistic(A + B*X) @@ -74,7 +75,7 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll { val A = 2.0 val B = -1.5 - val testData = generateLogisticInput(A, B, nPoints) + val testData = generateLogisticInput(A, B, nPoints, 42) val testRDD = sc.parallelize(testData, 2) testRDD.cache() @@ -87,11 +88,13 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll { assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]") assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]") + val validationData = generateLogisticInput(A, B, nPoints, 17) + val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. - validatePrediction(model.predict(testRDD.map(_._2)).collect(), testData) + validatePrediction(model.predict(validationRDD.map(_._2)).collect(), validationData) // Test prediction on Array. - validatePrediction(testData.map(row => model.predict(row._2)), testData) + validatePrediction(validationData.map(row => model.predict(row._2)), validationData) } test("logistic regression with initial weights") { @@ -99,7 +102,7 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll { val A = 2.0 val B = -1.5 - val testData = generateLogisticInput(A, B, nPoints) + val testData = generateLogisticInput(A, B, nPoints, 42) val initialB = -1.0 val initialWeights = Array(initialB) @@ -116,10 +119,12 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll { assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]") assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]") + val validationData = generateLogisticInput(A, B, nPoints, 17) + val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. - validatePrediction(model.predict(testRDD.map(_._2)).collect(), testData) + validatePrediction(model.predict(validationRDD.map(_._2)).collect(), validationData) // Test prediction on Array. - validatePrediction(testData.map(row => model.predict(row._2)), testData) + validatePrediction(validationData.map(row => model.predict(row._2)), validationData) } } |