aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTravis Galoppo <tjg2107@columbia.edu>2014-12-29 15:29:15 -0800
committerXiangrui Meng <meng@databricks.com>2014-12-29 15:29:15 -0800
commit6cf6fdf3ff5d1cf33c2dc28f039adc4d7c0f0464 (patch)
treea9d9bcd5af2c93f1bb89cb63edc60278ba4124c2
parent9bc0df6804f241aff24520d9c6ec54d9b11f5785 (diff)
downloadspark-6cf6fdf3ff5d1cf33c2dc28f039adc4d7c0f0464.tar.gz
spark-6cf6fdf3ff5d1cf33c2dc28f039adc4d7c0f0464.tar.bz2
spark-6cf6fdf3ff5d1cf33c2dc28f039adc4d7c0f0464.zip
SPARK-4156 [MLLIB] EM algorithm for GMMs
Implementation of Expectation-Maximization for Gaussian Mixture Models. This is my maiden contribution to Apache Spark, so I apologize now if I have done anything incorrectly; having said that, this work is my own, and I offer it to the project under the project's open source license. Author: Travis Galoppo <tjg2107@columbia.edu> Author: Travis Galoppo <travis@localhost.localdomain> Author: tgaloppo <tjg2107@columbia.edu> Author: FlytxtRnD <meethu.mathew@flytxt.com> Closes #3022 from tgaloppo/master and squashes the following commits: aaa8f25 [Travis Galoppo] MLUtils: changed privacy of EPSILON from [util] to [mllib] 709e4bf [Travis Galoppo] fixed usage line to include optional maxIterations parameter acf1fba [Travis Galoppo] Fixed parameter comment in GaussianMixtureModel Made maximum iterations an optional parameter to DenseGmmEM 9b2fc2a [Travis Galoppo] Style improvements Changed ExpectationSum to a private class b97fe00 [Travis Galoppo] Minor fixes and tweaks. 1de73f3 [Travis Galoppo] Removed redundant array from array creation 578c2d1 [Travis Galoppo] Removed unused import 227ad66 [Travis Galoppo] Moved prediction methods into model class. 308c8ad [Travis Galoppo] Numerous changes to improve code cff73e0 [Travis Galoppo] Replaced accumulators with RDD.aggregate 20ebca1 [Travis Galoppo] Removed unusued code 42b2142 [Travis Galoppo] Added functionality to allow setting of GMM starting point. Added two cluster test to testing suite. 8b633f3 [Travis Galoppo] Style issue 9be2534 [Travis Galoppo] Style issue d695034 [Travis Galoppo] Fixed style issues c3b8ce0 [Travis Galoppo] Merge branch 'master' of https://github.com/tgaloppo/spark Adds predict() method 2df336b [Travis Galoppo] Fixed style issue b99ecc4 [tgaloppo] Merge pull request #1 from FlytxtRnD/predictBranch f407b4c [FlytxtRnD] Added predict() to return the cluster labels and membership values 97044cf [Travis Galoppo] Fixed style issues dc9c742 [Travis Galoppo] Moved MultivariateGaussian utility class e7d413b [Travis Galoppo] Moved multivariate Gaussian utility class to mllib/stat/impl Improved comments 9770261 [Travis Galoppo] Corrected a variety of style and naming issues. 8aaa17d [Travis Galoppo] Added additional train() method to companion object for cluster count and tolerance parameters. 676e523 [Travis Galoppo] Fixed to no longer ignore delta value provided on command line e6ea805 [Travis Galoppo] Merged with master branch; update test suite with latest context changes. Improved cluster initialization strategy. 86fb382 [Travis Galoppo] Merge remote-tracking branch 'upstream/master' 719d8cc [Travis Galoppo] Added scala test suite with basic test c1a8e16 [Travis Galoppo] Made GaussianMixtureModel class serializable Modified sum function for better performance 5c96c57 [Travis Galoppo] Merge remote-tracking branch 'upstream/master' c15405c [Travis Galoppo] SPARK-4156
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala67
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala241
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala91
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussian.scala39
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala78
6 files changed, 517 insertions, 1 deletions
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala
new file mode 100644
index 0000000000..948c350953
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.mllib
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.mllib.clustering.GaussianMixtureEM
+import org.apache.spark.mllib.linalg.Vectors
+
+/**
+ * An example Gaussian Mixture Model EM app. Run with
+ * {{{
+ * ./bin/run-example org.apache.spark.examples.mllib.DenseGmmEM <input> <k> <covergenceTol>
+ * }}}
+ * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
+ */
+object DenseGmmEM {
+ def main(args: Array[String]): Unit = {
+ if (args.length < 3) {
+ println("usage: DenseGmmEM <input file> <k> <convergenceTol> [maxIterations]")
+ } else {
+ val maxIterations = if (args.length > 3) args(3).toInt else 100
+ run(args(0), args(1).toInt, args(2).toDouble, maxIterations)
+ }
+ }
+
+ private def run(inputFile: String, k: Int, convergenceTol: Double, maxIterations: Int) {
+ val conf = new SparkConf().setAppName("Gaussian Mixture Model EM example")
+ val ctx = new SparkContext(conf)
+
+ val data = ctx.textFile(inputFile).map { line =>
+ Vectors.dense(line.trim.split(' ').map(_.toDouble))
+ }.cache()
+
+ val clusters = new GaussianMixtureEM()
+ .setK(k)
+ .setConvergenceTol(convergenceTol)
+ .setMaxIterations(maxIterations)
+ .run(data)
+
+ for (i <- 0 until clusters.k) {
+ println("weight=%f\nmu=%s\nsigma=\n%s\n" format
+ (clusters.weight(i), clusters.mu(i), clusters.sigma(i)))
+ }
+
+ println("Cluster labels (first <= 100):")
+ val clusterLabels = clusters.predict(data)
+ clusterLabels.take(100).foreach { x =>
+ print(" " + x)
+ }
+ println()
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala
new file mode 100644
index 0000000000..bdf984aee4
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala
@@ -0,0 +1,241 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.clustering
+
+import scala.collection.mutable.IndexedSeq
+
+import breeze.linalg.{DenseVector => BreezeVector, DenseMatrix => BreezeMatrix, diag, Transpose}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors}
+import org.apache.spark.mllib.stat.impl.MultivariateGaussian
+import org.apache.spark.mllib.util.MLUtils
+
+/**
+ * This class performs expectation maximization for multivariate Gaussian
+ * Mixture Models (GMMs). A GMM represents a composite distribution of
+ * independent Gaussian distributions with associated "mixing" weights
+ * specifying each's contribution to the composite.
+ *
+ * Given a set of sample points, this class will maximize the log-likelihood
+ * for a mixture of k Gaussians, iterating until the log-likelihood changes by
+ * less than convergenceTol, or until it has reached the max number of iterations.
+ * While this process is generally guaranteed to converge, it is not guaranteed
+ * to find a global optimum.
+ *
+ * @param k The number of independent Gaussians in the mixture model
+ * @param convergenceTol The maximum change in log-likelihood at which convergence
+ * is considered to have occurred.
+ * @param maxIterations The maximum number of iterations to perform
+ */
+class GaussianMixtureEM private (
+ private var k: Int,
+ private var convergenceTol: Double,
+ private var maxIterations: Int) extends Serializable {
+
+ /** A default instance, 2 Gaussians, 100 iterations, 0.01 log-likelihood threshold */
+ def this() = this(2, 0.01, 100)
+
+ // number of samples per cluster to use when initializing Gaussians
+ private val nSamples = 5
+
+ // an initializing GMM can be provided rather than using the
+ // default random starting point
+ private var initialModel: Option[GaussianMixtureModel] = None
+
+ /** Set the initial GMM starting point, bypassing the random initialization.
+ * You must call setK() prior to calling this method, and the condition
+ * (model.k == this.k) must be met; failure will result in an IllegalArgumentException
+ */
+ def setInitialModel(model: GaussianMixtureModel): this.type = {
+ if (model.k == k) {
+ initialModel = Some(model)
+ } else {
+ throw new IllegalArgumentException("mismatched cluster count (model.k != k)")
+ }
+ this
+ }
+
+ /** Return the user supplied initial GMM, if supplied */
+ def getInitialModel: Option[GaussianMixtureModel] = initialModel
+
+ /** Set the number of Gaussians in the mixture model. Default: 2 */
+ def setK(k: Int): this.type = {
+ this.k = k
+ this
+ }
+
+ /** Return the number of Gaussians in the mixture model */
+ def getK: Int = k
+
+ /** Set the maximum number of iterations to run. Default: 100 */
+ def setMaxIterations(maxIterations: Int): this.type = {
+ this.maxIterations = maxIterations
+ this
+ }
+
+ /** Return the maximum number of iterations to run */
+ def getMaxIterations: Int = maxIterations
+
+ /**
+ * Set the largest change in log-likelihood at which convergence is
+ * considered to have occurred.
+ */
+ def setConvergenceTol(convergenceTol: Double): this.type = {
+ this.convergenceTol = convergenceTol
+ this
+ }
+
+ /** Return the largest change in log-likelihood at which convergence is
+ * considered to have occurred.
+ */
+ def getConvergenceTol: Double = convergenceTol
+
+ /** Perform expectation maximization */
+ def run(data: RDD[Vector]): GaussianMixtureModel = {
+ val sc = data.sparkContext
+
+ // we will operate on the data as breeze data
+ val breezeData = data.map(u => u.toBreeze.toDenseVector).cache()
+
+ // Get length of the input vectors
+ val d = breezeData.first.length
+
+ // Determine initial weights and corresponding Gaussians.
+ // If the user supplied an initial GMM, we use those values, otherwise
+ // we start with uniform weights, a random mean from the data, and
+ // diagonal covariance matrices using component variances
+ // derived from the samples
+ val (weights, gaussians) = initialModel match {
+ case Some(gmm) => (gmm.weight, gmm.mu.zip(gmm.sigma).map { case(mu, sigma) =>
+ new MultivariateGaussian(mu.toBreeze.toDenseVector, sigma.toBreeze.toDenseMatrix)
+ })
+
+ case None => {
+ val samples = breezeData.takeSample(true, k * nSamples, scala.util.Random.nextInt)
+ (Array.fill(k)(1.0 / k), Array.tabulate(k) { i =>
+ val slice = samples.view(i * nSamples, (i + 1) * nSamples)
+ new MultivariateGaussian(vectorMean(slice), initCovariance(slice))
+ })
+ }
+ }
+
+ var llh = Double.MinValue // current log-likelihood
+ var llhp = 0.0 // previous log-likelihood
+
+ var iter = 0
+ while(iter < maxIterations && Math.abs(llh-llhp) > convergenceTol) {
+ // create and broadcast curried cluster contribution function
+ val compute = sc.broadcast(ExpectationSum.add(weights, gaussians)_)
+
+ // aggregate the cluster contribution for all sample points
+ val sums = breezeData.aggregate(ExpectationSum.zero(k, d))(compute.value, _ += _)
+
+ // Create new distributions based on the partial assignments
+ // (often referred to as the "M" step in literature)
+ val sumWeights = sums.weights.sum
+ var i = 0
+ while (i < k) {
+ val mu = sums.means(i) / sums.weights(i)
+ val sigma = sums.sigmas(i) / sums.weights(i) - mu * new Transpose(mu) // TODO: Use BLAS.dsyr
+ weights(i) = sums.weights(i) / sumWeights
+ gaussians(i) = new MultivariateGaussian(mu, sigma)
+ i = i + 1
+ }
+
+ llhp = llh // current becomes previous
+ llh = sums.logLikelihood // this is the freshly computed log-likelihood
+ iter += 1
+ }
+
+ // Need to convert the breeze matrices to MLlib matrices
+ val means = Array.tabulate(k) { i => Vectors.fromBreeze(gaussians(i).mu) }
+ val sigmas = Array.tabulate(k) { i => Matrices.fromBreeze(gaussians(i).sigma) }
+ new GaussianMixtureModel(weights, means, sigmas)
+ }
+
+ /** Average of dense breeze vectors */
+ private def vectorMean(x: IndexedSeq[BreezeVector[Double]]): BreezeVector[Double] = {
+ val v = BreezeVector.zeros[Double](x(0).length)
+ x.foreach(xi => v += xi)
+ v / x.length.toDouble
+ }
+
+ /**
+ * Construct matrix where diagonal entries are element-wise
+ * variance of input vectors (computes biased variance)
+ */
+ private def initCovariance(x: IndexedSeq[BreezeVector[Double]]): BreezeMatrix[Double] = {
+ val mu = vectorMean(x)
+ val ss = BreezeVector.zeros[Double](x(0).length)
+ x.map(xi => (xi - mu) :^ 2.0).foreach(u => ss += u)
+ diag(ss / x.length.toDouble)
+ }
+}
+
+// companion class to provide zero constructor for ExpectationSum
+private object ExpectationSum {
+ def zero(k: Int, d: Int): ExpectationSum = {
+ new ExpectationSum(0.0, Array.fill(k)(0.0),
+ Array.fill(k)(BreezeVector.zeros(d)), Array.fill(k)(BreezeMatrix.zeros(d,d)))
+ }
+
+ // compute cluster contributions for each input point
+ // (U, T) => U for aggregation
+ def add(
+ weights: Array[Double],
+ dists: Array[MultivariateGaussian])
+ (sums: ExpectationSum, x: BreezeVector[Double]): ExpectationSum = {
+ val p = weights.zip(dists).map {
+ case (weight, dist) => MLUtils.EPSILON + weight * dist.pdf(x)
+ }
+ val pSum = p.sum
+ sums.logLikelihood += math.log(pSum)
+ val xxt = x * new Transpose(x)
+ var i = 0
+ while (i < sums.k) {
+ p(i) /= pSum
+ sums.weights(i) += p(i)
+ sums.means(i) += x * p(i)
+ sums.sigmas(i) += xxt * p(i) // TODO: use BLAS.dsyr
+ i = i + 1
+ }
+ sums
+ }
+}
+
+// Aggregation class for partial expectation results
+private class ExpectationSum(
+ var logLikelihood: Double,
+ val weights: Array[Double],
+ val means: Array[BreezeVector[Double]],
+ val sigmas: Array[BreezeMatrix[Double]]) extends Serializable {
+
+ val k = weights.length
+
+ def +=(x: ExpectationSum): ExpectationSum = {
+ var i = 0
+ while (i < k) {
+ weights(i) += x.weights(i)
+ means(i) += x.means(i)
+ sigmas(i) += x.sigmas(i)
+ i = i + 1
+ }
+ logLikelihood += x.logLikelihood
+ this
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
new file mode 100644
index 0000000000..11a110db1f
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.clustering
+
+import breeze.linalg.{DenseVector => BreezeVector}
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.mllib.linalg.{Matrix, Vector}
+import org.apache.spark.mllib.stat.impl.MultivariateGaussian
+import org.apache.spark.mllib.util.MLUtils
+
+/**
+ * Multivariate Gaussian Mixture Model (GMM) consisting of k Gaussians, where points
+ * are drawn from each Gaussian i=1..k with probability w(i); mu(i) and sigma(i) are
+ * the respective mean and covariance for each Gaussian distribution i=1..k.
+ *
+ * @param weight Weights for each Gaussian distribution in the mixture, where weight(i) is
+ * the weight for Gaussian i, and weight.sum == 1
+ * @param mu Means for each Gaussian in the mixture, where mu(i) is the mean for Gaussian i
+ * @param sigma Covariance maxtrix for each Gaussian in the mixture, where sigma(i) is the
+ * covariance matrix for Gaussian i
+ */
+class GaussianMixtureModel(
+ val weight: Array[Double],
+ val mu: Array[Vector],
+ val sigma: Array[Matrix]) extends Serializable {
+
+ /** Number of gaussians in mixture */
+ def k: Int = weight.length
+
+ /** Maps given points to their cluster indices. */
+ def predict(points: RDD[Vector]): RDD[Int] = {
+ val responsibilityMatrix = predictMembership(points, mu, sigma, weight, k)
+ responsibilityMatrix.map(r => r.indexOf(r.max))
+ }
+
+ /**
+ * Given the input vectors, return the membership value of each vector
+ * to all mixture components.
+ */
+ def predictMembership(
+ points: RDD[Vector],
+ mu: Array[Vector],
+ sigma: Array[Matrix],
+ weight: Array[Double],
+ k: Int): RDD[Array[Double]] = {
+ val sc = points.sparkContext
+ val dists = sc.broadcast {
+ (0 until k).map { i =>
+ new MultivariateGaussian(mu(i).toBreeze.toDenseVector, sigma(i).toBreeze.toDenseMatrix)
+ }.toArray
+ }
+ val weights = sc.broadcast(weight)
+ points.map { x =>
+ computeSoftAssignments(x.toBreeze.toDenseVector, dists.value, weights.value, k)
+ }
+ }
+
+ /**
+ * Compute the partial assignments for each vector
+ */
+ private def computeSoftAssignments(
+ pt: BreezeVector[Double],
+ dists: Array[MultivariateGaussian],
+ weights: Array[Double],
+ k: Int): Array[Double] = {
+ val p = weights.zip(dists).map {
+ case (weight, dist) => MLUtils.EPSILON + weight * dist.pdf(pt)
+ }
+ val pSum = p.sum
+ for (i <- 0 until k) {
+ p(i) /= pSum
+ }
+ p
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussian.scala
new file mode 100644
index 0000000000..2eab5d2778
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussian.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.stat.impl
+
+import breeze.linalg.{DenseVector => DBV, DenseMatrix => DBM, Transpose, det, pinv}
+
+/**
+ * Utility class to implement the density function for multivariate Gaussian distribution.
+ * Breeze provides this functionality, but it requires the Apache Commons Math library,
+ * so this class is here so-as to not introduce a new dependency in Spark.
+ */
+private[mllib] class MultivariateGaussian(
+ val mu: DBV[Double],
+ val sigma: DBM[Double]) extends Serializable {
+ private val sigmaInv2 = pinv(sigma) * -0.5
+ private val U = math.pow(2.0 * math.Pi, -mu.length / 2.0) * math.pow(det(sigma), -0.5)
+
+ /** Returns density of this multivariate Gaussian at given point, x */
+ def pdf(x: DBV[Double]): Double = {
+ val delta = x - mu
+ val deltaTranspose = new Transpose(delta)
+ U * math.exp(deltaTranspose * sigmaInv2 * delta)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
index b0d05ae33e..1d07b5dab8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
@@ -39,7 +39,7 @@ import org.apache.spark.streaming.dstream.DStream
*/
object MLUtils {
- private[util] lazy val EPSILON = {
+ private[mllib] lazy val EPSILON = {
var eps = 1.0
while ((1.0 + (eps / 2.0)) != 1.0) {
eps /= 2.0
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala
new file mode 100644
index 0000000000..23feb82874
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.clustering
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.linalg.{Vectors, Matrices}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+
+class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContext {
+ test("single cluster") {
+ val data = sc.parallelize(Array(
+ Vectors.dense(6.0, 9.0),
+ Vectors.dense(5.0, 10.0),
+ Vectors.dense(4.0, 11.0)
+ ))
+
+ // expectations
+ val Ew = 1.0
+ val Emu = Vectors.dense(5.0, 10.0)
+ val Esigma = Matrices.dense(2, 2, Array(2.0 / 3.0, -2.0 / 3.0, -2.0 / 3.0, 2.0 / 3.0))
+
+ val gmm = new GaussianMixtureEM().setK(1).run(data)
+
+ assert(gmm.weight(0) ~== Ew absTol 1E-5)
+ assert(gmm.mu(0) ~== Emu absTol 1E-5)
+ assert(gmm.sigma(0) ~== Esigma absTol 1E-5)
+ }
+
+ test("two clusters") {
+ val data = sc.parallelize(Array(
+ Vectors.dense(-5.1971), Vectors.dense(-2.5359), Vectors.dense(-3.8220),
+ Vectors.dense(-5.2211), Vectors.dense(-5.0602), Vectors.dense( 4.7118),
+ Vectors.dense( 6.8989), Vectors.dense( 3.4592), Vectors.dense( 4.6322),
+ Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026),
+ Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734)
+ ))
+
+ // we set an initial gaussian to induce expected results
+ val initialGmm = new GaussianMixtureModel(
+ Array(0.5, 0.5),
+ Array(Vectors.dense(-1.0), Vectors.dense(1.0)),
+ Array(Matrices.dense(1, 1, Array(1.0)), Matrices.dense(1, 1, Array(1.0)))
+ )
+
+ val Ew = Array(1.0 / 3.0, 2.0 / 3.0)
+ val Emu = Array(Vectors.dense(-4.3673), Vectors.dense(5.1604))
+ val Esigma = Array(Matrices.dense(1, 1, Array(1.1098)), Matrices.dense(1, 1, Array(0.86644)))
+
+ val gmm = new GaussianMixtureEM()
+ .setK(2)
+ .setInitialModel(initialGmm)
+ .run(data)
+
+ assert(gmm.weight(0) ~== Ew(0) absTol 1E-3)
+ assert(gmm.weight(1) ~== Ew(1) absTol 1E-3)
+ assert(gmm.mu(0) ~== Emu(0) absTol 1E-3)
+ assert(gmm.mu(1) ~== Emu(1) absTol 1E-3)
+ assert(gmm.sigma(0) ~== Esigma(0) absTol 1E-3)
+ assert(gmm.sigma(1) ~== Esigma(1) absTol 1E-3)
+ }
+}