diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2016-01-11 14:43:25 -0800 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-01-11 14:43:25 -0800 |
commit | ee4ee02b86be8756a6d895a2e23e80862134a6d3 (patch) | |
tree | 7d4402425ee3c9c10e38bd5bb85d9186cbde8fc9 /examples/src/main | |
parent | a767ee8a0599f5482717493a3298413c65d8ff89 (diff) | |
download | spark-ee4ee02b86be8756a6d895a2e23e80862134a6d3.tar.gz spark-ee4ee02b86be8756a6d895a2e23e80862134a6d3.tar.bz2 spark-ee4ee02b86be8756a6d895a2e23e80862134a6d3.zip |
[SPARK-12603][MLLIB] PySpark MLlib GaussianMixtureModel should support single instance predict/predictSoft
PySpark MLlib ```GaussianMixtureModel``` should support single instance ```predict/predictSoft``` just like Scala do.
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #10552 from yanboliang/spark-12603.
Diffstat (limited to 'examples/src/main')
-rw-r--r-- | examples/src/main/python/mllib/gaussian_mixture_model.py | 4 | ||||
-rw-r--r-- | examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala | 6 |
2 files changed, 10 insertions, 0 deletions
diff --git a/examples/src/main/python/mllib/gaussian_mixture_model.py b/examples/src/main/python/mllib/gaussian_mixture_model.py index 2cb8010cdc..69e836fc1d 100644 --- a/examples/src/main/python/mllib/gaussian_mixture_model.py +++ b/examples/src/main/python/mllib/gaussian_mixture_model.py @@ -62,5 +62,9 @@ if __name__ == "__main__": for i in range(args.k): print(("weight = ", model.weights[i], "mu = ", model.gaussians[i].mu, "sigma = ", model.gaussians[i].sigma.toArray())) + print("\n") + print(("The membership value of each vector to all mixture components (first 100): ", + model.predictSoft(data).take(100))) + print("\n") print(("Cluster labels (first 100): ", model.predict(data).take(100))) sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala index 1fce4ba7ef..90b817b23e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala @@ -58,6 +58,12 @@ object DenseGaussianMixture { (clusters.weights(i), clusters.gaussians(i).mu, clusters.gaussians(i).sigma)) } + println("The membership value of each vector to all mixture components (first <= 100):") + val membership = clusters.predictSoft(data) + membership.take(100).foreach { x => + print(" " + x.mkString(",")) + } + println() println("Cluster labels (first <= 100):") val clusterLabels = clusters.predict(data) clusterLabels.take(100).foreach { x => |