From 50a1a874e1d087a6c79835b1936d0009622a97b1 Mon Sep 17 00:00:00 2001 From: FlytxtRnD Date: Mon, 2 Feb 2015 23:04:55 -0800 Subject: [SPARK-5012][MLLib][PySpark]Python API for Gaussian Mixture Model Python API for the Gaussian Mixture Model clustering algorithm in MLLib. Author: FlytxtRnD Closes #4059 from FlytxtRnD/PythonGmmWrapper and squashes the following commits: c973ab3 [FlytxtRnD] Merge branch 'PythonGmmWrapper', remote-tracking branch 'upstream/master' into PythonGmmWrapper 339b09c [FlytxtRnD] Added MultivariateGaussian namedtuple and Arraybuffer in trainGaussianMixture fa0a142 [FlytxtRnD] New line added d5b36ab [FlytxtRnD] Changed argument names to lowercase ac134f1 [FlytxtRnD] Merge branch 'PythonGmmWrapper' of https://github.com/FlytxtRnD/spark into PythonGmmWrapper 6671ea1 [FlytxtRnD] Added mllib/stat/distribution.py 3aee84b [FlytxtRnD] Fixed style issues 2e9f12a [FlytxtRnD] Added mllib/stat/distribution.py and fixed style issues b22532c [FlytxtRnD] Merge branch 'PythonGmmWrapper', remote-tracking branch 'upstream/master' into PythonGmmWrapper 2e14d82 [FlytxtRnD] Incorporate MultivariateGaussian instances in GaussianMixtureModel 05767c7 [FlytxtRnD] Merge branch 'PythonGmmWrapper', remote-tracking branch 'upstream/master' into PythonGmmWrapper 3464d19 [FlytxtRnD] Merge branch 'PythonGmmWrapper', remote-tracking branch 'upstream/master' into PythonGmmWrapper c1d4c71 [FlytxtRnD] Merge branch 'PythonGmmWrapper', remote-tracking branch 'origin/PythonGmmWrapper' into PythonGmmWrapper 426d130 [FlytxtRnD] Added random seed parameter 332bad1 [FlytxtRnD] Merge branch 'PythonGmmWrapper', remote-tracking branch 'upstream/master' into PythonGmmWrapper f82750b [FlytxtRnD] Fixed style issues 5c83825 [FlytxtRnD] Split input file with space delimiter fda60f3 [FlytxtRnD] Python API for Gaussian Mixture Model --- .../main/python/mllib/gaussian_mixture_model.py | 65 ++++++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 examples/src/main/python/mllib/gaussian_mixture_model.py (limited to 'examples') diff --git a/examples/src/main/python/mllib/gaussian_mixture_model.py b/examples/src/main/python/mllib/gaussian_mixture_model.py new file mode 100644 index 0000000000..a2cd626c9f --- /dev/null +++ b/examples/src/main/python/mllib/gaussian_mixture_model.py @@ -0,0 +1,65 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +A Gaussian Mixture Model clustering program using MLlib. +""" +import sys +import random +import argparse +import numpy as np + +from pyspark import SparkConf, SparkContext +from pyspark.mllib.clustering import GaussianMixture + + +def parseVector(line): + return np.array([float(x) for x in line.split(' ')]) + + +if __name__ == "__main__": + """ + Parameters + ---------- + :param inputFile: Input file path which contains data points + :param k: Number of mixture components + :param convergenceTol: Convergence threshold. Default to 1e-3 + :param maxIterations: Number of EM iterations to perform. Default to 100 + :param seed: Random seed + """ + + parser = argparse.ArgumentParser() + parser.add_argument('inputFile', help='Input File') + parser.add_argument('k', type=int, help='Number of clusters') + parser.add_argument('--convergenceTol', default=1e-3, type=float, help='convergence threshold') + parser.add_argument('--maxIterations', default=100, type=int, help='Number of iterations') + parser.add_argument('--seed', default=random.getrandbits(19), + type=long, help='Random seed') + args = parser.parse_args() + + conf = SparkConf().setAppName("GMM") + sc = SparkContext(conf=conf) + + lines = sc.textFile(args.inputFile) + data = lines.map(parseVector) + model = GaussianMixture.train(data, args.k, args.convergenceTol, + args.maxIterations, args.seed) + for i in range(args.k): + print ("weight = ", model.weights[i], "mu = ", model.gaussians[i].mu, + "sigma = ", model.gaussians[i].sigma.toArray()) + print ("Cluster labels (first 100): ", model.predict(data).take(100)) + sc.stop() -- cgit v1.2.3