aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorFlytxtRnD <meethu.mathew@flytxt.com>2015-02-02 23:04:55 -0800
committerXiangrui Meng <meng@databricks.com>2015-02-02 23:04:55 -0800
commit50a1a874e1d087a6c79835b1936d0009622a97b1 (patch)
tree81381fdb41d6bf9e3cbf59291f200fbc5ddab3d1 /examples
parentc31c36c4a76bd3449696383321332ec95bff7fed (diff)
downloadspark-50a1a874e1d087a6c79835b1936d0009622a97b1.tar.gz
spark-50a1a874e1d087a6c79835b1936d0009622a97b1.tar.bz2
spark-50a1a874e1d087a6c79835b1936d0009622a97b1.zip
[SPARK-5012][MLLib][PySpark]Python API for Gaussian Mixture Model
Python API for the Gaussian Mixture Model clustering algorithm in MLLib. Author: FlytxtRnD <meethu.mathew@flytxt.com> Closes #4059 from FlytxtRnD/PythonGmmWrapper and squashes the following commits: c973ab3 [FlytxtRnD] Merge branch 'PythonGmmWrapper', remote-tracking branch 'upstream/master' into PythonGmmWrapper 339b09c [FlytxtRnD] Added MultivariateGaussian namedtuple and Arraybuffer in trainGaussianMixture fa0a142 [FlytxtRnD] New line added d5b36ab [FlytxtRnD] Changed argument names to lowercase ac134f1 [FlytxtRnD] Merge branch 'PythonGmmWrapper' of https://github.com/FlytxtRnD/spark into PythonGmmWrapper 6671ea1 [FlytxtRnD] Added mllib/stat/distribution.py 3aee84b [FlytxtRnD] Fixed style issues 2e9f12a [FlytxtRnD] Added mllib/stat/distribution.py and fixed style issues b22532c [FlytxtRnD] Merge branch 'PythonGmmWrapper', remote-tracking branch 'upstream/master' into PythonGmmWrapper 2e14d82 [FlytxtRnD] Incorporate MultivariateGaussian instances in GaussianMixtureModel 05767c7 [FlytxtRnD] Merge branch 'PythonGmmWrapper', remote-tracking branch 'upstream/master' into PythonGmmWrapper 3464d19 [FlytxtRnD] Merge branch 'PythonGmmWrapper', remote-tracking branch 'upstream/master' into PythonGmmWrapper c1d4c71 [FlytxtRnD] Merge branch 'PythonGmmWrapper', remote-tracking branch 'origin/PythonGmmWrapper' into PythonGmmWrapper 426d130 [FlytxtRnD] Added random seed parameter 332bad1 [FlytxtRnD] Merge branch 'PythonGmmWrapper', remote-tracking branch 'upstream/master' into PythonGmmWrapper f82750b [FlytxtRnD] Fixed style issues 5c83825 [FlytxtRnD] Split input file with space delimiter fda60f3 [FlytxtRnD] Python API for Gaussian Mixture Model
Diffstat (limited to 'examples')
-rw-r--r--examples/src/main/python/mllib/gaussian_mixture_model.py65
1 files changed, 65 insertions, 0 deletions
diff --git a/examples/src/main/python/mllib/gaussian_mixture_model.py b/examples/src/main/python/mllib/gaussian_mixture_model.py
new file mode 100644
index 0000000000..a2cd626c9f
--- /dev/null
+++ b/examples/src/main/python/mllib/gaussian_mixture_model.py
@@ -0,0 +1,65 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+A Gaussian Mixture Model clustering program using MLlib.
+"""
+import sys
+import random
+import argparse
+import numpy as np
+
+from pyspark import SparkConf, SparkContext
+from pyspark.mllib.clustering import GaussianMixture
+
+
+def parseVector(line):
+ return np.array([float(x) for x in line.split(' ')])
+
+
+if __name__ == "__main__":
+ """
+ Parameters
+ ----------
+ :param inputFile: Input file path which contains data points
+ :param k: Number of mixture components
+ :param convergenceTol: Convergence threshold. Default to 1e-3
+ :param maxIterations: Number of EM iterations to perform. Default to 100
+ :param seed: Random seed
+ """
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('inputFile', help='Input File')
+ parser.add_argument('k', type=int, help='Number of clusters')
+ parser.add_argument('--convergenceTol', default=1e-3, type=float, help='convergence threshold')
+ parser.add_argument('--maxIterations', default=100, type=int, help='Number of iterations')
+ parser.add_argument('--seed', default=random.getrandbits(19),
+ type=long, help='Random seed')
+ args = parser.parse_args()
+
+ conf = SparkConf().setAppName("GMM")
+ sc = SparkContext(conf=conf)
+
+ lines = sc.textFile(args.inputFile)
+ data = lines.map(parseVector)
+ model = GaussianMixture.train(data, args.k, args.convergenceTol,
+ args.maxIterations, args.seed)
+ for i in range(args.k):
+ print ("weight = ", model.weights[i], "mu = ", model.gaussians[i].mu,
+ "sigma = ", model.gaussians[i].sigma.toArray())
+ print ("Cluster labels (first 100): ", model.predict(data).take(100))
+ sc.stop()