diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2016-01-11 14:43:25 -0800 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-01-11 14:43:25 -0800 |
commit | ee4ee02b86be8756a6d895a2e23e80862134a6d3 (patch) | |
tree | 7d4402425ee3c9c10e38bd5bb85d9186cbde8fc9 /mllib | |
parent | a767ee8a0599f5482717493a3298413c65d8ff89 (diff) | |
download | spark-ee4ee02b86be8756a6d895a2e23e80862134a6d3.tar.gz spark-ee4ee02b86be8756a6d895a2e23e80862134a6d3.tar.bz2 spark-ee4ee02b86be8756a6d895a2e23e80862134a6d3.zip |
[SPARK-12603][MLLIB] PySpark MLlib GaussianMixtureModel should support single instance predict/predictSoft
PySpark MLlib ```GaussianMixtureModel``` should support single instance ```predict/predictSoft``` just like Scala do.
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #10552 from yanboliang/spark-12603.
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala | 4 | ||||
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala | 2 |
2 files changed, 5 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala index 6a3b20c88d..a689b09341 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala @@ -40,5 +40,9 @@ private[python] class GaussianMixtureModelWrapper(model: GaussianMixtureModel) { SerDe.dumps(JavaConverters.seqAsJavaListConverter(modelGaussians).asJava) } + def predictSoft(point: Vector): Vector = { + Vectors.dense(model.predictSoft(point)) + } + def save(sc: SparkContext, path: String): Unit = model.save(sc, path) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index 16bc45bcb6..42fe27024f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -75,7 +75,7 @@ class GaussianMixtureModel @Since("1.3.0") ( */ @Since("1.5.0") def predict(point: Vector): Int = { - val r = computeSoftAssignments(point.toBreeze.toDenseVector, gaussians, weights, k) + val r = predictSoft(point) r.indexOf(r.max) } |