aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTravis Galoppo <tjg2107@columbia.edu>2014-12-31 15:39:58 -0800
committerXiangrui Meng <meng@databricks.com>2014-12-31 15:39:58 -0800
commitc4f0b4f334f7f3565375921fcac184ad5b1fb207 (patch)
treeb8a5155efa5b255f577954f3328f660b2c5b2173
parentfdc2aa4918fd4c510f04812b782cc0bfef9a2107 (diff)
downloadspark-c4f0b4f334f7f3565375921fcac184ad5b1fb207.tar.gz
spark-c4f0b4f334f7f3565375921fcac184ad5b1fb207.tar.bz2
spark-c4f0b4f334f7f3565375921fcac184ad5b1fb207.zip
SPARK-5020 [MLlib] GaussianMixtureModel.predictMembership() should take an RDD only
Removed unnecessary parameters to predictMembership() CC: jkbradley Author: Travis Galoppo <tjg2107@columbia.edu> Closes #3854 from tgaloppo/spark-5020 and squashes the following commits: 1bf4669 [Travis Galoppo] renamed predictMembership() to predictSoft() 0f1d96e [Travis Galoppo] SPARK-5020 - Removed superfluous parameters from predictMembership()
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala9
1 files changed, 2 insertions, 7 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
index 11a110db1f..b461ea4f0f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
@@ -45,7 +45,7 @@ class GaussianMixtureModel(
/** Maps given points to their cluster indices. */
def predict(points: RDD[Vector]): RDD[Int] = {
- val responsibilityMatrix = predictMembership(points, mu, sigma, weight, k)
+ val responsibilityMatrix = predictSoft(points)
responsibilityMatrix.map(r => r.indexOf(r.max))
}
@@ -53,12 +53,7 @@ class GaussianMixtureModel(
* Given the input vectors, return the membership value of each vector
* to all mixture components.
*/
- def predictMembership(
- points: RDD[Vector],
- mu: Array[Vector],
- sigma: Array[Matrix],
- weight: Array[Double],
- k: Int): RDD[Array[Double]] = {
+ def predictSoft(points: RDD[Vector]): RDD[Array[Double]] = {
val sc = points.sparkContext
val dists = sc.broadcast {
(0 until k).map { i =>