aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala9
1 files changed, 2 insertions, 7 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
index 11a110db1f..b461ea4f0f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
@@ -45,7 +45,7 @@ class GaussianMixtureModel(
/** Maps given points to their cluster indices. */
def predict(points: RDD[Vector]): RDD[Int] = {
- val responsibilityMatrix = predictMembership(points, mu, sigma, weight, k)
+ val responsibilityMatrix = predictSoft(points)
responsibilityMatrix.map(r => r.indexOf(r.max))
}
@@ -53,12 +53,7 @@ class GaussianMixtureModel(
* Given the input vectors, return the membership value of each vector
* to all mixture components.
*/
- def predictMembership(
- points: RDD[Vector],
- mu: Array[Vector],
- sigma: Array[Matrix],
- weight: Array[Double],
- k: Int): RDD[Array[Double]] = {
+ def predictSoft(points: RDD[Vector]): RDD[Array[Double]] = {
val sc = points.sparkContext
val dists = sc.broadcast {
(0 until k).map { i =>