aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorTravis Galoppo <tjg2107@columbia.edu>2015-01-11 21:31:16 -0800
committerXiangrui Meng <meng@databricks.com>2015-01-11 21:31:16 -0800
commit2130de9d8f50f52b9b2d296b377df81d840546b3 (patch)
tree07b9c8fcce82d15690ad7cba68ea024fa409980c /mllib/src/test
parentf38ef6586c2980183c983b2aa14a5ddc1856b7b7 (diff)
downloadspark-2130de9d8f50f52b9b2d296b377df81d840546b3.tar.gz
spark-2130de9d8f50f52b9b2d296b377df81d840546b3.tar.bz2
spark-2130de9d8f50f52b9b2d296b377df81d840546b3.zip
SPARK-5018 [MLlib] [WIP] Make MultivariateGaussian public
Moving MutlivariateGaussian from private[mllib] to public. The class uses Breeze vectors internally, so this involves creating a public interface using MLlib vectors and matrices. This initial commit provides public construction, accessors for mean/covariance, density and log-density. Other potential methods include entropy and sample generation. Author: Travis Galoppo <tjg2107@columbia.edu> Closes #3923 from tgaloppo/spark-5018 and squashes the following commits: 2b15587 [Travis Galoppo] Style correction b4121b4 [Travis Galoppo] Merge remote-tracking branch 'upstream/master' into spark-5018 e30a100 [Travis Galoppo] Made mu, sigma private[mllib] members of MultivariateGaussian Moved MultivariateGaussian (and test suite) from stat.impl to stat.distribution (required updates in GaussianMixture{EM,Model}.scala) Marked MultivariateGaussian as @DeveloperApi Fixed style error 9fa3bb7 [Travis Galoppo] Style improvements 91a5fae [Travis Galoppo] Rearranged equation for part of density function 8c35381 [Travis Galoppo] Fixed accessor methods to match member variable names. Modified calculations to avoid log(pow(x,y)) calculations 0943dc4 [Travis Galoppo] SPARK-5018 4dee9e1 [Travis Galoppo] SPARK-5018
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala (renamed from mllib/src/test/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussianSuite.scala)33
1 files changed, 16 insertions, 17 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussianSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala
index d58f2587e5..fac2498e4d 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussianSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala
@@ -15,54 +15,53 @@
* limitations under the License.
*/
-package org.apache.spark.mllib.stat.impl
+package org.apache.spark.mllib.stat.distribution
import org.scalatest.FunSuite
-import breeze.linalg.{ DenseVector => BDV, DenseMatrix => BDM }
-
+import org.apache.spark.mllib.linalg.{ Vectors, Matrices }
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
class MultivariateGaussianSuite extends FunSuite with MLlibTestSparkContext {
test("univariate") {
- val x1 = new BDV(Array(0.0))
- val x2 = new BDV(Array(1.5))
+ val x1 = Vectors.dense(0.0)
+ val x2 = Vectors.dense(1.5)
- val mu = new BDV(Array(0.0))
- val sigma1 = new BDM(1, 1, Array(1.0))
+ val mu = Vectors.dense(0.0)
+ val sigma1 = Matrices.dense(1, 1, Array(1.0))
val dist1 = new MultivariateGaussian(mu, sigma1)
assert(dist1.pdf(x1) ~== 0.39894 absTol 1E-5)
assert(dist1.pdf(x2) ~== 0.12952 absTol 1E-5)
- val sigma2 = new BDM(1, 1, Array(4.0))
+ val sigma2 = Matrices.dense(1, 1, Array(4.0))
val dist2 = new MultivariateGaussian(mu, sigma2)
assert(dist2.pdf(x1) ~== 0.19947 absTol 1E-5)
assert(dist2.pdf(x2) ~== 0.15057 absTol 1E-5)
}
test("multivariate") {
- val x1 = new BDV(Array(0.0, 0.0))
- val x2 = new BDV(Array(1.0, 1.0))
+ val x1 = Vectors.dense(0.0, 0.0)
+ val x2 = Vectors.dense(1.0, 1.0)
- val mu = new BDV(Array(0.0, 0.0))
- val sigma1 = new BDM(2, 2, Array(1.0, 0.0, 0.0, 1.0))
+ val mu = Vectors.dense(0.0, 0.0)
+ val sigma1 = Matrices.dense(2, 2, Array(1.0, 0.0, 0.0, 1.0))
val dist1 = new MultivariateGaussian(mu, sigma1)
assert(dist1.pdf(x1) ~== 0.15915 absTol 1E-5)
assert(dist1.pdf(x2) ~== 0.05855 absTol 1E-5)
- val sigma2 = new BDM(2, 2, Array(4.0, -1.0, -1.0, 2.0))
+ val sigma2 = Matrices.dense(2, 2, Array(4.0, -1.0, -1.0, 2.0))
val dist2 = new MultivariateGaussian(mu, sigma2)
assert(dist2.pdf(x1) ~== 0.060155 absTol 1E-5)
assert(dist2.pdf(x2) ~== 0.033971 absTol 1E-5)
}
test("multivariate degenerate") {
- val x1 = new BDV(Array(0.0, 0.0))
- val x2 = new BDV(Array(1.0, 1.0))
+ val x1 = Vectors.dense(0.0, 0.0)
+ val x2 = Vectors.dense(1.0, 1.0)
- val mu = new BDV(Array(0.0, 0.0))
- val sigma = new BDM(2, 2, Array(1.0, 1.0, 1.0, 1.0))
+ val mu = Vectors.dense(0.0, 0.0)
+ val sigma = Matrices.dense(2, 2, Array(1.0, 1.0, 1.0, 1.0))
val dist = new MultivariateGaussian(mu, sigma)
assert(dist.pdf(x1) ~== 0.11254 absTol 1E-5)
assert(dist.pdf(x2) ~== 0.068259 absTol 1E-5)