aboutsummaryrefslogtreecommitdiff
path: root/examples/src/main/python/mllib/gaussian_mixture_model.py
blob: a2cd626c9f19d9859738f50fbf42c280287cef12 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""
A Gaussian Mixture Model clustering program using MLlib.
"""
import sys
import random
import argparse
import numpy as np

from pyspark import SparkConf, SparkContext
from pyspark.mllib.clustering import GaussianMixture


def parseVector(line):
    return np.array([float(x) for x in line.split(' ')])


if __name__ == "__main__":
    """
    Parameters
    ----------
    :param inputFile:        Input file path which contains data points
    :param k:                Number of mixture components
    :param convergenceTol:   Convergence threshold. Default to 1e-3
    :param maxIterations:    Number of EM iterations to perform. Default to 100
    :param seed:             Random seed
    """

    parser = argparse.ArgumentParser()
    parser.add_argument('inputFile', help='Input File')
    parser.add_argument('k', type=int, help='Number of clusters')
    parser.add_argument('--convergenceTol', default=1e-3, type=float, help='convergence threshold')
    parser.add_argument('--maxIterations', default=100, type=int, help='Number of iterations')
    parser.add_argument('--seed', default=random.getrandbits(19),
                        type=long, help='Random seed')
    args = parser.parse_args()

    conf = SparkConf().setAppName("GMM")
    sc = SparkContext(conf=conf)

    lines = sc.textFile(args.inputFile)
    data = lines.map(parseVector)
    model = GaussianMixture.train(data, args.k, args.convergenceTol,
                                  args.maxIterations, args.seed)
    for i in range(args.k):
        print ("weight = ", model.weights[i], "mu = ", model.gaussians[i].mu,
               "sigma = ", model.gaussians[i].sigma.toArray())
    print ("Cluster labels (first 100): ", model.predict(data).take(100))
    sc.stop()