aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala21
1 files changed, 9 insertions, 12 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
index 416cad080c..1a2178ee7f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
@@ -20,7 +20,7 @@ package org.apache.spark.mllib.clustering
import breeze.linalg.{DenseVector => BreezeVector}
import org.apache.spark.rdd.RDD
-import org.apache.spark.mllib.linalg.{Matrix, Vector}
+import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
import org.apache.spark.mllib.util.MLUtils
@@ -36,12 +36,13 @@ import org.apache.spark.mllib.util.MLUtils
* covariance matrix for Gaussian i
*/
class GaussianMixtureModel(
- val weight: Array[Double],
- val mu: Array[Vector],
- val sigma: Array[Matrix]) extends Serializable {
+ val weights: Array[Double],
+ val gaussians: Array[MultivariateGaussian]) extends Serializable {
+
+ require(weights.length == gaussians.length, "Length of weight and Gaussian arrays must match")
/** Number of gaussians in mixture */
- def k: Int = weight.length
+ def k: Int = weights.length
/** Maps given points to their cluster indices. */
def predict(points: RDD[Vector]): RDD[Int] = {
@@ -55,14 +56,10 @@ class GaussianMixtureModel(
*/
def predictSoft(points: RDD[Vector]): RDD[Array[Double]] = {
val sc = points.sparkContext
- val dists = sc.broadcast {
- (0 until k).map { i =>
- new MultivariateGaussian(mu(i).toBreeze.toDenseVector, sigma(i).toBreeze.toDenseMatrix)
- }.toArray
- }
- val weights = sc.broadcast(weight)
+ val bcDists = sc.broadcast(gaussians)
+ val bcWeights = sc.broadcast(weights)
points.map { x =>
- computeSoftAssignments(x.toBreeze.toDenseVector, dists.value, weights.value, k)
+ computeSoftAssignments(x.toBreeze.toDenseVector, bcDists.value, bcWeights.value, k)
}
}