aboutsummaryrefslogtreecommitdiff
path: root/mllib-local
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2016-04-26 16:53:16 -0700
committerDB Tsai <dbt@netflix.com>2016-04-26 16:53:16 -0700
commitbd2c9a6d48ef6d489c747d9db2642bdef6b1f728 (patch)
tree9a8a4864825aca4e8f11d4442d33e1ca4f7ac0c4 /mllib-local
parent0c99c23b7d9f0c3538cd2b062d551411712a2bcc (diff)
downloadspark-bd2c9a6d48ef6d489c747d9db2642bdef6b1f728.tar.gz
spark-bd2c9a6d48ef6d489c747d9db2642bdef6b1f728.tar.bz2
spark-bd2c9a6d48ef6d489c747d9db2642bdef6b1f728.zip
[SPARK-14732][ML] spark.ml GaussianMixture should use MultivariateGaussian in mllib-local
## What changes were proposed in this pull request? Before, spark.ml GaussianMixtureModel used the spark.mllib MultivariateGaussian in its public API. This was added after 1.6, so we can modify this API without breaking APIs. This PR copies MultivariateGaussian to mllib-local in spark.ml, with a few changes: * Renamed fields to match numpy, scipy: mu => mean, sigma => cov This PR then uses the spark.ml MultivariateGaussian in the spark.ml GaussianMixtureModel, which involves: * Modifying the constructor * Adding a computeProbabilities method Also: * Added EPSILON to mllib-local for use in MultivariateGaussian ## How was this patch tested? Existing unit tests Author: Joseph K. Bradley <joseph@databricks.com> Closes #12593 from jkbradley/sparkml-gmm-fix.
Diffstat (limited to 'mllib-local')
-rw-r--r--mllib-local/src/main/scala/org/apache/spark/ml/impl/Utils.scala30
-rw-r--r--mllib-local/src/main/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussian.scala131
-rw-r--r--mllib-local/src/test/scala/org/apache/spark/ml/impl/UtilsSuite.scala30
-rw-r--r--mllib-local/src/test/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussianSuite.scala83
4 files changed, 274 insertions, 0 deletions
diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/impl/Utils.scala b/mllib-local/src/main/scala/org/apache/spark/ml/impl/Utils.scala
new file mode 100644
index 0000000000..112de982e4
--- /dev/null
+++ b/mllib-local/src/main/scala/org/apache/spark/ml/impl/Utils.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.impl
+
+
+private[ml] object Utils {
+
+ lazy val EPSILON = {
+ var eps = 1.0
+ while ((1.0 + (eps / 2.0)) != 1.0) {
+ eps /= 2.0
+ }
+ eps
+ }
+}
diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussian.scala b/mllib-local/src/main/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussian.scala
new file mode 100644
index 0000000000..c62a1eab20
--- /dev/null
+++ b/mllib-local/src/main/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussian.scala
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.stat.distribution
+
+import breeze.linalg.{diag, eigSym, max, DenseMatrix => BDM, DenseVector => BDV, Vector => BV}
+
+import org.apache.spark.ml.impl.Utils
+import org.apache.spark.ml.linalg.{Matrices, Matrix, Vector, Vectors}
+
+
+/**
+ * This class provides basic functionality for a Multivariate Gaussian (Normal) Distribution. In
+ * the event that the covariance matrix is singular, the density will be computed in a
+ * reduced dimensional subspace under which the distribution is supported.
+ * (see [[http://en.wikipedia.org/wiki/Multivariate_normal_distribution#Degenerate_case]])
+ *
+ * @param mean The mean vector of the distribution
+ * @param cov The covariance matrix of the distribution
+ */
+class MultivariateGaussian(
+ val mean: Vector,
+ val cov: Matrix) extends Serializable {
+
+ require(cov.numCols == cov.numRows, "Covariance matrix must be square")
+ require(mean.size == cov.numCols, "Mean vector length must match covariance matrix size")
+
+ /** Private constructor taking Breeze types */
+ private[ml] def this(mean: BDV[Double], cov: BDM[Double]) = {
+ this(Vectors.fromBreeze(mean), Matrices.fromBreeze(cov))
+ }
+
+ private val breezeMu = mean.toBreeze.toDenseVector
+
+ /**
+ * Compute distribution dependent constants:
+ * rootSigmaInv = D^(-1/2)^ * U.t, where sigma = U * D * U.t
+ * u = log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^)
+ */
+ private val (rootSigmaInv: BDM[Double], u: Double) = calculateCovarianceConstants
+
+ /**
+ * Returns density of this multivariate Gaussian at given point, x
+ */
+ def pdf(x: Vector): Double = {
+ pdf(x.toBreeze)
+ }
+
+ /**
+ * Returns the log-density of this multivariate Gaussian at given point, x
+ */
+ def logpdf(x: Vector): Double = {
+ logpdf(x.toBreeze)
+ }
+
+ /** Returns density of this multivariate Gaussian at given point, x */
+ private[ml] def pdf(x: BV[Double]): Double = {
+ math.exp(logpdf(x))
+ }
+
+ /** Returns the log-density of this multivariate Gaussian at given point, x */
+ private[ml] def logpdf(x: BV[Double]): Double = {
+ val delta = x - breezeMu
+ val v = rootSigmaInv * delta
+ u + v.t * v * -0.5
+ }
+
+ /**
+ * Calculate distribution dependent components used for the density function:
+ * pdf(x) = (2*pi)^(-k/2)^ * det(sigma)^(-1/2)^ * exp((-1/2) * (x-mu).t * inv(sigma) * (x-mu))
+ * where k is length of the mean vector.
+ *
+ * We here compute distribution-fixed parts
+ * log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^)
+ * and
+ * D^(-1/2)^ * U, where sigma = U * D * U.t
+ *
+ * Both the determinant and the inverse can be computed from the singular value decomposition
+ * of sigma. Noting that covariance matrices are always symmetric and positive semi-definite,
+ * we can use the eigendecomposition. We also do not compute the inverse directly; noting
+ * that
+ *
+ * sigma = U * D * U.t
+ * inv(Sigma) = U * inv(D) * U.t
+ * = (D^{-1/2}^ * U.t).t * (D^{-1/2}^ * U.t)
+ *
+ * and thus
+ *
+ * -0.5 * (x-mu).t * inv(Sigma) * (x-mu) = -0.5 * norm(D^{-1/2}^ * U.t * (x-mu))^2^
+ *
+ * To guard against singular covariance matrices, this method computes both the
+ * pseudo-determinant and the pseudo-inverse (Moore-Penrose). Singular values are considered
+ * to be non-zero only if they exceed a tolerance based on machine precision, matrix size, and
+ * relation to the maximum singular value (same tolerance used by, e.g., Octave).
+ */
+ private def calculateCovarianceConstants: (BDM[Double], Double) = {
+ val eigSym.EigSym(d, u) = eigSym(cov.toBreeze.toDenseMatrix) // sigma = u * diag(d) * u.t
+
+ // For numerical stability, values are considered to be non-zero only if they exceed tol.
+ // This prevents any inverted value from exceeding (eps * n * max(d))^-1
+ val tol = Utils.EPSILON * max(d) * d.length
+
+ try {
+ // log(pseudo-determinant) is sum of the logs of all non-zero singular values
+ val logPseudoDetSigma = d.activeValuesIterator.filter(_ > tol).map(math.log).sum
+
+ // calculate the root-pseudo-inverse of the diagonal matrix of singular values
+ // by inverting the square root of all non-zero values
+ val pinvS = diag(new BDV(d.map(v => if (v > tol) math.sqrt(1.0 / v) else 0.0).toArray))
+
+ (pinvS * u.t, -0.5 * (mean.size * math.log(2.0 * math.Pi) + logPseudoDetSigma))
+ } catch {
+ case uex: UnsupportedOperationException =>
+ throw new IllegalArgumentException("Covariance matrix has no non-zero singular values")
+ }
+ }
+}
diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/impl/UtilsSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/impl/UtilsSuite.scala
new file mode 100644
index 0000000000..44b122b694
--- /dev/null
+++ b/mllib-local/src/test/scala/org/apache/spark/ml/impl/UtilsSuite.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.impl
+
+import org.apache.spark.ml.impl.Utils.EPSILON
+import org.apache.spark.ml.SparkMLFunSuite
+
+
+class UtilsSuite extends SparkMLFunSuite {
+
+ test("EPSILON") {
+ assert(1.0 + EPSILON > 1.0, s"EPSILON is too small: $EPSILON.")
+ assert(1.0 + EPSILON / 2.0 === 1.0, s"EPSILON is too big: $EPSILON.")
+ }
+}
diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussianSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussianSuite.scala
new file mode 100644
index 0000000000..f9306ed83e
--- /dev/null
+++ b/mllib-local/src/test/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussianSuite.scala
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.stat.distribution
+
+import org.apache.spark.ml.SparkMLFunSuite
+import org.apache.spark.ml.linalg.{Matrices, Vectors}
+import org.apache.spark.ml.util.TestingUtils._
+
+
+class MultivariateGaussianSuite extends SparkMLFunSuite {
+
+ test("univariate") {
+ val x1 = Vectors.dense(0.0)
+ val x2 = Vectors.dense(1.5)
+
+ val mu = Vectors.dense(0.0)
+ val sigma1 = Matrices.dense(1, 1, Array(1.0))
+ val dist1 = new MultivariateGaussian(mu, sigma1)
+ assert(dist1.pdf(x1) ~== 0.39894 absTol 1E-5)
+ assert(dist1.pdf(x2) ~== 0.12952 absTol 1E-5)
+
+ val sigma2 = Matrices.dense(1, 1, Array(4.0))
+ val dist2 = new MultivariateGaussian(mu, sigma2)
+ assert(dist2.pdf(x1) ~== 0.19947 absTol 1E-5)
+ assert(dist2.pdf(x2) ~== 0.15057 absTol 1E-5)
+ }
+
+ test("multivariate") {
+ val x1 = Vectors.dense(0.0, 0.0)
+ val x2 = Vectors.dense(1.0, 1.0)
+
+ val mu = Vectors.dense(0.0, 0.0)
+ val sigma1 = Matrices.dense(2, 2, Array(1.0, 0.0, 0.0, 1.0))
+ val dist1 = new MultivariateGaussian(mu, sigma1)
+ assert(dist1.pdf(x1) ~== 0.15915 absTol 1E-5)
+ assert(dist1.pdf(x2) ~== 0.05855 absTol 1E-5)
+
+ val sigma2 = Matrices.dense(2, 2, Array(4.0, -1.0, -1.0, 2.0))
+ val dist2 = new MultivariateGaussian(mu, sigma2)
+ assert(dist2.pdf(x1) ~== 0.060155 absTol 1E-5)
+ assert(dist2.pdf(x2) ~== 0.033971 absTol 1E-5)
+ }
+
+ test("multivariate degenerate") {
+ val x1 = Vectors.dense(0.0, 0.0)
+ val x2 = Vectors.dense(1.0, 1.0)
+
+ val mu = Vectors.dense(0.0, 0.0)
+ val sigma = Matrices.dense(2, 2, Array(1.0, 1.0, 1.0, 1.0))
+ val dist = new MultivariateGaussian(mu, sigma)
+ assert(dist.pdf(x1) ~== 0.11254 absTol 1E-5)
+ assert(dist.pdf(x2) ~== 0.068259 absTol 1E-5)
+ }
+
+ test("SPARK-11302") {
+ val x = Vectors.dense(629, 640, 1.7188, 618.19)
+ val mu = Vectors.dense(
+ 1055.3910505836575, 1070.489299610895, 1.39020554474708, 1040.5907503867697)
+ val sigma = Matrices.dense(4, 4, Array(
+ 166769.00466698944, 169336.6705268059, 12.820670788921873, 164243.93314092053,
+ 169336.6705268059, 172041.5670061245, 21.62590020524533, 166678.01075856484,
+ 12.820670788921873, 21.62590020524533, 0.872524191943962, 4.283255814732373,
+ 164243.93314092053, 166678.01075856484, 4.283255814732373, 161848.9196719207))
+ val dist = new MultivariateGaussian(mu, sigma)
+ // Agrees with R's dmvnorm: 7.154782e-05
+ assert(dist.pdf(x) ~== 7.154782224045512E-5 absTol 1E-9)
+ }
+}