aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-01-11 14:43:25 -0800
committerJoseph K. Bradley <joseph@databricks.com>2016-01-11 14:43:25 -0800
commitee4ee02b86be8756a6d895a2e23e80862134a6d3 (patch)
tree7d4402425ee3c9c10e38bd5bb85d9186cbde8fc9 /mllib
parenta767ee8a0599f5482717493a3298413c65d8ff89 (diff)
downloadspark-ee4ee02b86be8756a6d895a2e23e80862134a6d3.tar.gz
spark-ee4ee02b86be8756a6d895a2e23e80862134a6d3.tar.bz2
spark-ee4ee02b86be8756a6d895a2e23e80862134a6d3.zip
[SPARK-12603][MLLIB] PySpark MLlib GaussianMixtureModel should support single instance predict/predictSoft
PySpark MLlib ```GaussianMixtureModel``` should support single instance ```predict/predictSoft``` just like Scala do. Author: Yanbo Liang <ybliang8@gmail.com> Closes #10552 from yanboliang/spark-12603.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala2
2 files changed, 5 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala
index 6a3b20c88d..a689b09341 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala
@@ -40,5 +40,9 @@ private[python] class GaussianMixtureModelWrapper(model: GaussianMixtureModel) {
SerDe.dumps(JavaConverters.seqAsJavaListConverter(modelGaussians).asJava)
}
+ def predictSoft(point: Vector): Vector = {
+ Vectors.dense(model.predictSoft(point))
+ }
+
def save(sc: SparkContext, path: String): Unit = model.save(sc, path)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
index 16bc45bcb6..42fe27024f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
@@ -75,7 +75,7 @@ class GaussianMixtureModel @Since("1.3.0") (
*/
@Since("1.5.0")
def predict(point: Vector): Int = {
- val r = computeSoftAssignments(point.toBreeze.toDenseVector, gaussians, weights, k)
+ val r = predictSoft(point)
r.indexOf(r.max)
}