diff options
author | DB Tsai <dbtsai@alpinenow.com> | 2014-12-18 13:55:49 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2014-12-18 13:55:49 -0800 |
commit | 59a49db5982ecc487187fcd92399e08b4b4bea64 (patch) | |
tree | d51225b6bc70c0cee3e1dd822c843abbb11757b6 /mllib | |
parent | 3720057b8e7c15c2c0464b5bb7243bc22323f4e8 (diff) | |
download | spark-59a49db5982ecc487187fcd92399e08b4b4bea64.tar.gz spark-59a49db5982ecc487187fcd92399e08b4b4bea64.tar.bz2 spark-59a49db5982ecc487187fcd92399e08b4b4bea64.zip |
[SPARK-4887][MLlib] Fix a bad unittest in LogisticRegressionSuite
The original test doesn't make sense since if you step in, the lossSum is already NaN,
and the coefficients are diverging. That's because the step size is too large for SGD,
so it doesn't work.
The correct behavior is that you should get smaller coefficients than the one
without regularization. Comparing the values using 20000.0 relative error doesn't
make sense as well.
Author: DB Tsai <dbtsai@alpinenow.com>
Closes #3735 from dbtsai/mlortestfix and squashes the following commits:
b1a3c42 [DB Tsai] first commit
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index 4e81299440..94b0e00f37 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -178,15 +178,16 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M // Use half as many iterations as the previous test. val lr = new LogisticRegressionWithSGD().setIntercept(true) lr.optimizer. - setStepSize(10.0). + setStepSize(1.0). setNumIterations(10). setRegParam(1.0) val model = lr.run(testRDD, initialWeights) // Test the weights - assert(model.weights(0) ~== -430000.0 relTol 20000.0) - assert(model.intercept ~== 370000.0 relTol 20000.0) + // With regularization, the resulting weights will be smaller. + assert(model.weights(0) ~== -0.14 relTol 0.02) + assert(model.intercept ~== 0.25 relTol 0.02) val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) val validationRDD = sc.parallelize(validationData, 2) |