aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTravis Galoppo <tjg2107@columbia.edu>2015-01-20 12:58:11 -0800
committerXiangrui Meng <meng@databricks.com>2015-01-20 12:58:11 -0800
commit23e25543beaa5966b5f07365f338ce338fd6d71f (patch)
tree8c0642e5e895a02b0d97f1a2f5ca45142751fab1
parent769aced9e7f058f5008ce405f7c9714c3db203be (diff)
downloadspark-23e25543beaa5966b5f07365f338ce338fd6d71f.tar.gz
spark-23e25543beaa5966b5f07365f338ce338fd6d71f.tar.bz2
spark-23e25543beaa5966b5f07365f338ce338fd6d71f.zip
SPARK-5019 [MLlib] - GaussianMixtureModel exposes instances of MultivariateGauss...
This PR modifies GaussianMixtureModel to expose instances of MutlivariateGaussian rather than separate mean and covariance arrays. Author: Travis Galoppo <tjg2107@columbia.edu> Closes #4088 from tgaloppo/spark-5019 and squashes the following commits: 3ef6c7f [Travis Galoppo] In GaussianMixtureModel: Changed name of weight, gaussian to weights, gaussians. Other sources modified accordingly. 091e8da [Travis Galoppo] SPARK-5019 - GaussianMixtureModel exposes instances of MultivariateGaussian rather than mean/covariance matrices
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala9
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala21
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala25
4 files changed, 26 insertions, 31 deletions
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala
index 948c350953..de58be38c7 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala
@@ -54,7 +54,7 @@ object DenseGmmEM {
for (i <- 0 until clusters.k) {
println("weight=%f\nmu=%s\nsigma=\n%s\n" format
- (clusters.weight(i), clusters.mu(i), clusters.sigma(i)))
+ (clusters.weights(i), clusters.gaussians(i).mu, clusters.gaussians(i).sigma))
}
println("Cluster labels (first <= 100):")
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala
index d8e1346194..899fe5e9e9 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala
@@ -134,9 +134,7 @@ class GaussianMixtureEM private (
// diagonal covariance matrices using component variances
// derived from the samples
val (weights, gaussians) = initialModel match {
- case Some(gmm) => (gmm.weight, gmm.mu.zip(gmm.sigma).map { case(mu, sigma) =>
- new MultivariateGaussian(mu, sigma)
- })
+ case Some(gmm) => (gmm.weights, gmm.gaussians)
case None => {
val samples = breezeData.takeSample(withReplacement = true, k * nSamples, seed)
@@ -176,10 +174,7 @@ class GaussianMixtureEM private (
iter += 1
}
- // Need to convert the breeze matrices to MLlib matrices
- val means = Array.tabulate(k) { i => gaussians(i).mu }
- val sigmas = Array.tabulate(k) { i => gaussians(i).sigma }
- new GaussianMixtureModel(weights, means, sigmas)
+ new GaussianMixtureModel(weights, gaussians)
}
/** Average of dense breeze vectors */
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
index 416cad080c..1a2178ee7f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
@@ -20,7 +20,7 @@ package org.apache.spark.mllib.clustering
import breeze.linalg.{DenseVector => BreezeVector}
import org.apache.spark.rdd.RDD
-import org.apache.spark.mllib.linalg.{Matrix, Vector}
+import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
import org.apache.spark.mllib.util.MLUtils
@@ -36,12 +36,13 @@ import org.apache.spark.mllib.util.MLUtils
* covariance matrix for Gaussian i
*/
class GaussianMixtureModel(
- val weight: Array[Double],
- val mu: Array[Vector],
- val sigma: Array[Matrix]) extends Serializable {
+ val weights: Array[Double],
+ val gaussians: Array[MultivariateGaussian]) extends Serializable {
+
+ require(weights.length == gaussians.length, "Length of weight and Gaussian arrays must match")
/** Number of gaussians in mixture */
- def k: Int = weight.length
+ def k: Int = weights.length
/** Maps given points to their cluster indices. */
def predict(points: RDD[Vector]): RDD[Int] = {
@@ -55,14 +56,10 @@ class GaussianMixtureModel(
*/
def predictSoft(points: RDD[Vector]): RDD[Array[Double]] = {
val sc = points.sparkContext
- val dists = sc.broadcast {
- (0 until k).map { i =>
- new MultivariateGaussian(mu(i).toBreeze.toDenseVector, sigma(i).toBreeze.toDenseMatrix)
- }.toArray
- }
- val weights = sc.broadcast(weight)
+ val bcDists = sc.broadcast(gaussians)
+ val bcWeights = sc.broadcast(weights)
points.map { x =>
- computeSoftAssignments(x.toBreeze.toDenseVector, dists.value, weights.value, k)
+ computeSoftAssignments(x.toBreeze.toDenseVector, bcDists.value, bcWeights.value, k)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala
index 9da5495741..198997b5bb 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.mllib.clustering
import org.scalatest.FunSuite
import org.apache.spark.mllib.linalg.{Vectors, Matrices}
+import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
@@ -39,9 +40,9 @@ class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContex
val seeds = Array(314589, 29032897, 50181, 494821, 4660)
seeds.foreach { seed =>
val gmm = new GaussianMixtureEM().setK(1).setSeed(seed).run(data)
- assert(gmm.weight(0) ~== Ew absTol 1E-5)
- assert(gmm.mu(0) ~== Emu absTol 1E-5)
- assert(gmm.sigma(0) ~== Esigma absTol 1E-5)
+ assert(gmm.weights(0) ~== Ew absTol 1E-5)
+ assert(gmm.gaussians(0).mu ~== Emu absTol 1E-5)
+ assert(gmm.gaussians(0).sigma ~== Esigma absTol 1E-5)
}
}
@@ -57,8 +58,10 @@ class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContex
// we set an initial gaussian to induce expected results
val initialGmm = new GaussianMixtureModel(
Array(0.5, 0.5),
- Array(Vectors.dense(-1.0), Vectors.dense(1.0)),
- Array(Matrices.dense(1, 1, Array(1.0)), Matrices.dense(1, 1, Array(1.0)))
+ Array(
+ new MultivariateGaussian(Vectors.dense(-1.0), Matrices.dense(1, 1, Array(1.0))),
+ new MultivariateGaussian(Vectors.dense(1.0), Matrices.dense(1, 1, Array(1.0)))
+ )
)
val Ew = Array(1.0 / 3.0, 2.0 / 3.0)
@@ -70,11 +73,11 @@ class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContex
.setInitialModel(initialGmm)
.run(data)
- assert(gmm.weight(0) ~== Ew(0) absTol 1E-3)
- assert(gmm.weight(1) ~== Ew(1) absTol 1E-3)
- assert(gmm.mu(0) ~== Emu(0) absTol 1E-3)
- assert(gmm.mu(1) ~== Emu(1) absTol 1E-3)
- assert(gmm.sigma(0) ~== Esigma(0) absTol 1E-3)
- assert(gmm.sigma(1) ~== Esigma(1) absTol 1E-3)
+ assert(gmm.weights(0) ~== Ew(0) absTol 1E-3)
+ assert(gmm.weights(1) ~== Ew(1) absTol 1E-3)
+ assert(gmm.gaussians(0).mu ~== Emu(0) absTol 1E-3)
+ assert(gmm.gaussians(1).mu ~== Emu(1) absTol 1E-3)
+ assert(gmm.gaussians(0).sigma ~== Esigma(0) absTol 1E-3)
+ assert(gmm.gaussians(1).sigma ~== Esigma(1) absTol 1E-3)
}
}