diff options
author | Travis Galoppo <tjg2107@columbia.edu> | 2014-12-31 15:39:58 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2014-12-31 15:39:58 -0800 |
commit | c4f0b4f334f7f3565375921fcac184ad5b1fb207 (patch) | |
tree | b8a5155efa5b255f577954f3328f660b2c5b2173 /mllib | |
parent | fdc2aa4918fd4c510f04812b782cc0bfef9a2107 (diff) | |
download | spark-c4f0b4f334f7f3565375921fcac184ad5b1fb207.tar.gz spark-c4f0b4f334f7f3565375921fcac184ad5b1fb207.tar.bz2 spark-c4f0b4f334f7f3565375921fcac184ad5b1fb207.zip |
SPARK-5020 [MLlib] GaussianMixtureModel.predictMembership() should take an RDD only
Removed unnecessary parameters to predictMembership()
CC: jkbradley
Author: Travis Galoppo <tjg2107@columbia.edu>
Closes #3854 from tgaloppo/spark-5020 and squashes the following commits:
1bf4669 [Travis Galoppo] renamed predictMembership() to predictSoft()
0f1d96e [Travis Galoppo] SPARK-5020 - Removed superfluous parameters from predictMembership()
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala | 9 |
1 files changed, 2 insertions, 7 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index 11a110db1f..b461ea4f0f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -45,7 +45,7 @@ class GaussianMixtureModel( /** Maps given points to their cluster indices. */ def predict(points: RDD[Vector]): RDD[Int] = { - val responsibilityMatrix = predictMembership(points, mu, sigma, weight, k) + val responsibilityMatrix = predictSoft(points) responsibilityMatrix.map(r => r.indexOf(r.max)) } @@ -53,12 +53,7 @@ class GaussianMixtureModel( * Given the input vectors, return the membership value of each vector * to all mixture components. */ - def predictMembership( - points: RDD[Vector], - mu: Array[Vector], - sigma: Array[Matrix], - weight: Array[Double], - k: Int): RDD[Array[Double]] = { + def predictSoft(points: RDD[Vector]): RDD[Array[Double]] = { val sc = points.sparkContext val dists = sc.broadcast { (0 until k).map { i => |