aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorReynold Xin <reynoldx@gmail.com>2013-07-23 12:52:15 -0700
committerReynold Xin <reynoldx@gmail.com>2013-07-23 12:52:15 -0700
commit2210e8ccf8d77f65442a344c4eae39e000fba927 (patch)
treeae961efd8905c9618352a40580941241d3b28217 /mllib
parent87a9dd898ff51fd110799edae087d59f6b714211 (diff)
downloadspark-2210e8ccf8d77f65442a344c4eae39e000fba927.tar.gz
spark-2210e8ccf8d77f65442a344c4eae39e000fba927.tar.bz2
spark-2210e8ccf8d77f65442a344c4eae39e000fba927.zip
Use a different validation dataset for Logistic Regression prediction testing.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/test/scala/spark/mllib/regression/LogisticRegressionSuite.scala29
1 files changed, 17 insertions, 12 deletions
diff --git a/mllib/src/test/scala/spark/mllib/regression/LogisticRegressionSuite.scala b/mllib/src/test/scala/spark/mllib/regression/LogisticRegressionSuite.scala
index 6a8098b59d..0a99b78cf8 100644
--- a/mllib/src/test/scala/spark/mllib/regression/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/spark/mllib/regression/LogisticRegressionSuite.scala
@@ -35,10 +35,11 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll {
// Generate input of the form Y = logistic(offset + scale*X)
def generateLogisticInput(
- offset: Double,
- scale: Double,
- nPoints: Int) : Seq[(Double, Array[Double])] = {
- val rnd = new Random(42)
+ offset: Double,
+ scale: Double,
+ nPoints: Int,
+ seed: Int): Seq[(Double, Array[Double])] = {
+ val rnd = new Random(seed)
val x1 = Array.fill[Double](nPoints)(rnd.nextGaussian())
// NOTE: if U is uniform[0, 1] then ln(u) - ln(1-u) is Logistic(0,1)
@@ -60,12 +61,12 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll {
}
def validatePrediction(predictions: Seq[Double], input: Seq[(Double, Array[Double])]) {
- val offPredictions = predictions.zip(input).filter { case (prediction, (expected, _)) =>
+ val numOffPredictions = predictions.zip(input).filter { case (prediction, (expected, _)) =>
// A prediction is off if the prediction is more than 0.5 away from expected value.
math.abs(prediction - expected) > 0.5
}.size
// At least 80% of the predictions should be on.
- assert(offPredictions < input.length / 5)
+ assert(numOffPredictions < input.length / 5)
}
// Test if we can correctly learn A, B where Y = logistic(A + B*X)
@@ -74,7 +75,7 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll {
val A = 2.0
val B = -1.5
- val testData = generateLogisticInput(A, B, nPoints)
+ val testData = generateLogisticInput(A, B, nPoints, 42)
val testRDD = sc.parallelize(testData, 2)
testRDD.cache()
@@ -87,11 +88,13 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll {
assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]")
assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]")
+ val validationData = generateLogisticInput(A, B, nPoints, 17)
+ val validationRDD = sc.parallelize(validationData, 2)
// Test prediction on RDD.
- validatePrediction(model.predict(testRDD.map(_._2)).collect(), testData)
+ validatePrediction(model.predict(validationRDD.map(_._2)).collect(), validationData)
// Test prediction on Array.
- validatePrediction(testData.map(row => model.predict(row._2)), testData)
+ validatePrediction(validationData.map(row => model.predict(row._2)), validationData)
}
test("logistic regression with initial weights") {
@@ -99,7 +102,7 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll {
val A = 2.0
val B = -1.5
- val testData = generateLogisticInput(A, B, nPoints)
+ val testData = generateLogisticInput(A, B, nPoints, 42)
val initialB = -1.0
val initialWeights = Array(initialB)
@@ -116,10 +119,12 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll {
assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]")
assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]")
+ val validationData = generateLogisticInput(A, B, nPoints, 17)
+ val validationRDD = sc.parallelize(validationData, 2)
// Test prediction on RDD.
- validatePrediction(model.predict(testRDD.map(_._2)).collect(), testData)
+ validatePrediction(model.predict(validationRDD.map(_._2)).collect(), validationData)
// Test prediction on Array.
- validatePrediction(testData.map(row => model.predict(row._2)), testData)
+ validatePrediction(validationData.map(row => model.predict(row._2)), validationData)
}
}