aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-01-11 14:43:25 -0800
committerJoseph K. Bradley <joseph@databricks.com>2016-01-11 14:43:25 -0800
commitee4ee02b86be8756a6d895a2e23e80862134a6d3 (patch)
tree7d4402425ee3c9c10e38bd5bb85d9186cbde8fc9 /examples
parenta767ee8a0599f5482717493a3298413c65d8ff89 (diff)
downloadspark-ee4ee02b86be8756a6d895a2e23e80862134a6d3.tar.gz
spark-ee4ee02b86be8756a6d895a2e23e80862134a6d3.tar.bz2
spark-ee4ee02b86be8756a6d895a2e23e80862134a6d3.zip
[SPARK-12603][MLLIB] PySpark MLlib GaussianMixtureModel should support single instance predict/predictSoft
PySpark MLlib ```GaussianMixtureModel``` should support single instance ```predict/predictSoft``` just like Scala do. Author: Yanbo Liang <ybliang8@gmail.com> Closes #10552 from yanboliang/spark-12603.
Diffstat (limited to 'examples')
-rw-r--r--examples/src/main/python/mllib/gaussian_mixture_model.py4
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala6
2 files changed, 10 insertions, 0 deletions
diff --git a/examples/src/main/python/mllib/gaussian_mixture_model.py b/examples/src/main/python/mllib/gaussian_mixture_model.py
index 2cb8010cdc..69e836fc1d 100644
--- a/examples/src/main/python/mllib/gaussian_mixture_model.py
+++ b/examples/src/main/python/mllib/gaussian_mixture_model.py
@@ -62,5 +62,9 @@ if __name__ == "__main__":
for i in range(args.k):
print(("weight = ", model.weights[i], "mu = ", model.gaussians[i].mu,
"sigma = ", model.gaussians[i].sigma.toArray()))
+ print("\n")
+ print(("The membership value of each vector to all mixture components (first 100): ",
+ model.predictSoft(data).take(100)))
+ print("\n")
print(("Cluster labels (first 100): ", model.predict(data).take(100)))
sc.stop()
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala
index 1fce4ba7ef..90b817b23e 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala
@@ -58,6 +58,12 @@ object DenseGaussianMixture {
(clusters.weights(i), clusters.gaussians(i).mu, clusters.gaussians(i).sigma))
}
+ println("The membership value of each vector to all mixture components (first <= 100):")
+ val membership = clusters.predictSoft(data)
+ membership.take(100).foreach { x =>
+ print(" " + x.mkString(","))
+ }
+ println()
println("Cluster labels (first <= 100):")
val clusterLabels = clusters.predict(data)
clusterLabels.take(100).foreach { x =>