diff options
author | Sean Owen <sowen@cloudera.com> | 2015-10-27 23:07:37 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-10-27 23:07:37 -0700 |
commit | 826e1e304b57abbc56b8b7ffd663d53942ab3c7c (patch) | |
tree | 379cecd7931154b2ce835302106139f06af613be /mllib/src/test/scala/org | |
parent | d9c6039897236c3f1e4503aa95c5c9b07b32eadd (diff) | |
download | spark-826e1e304b57abbc56b8b7ffd663d53942ab3c7c.tar.gz spark-826e1e304b57abbc56b8b7ffd663d53942ab3c7c.tar.bz2 spark-826e1e304b57abbc56b8b7ffd663d53942ab3c7c.zip |
[SPARK-11302][MLLIB] 2) Multivariate Gaussian Model with Covariance matrix returns incorrect answer in some cases
Fix computation of root-sigma-inverse in multivariate Gaussian; add a test and fix related Python mixture model test.
Supersedes https://github.com/apache/spark/pull/9293
Author: Sean Owen <sowen@cloudera.com>
Closes #9309 from srowen/SPARK-11302.2.
Diffstat (limited to 'mllib/src/test/scala/org')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala | 15 |
1 files changed, 15 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala index aa60deb665..6e7a003475 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala @@ -65,4 +65,19 @@ class MultivariateGaussianSuite extends SparkFunSuite with MLlibTestSparkContext assert(dist.pdf(x1) ~== 0.11254 absTol 1E-5) assert(dist.pdf(x2) ~== 0.068259 absTol 1E-5) } + + test("SPARK-11302") { + val x = Vectors.dense(629, 640, 1.7188, 618.19) + val mu = Vectors.dense( + 1055.3910505836575, 1070.489299610895, 1.39020554474708, 1040.5907503867697) + val sigma = Matrices.dense(4, 4, Array( + 166769.00466698944, 169336.6705268059, 12.820670788921873, 164243.93314092053, + 169336.6705268059, 172041.5670061245, 21.62590020524533, 166678.01075856484, + 12.820670788921873, 21.62590020524533, 0.872524191943962, 4.283255814732373, + 164243.93314092053, 166678.01075856484, 4.283255814732373, 161848.9196719207)) + val dist = new MultivariateGaussian(mu, sigma) + // Agrees with R's dmvnorm: 7.154782e-05 + assert(dist.pdf(x) ~== 7.154782224045512E-5 absTol 1E-9) + } + } |