# # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ A Gaussian Mixture Model clustering program using MLlib. """ from __future__ import print_function import random import argparse import numpy as np from pyspark import SparkConf, SparkContext from pyspark.mllib.clustering import GaussianMixture def parseVector(line): return np.array([float(x) for x in line.split(' ')]) if __name__ == "__main__": """ Parameters ---------- :param inputFile: Input file path which contains data points :param k: Number of mixture components :param convergenceTol: Convergence threshold. Default to 1e-3 :param maxIterations: Number of EM iterations to perform. Default to 100 :param seed: Random seed """ parser = argparse.ArgumentParser() parser.add_argument('inputFile', help='Input File') parser.add_argument('k', type=int, help='Number of clusters') parser.add_argument('--convergenceTol', default=1e-3, type=float, help='convergence threshold') parser.add_argument('--maxIterations', default=100, type=int, help='Number of iterations') parser.add_argument('--seed', default=random.getrandbits(19), type=long, help='Random seed') args = parser.parse_args() conf = SparkConf().setAppName("GMM") sc = SparkContext(conf=conf) lines = sc.textFile(args.inputFile) data = lines.map(parseVector) model = GaussianMixture.train(data, args.k, args.convergenceTol, args.maxIterations, args.seed) for i in range(args.k): print(("weight = ", model.weights[i], "mu = ", model.gaussians[i].mu, "sigma = ", model.gaussians[i].sigma.toArray())) print(("Cluster labels (first 100): ", model.predict(data).take(100))) sc.stop()