From 7e3a1ada86e6adf1ddd4d8a321824daf5f3b2c75 Mon Sep 17 00:00:00 2001 From: coderxiang Date: Mon, 27 Oct 2014 19:43:39 -0700 Subject: [MLlib] SPARK-3987: add test case on objective value for NNLS Also update step parameter to pass the proposed test Author: coderxiang Closes #2965 from coderxiang/nnls-test and squashes the following commits: 24b06f9 [coderxiang] add test case on objective value for NNLS; update step parameter to pass the test --- .../org/apache/spark/mllib/optimization/NNLS.scala | 2 +- .../spark/mllib/optimization/NNLSSuite.scala | 30 ++++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala index e4b436b023..fef062e02b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala @@ -79,7 +79,7 @@ private[mllib] object NNLS { // stopping condition def stop(step: Double, ndir: Double, nx: Double): Boolean = { ((step.isNaN) // NaN - || (step < 1e-6) // too small or negative + || (step < 1e-7) // too small or negative || (step > 1e40) // too small; almost certainly numerical problems || (ndir < 1e-12 * nx) // gradient relatively too small || (ndir < 1e-32) // gradient absolutely too small; numerical issues may lurk diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala index b781a6aed9..82c327bd49 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala @@ -37,6 +37,12 @@ class NNLSSuite extends FunSuite { (ata, atb) } + /** Compute the objective value */ + def computeObjectiveValue(ata: DoubleMatrix, atb: DoubleMatrix, x: DoubleMatrix): Double = { + val res = (x.transpose().mmul(ata).mmul(x)).mul(0.5).sub(atb.dot(x)) + res.get(0) + } + test("NNLS: exact solution cases") { val n = 20 val rand = new Random(12346) @@ -79,4 +85,28 @@ class NNLSSuite extends FunSuite { assert(x(i) >= 0) } } + + test("NNLS: objective value test") { + val n = 5 + val ata = new DoubleMatrix(5, 5 + , 517399.13534, 242529.67289, -153644.98976, 130802.84503, -798452.29283 + , 242529.67289, 126017.69765, -75944.21743, 81785.36128, -405290.60884 + , -153644.98976, -75944.21743, 46986.44577, -45401.12659, 247059.51049 + , 130802.84503, 81785.36128, -45401.12659, 67457.31310, -253747.03819 + , -798452.29283, -405290.60884, 247059.51049, -253747.03819, 1310939.40814 + ) + val atb = new DoubleMatrix(5, 1, + -31755.05710, 13047.14813, -20191.24443, 25993.77580, 11963.55017) + + /** reference solution obtained from matlab function quadprog */ + val refx = new DoubleMatrix(Array(34.90751, 103.96254, 0.00000, 27.82094, 58.79627)) + val refObj = computeObjectiveValue(ata, atb, refx) + + + val ws = NNLS.createWorkspace(n) + val x = new DoubleMatrix(NNLS.solve(ata, atb, ws)) + val obj = computeObjectiveValue(ata, atb, x) + + assert(obj < refObj + 1E-5) + } } -- cgit v1.2.3