diff options
Diffstat (limited to 'mllib/src/main/scala/org')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala | 9 | ||||
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala | 21 |
2 files changed, 11 insertions, 19 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala index d8e1346194..899fe5e9e9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala @@ -134,9 +134,7 @@ class GaussianMixtureEM private ( // diagonal covariance matrices using component variances // derived from the samples val (weights, gaussians) = initialModel match { - case Some(gmm) => (gmm.weight, gmm.mu.zip(gmm.sigma).map { case(mu, sigma) => - new MultivariateGaussian(mu, sigma) - }) + case Some(gmm) => (gmm.weights, gmm.gaussians) case None => { val samples = breezeData.takeSample(withReplacement = true, k * nSamples, seed) @@ -176,10 +174,7 @@ class GaussianMixtureEM private ( iter += 1 } - // Need to convert the breeze matrices to MLlib matrices - val means = Array.tabulate(k) { i => gaussians(i).mu } - val sigmas = Array.tabulate(k) { i => gaussians(i).sigma } - new GaussianMixtureModel(weights, means, sigmas) + new GaussianMixtureModel(weights, gaussians) } /** Average of dense breeze vectors */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index 416cad080c..1a2178ee7f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.clustering import breeze.linalg.{DenseVector => BreezeVector} import org.apache.spark.rdd.RDD -import org.apache.spark.mllib.linalg.{Matrix, Vector} +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.util.MLUtils @@ -36,12 +36,13 @@ import org.apache.spark.mllib.util.MLUtils * covariance matrix for Gaussian i */ class GaussianMixtureModel( - val weight: Array[Double], - val mu: Array[Vector], - val sigma: Array[Matrix]) extends Serializable { + val weights: Array[Double], + val gaussians: Array[MultivariateGaussian]) extends Serializable { + + require(weights.length == gaussians.length, "Length of weight and Gaussian arrays must match") /** Number of gaussians in mixture */ - def k: Int = weight.length + def k: Int = weights.length /** Maps given points to their cluster indices. */ def predict(points: RDD[Vector]): RDD[Int] = { @@ -55,14 +56,10 @@ class GaussianMixtureModel( */ def predictSoft(points: RDD[Vector]): RDD[Array[Double]] = { val sc = points.sparkContext - val dists = sc.broadcast { - (0 until k).map { i => - new MultivariateGaussian(mu(i).toBreeze.toDenseVector, sigma(i).toBreeze.toDenseMatrix) - }.toArray - } - val weights = sc.broadcast(weight) + val bcDists = sc.broadcast(gaussians) + val bcWeights = sc.broadcast(weights) points.map { x => - computeSoftAssignments(x.toBreeze.toDenseVector, dists.value, weights.value, k) + computeSoftAssignments(x.toBreeze.toDenseVector, bcDists.value, bcWeights.value, k) } } |