aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorcoderxiang <shuoxiangpub@gmail.com>2014-10-27 19:43:39 -0700
committerXiangrui Meng <meng@databricks.com>2014-10-27 19:43:39 -0700
commit7e3a1ada86e6adf1ddd4d8a321824daf5f3b2c75 (patch)
tree28e56f2ef6007b5310f00fc23f6d4188eb331a31 /mllib
parentbfa614b12795f1cfce4de0950f90cb8c4f2a7d53 (diff)
downloadspark-7e3a1ada86e6adf1ddd4d8a321824daf5f3b2c75.tar.gz
spark-7e3a1ada86e6adf1ddd4d8a321824daf5f3b2c75.tar.bz2
spark-7e3a1ada86e6adf1ddd4d8a321824daf5f3b2c75.zip
[MLlib] SPARK-3987: add test case on objective value for NNLS
Also update step parameter to pass the proposed test Author: coderxiang <shuoxiangpub@gmail.com> Closes #2965 from coderxiang/nnls-test and squashes the following commits: 24b06f9 [coderxiang] add test case on objective value for NNLS; update step parameter to pass the test
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala30
2 files changed, 31 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala
index e4b436b023..fef062e02b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala
@@ -79,7 +79,7 @@ private[mllib] object NNLS {
// stopping condition
def stop(step: Double, ndir: Double, nx: Double): Boolean = {
((step.isNaN) // NaN
- || (step < 1e-6) // too small or negative
+ || (step < 1e-7) // too small or negative
|| (step > 1e40) // too small; almost certainly numerical problems
|| (ndir < 1e-12 * nx) // gradient relatively too small
|| (ndir < 1e-32) // gradient absolutely too small; numerical issues may lurk
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala
index b781a6aed9..82c327bd49 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala
@@ -37,6 +37,12 @@ class NNLSSuite extends FunSuite {
(ata, atb)
}
+ /** Compute the objective value */
+ def computeObjectiveValue(ata: DoubleMatrix, atb: DoubleMatrix, x: DoubleMatrix): Double = {
+ val res = (x.transpose().mmul(ata).mmul(x)).mul(0.5).sub(atb.dot(x))
+ res.get(0)
+ }
+
test("NNLS: exact solution cases") {
val n = 20
val rand = new Random(12346)
@@ -79,4 +85,28 @@ class NNLSSuite extends FunSuite {
assert(x(i) >= 0)
}
}
+
+ test("NNLS: objective value test") {
+ val n = 5
+ val ata = new DoubleMatrix(5, 5
+ , 517399.13534, 242529.67289, -153644.98976, 130802.84503, -798452.29283
+ , 242529.67289, 126017.69765, -75944.21743, 81785.36128, -405290.60884
+ , -153644.98976, -75944.21743, 46986.44577, -45401.12659, 247059.51049
+ , 130802.84503, 81785.36128, -45401.12659, 67457.31310, -253747.03819
+ , -798452.29283, -405290.60884, 247059.51049, -253747.03819, 1310939.40814
+ )
+ val atb = new DoubleMatrix(5, 1,
+ -31755.05710, 13047.14813, -20191.24443, 25993.77580, 11963.55017)
+
+ /** reference solution obtained from matlab function quadprog */
+ val refx = new DoubleMatrix(Array(34.90751, 103.96254, 0.00000, 27.82094, 58.79627))
+ val refObj = computeObjectiveValue(ata, atb, refx)
+
+
+ val ws = NNLS.createWorkspace(n)
+ val x = new DoubleMatrix(NNLS.solve(ata, atb, ws))
+ val obj = computeObjectiveValue(ata, atb, x)
+
+ assert(obj < refObj + 1E-5)
+ }
}