aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
Diffstat (limited to 'examples')
-rw-r--r--examples/src/main/python/mllib/gaussian_mixture_model.py65
1 files changed, 65 insertions, 0 deletions
diff --git a/examples/src/main/python/mllib/gaussian_mixture_model.py b/examples/src/main/python/mllib/gaussian_mixture_model.py
new file mode 100644
index 0000000000..a2cd626c9f
--- /dev/null
+++ b/examples/src/main/python/mllib/gaussian_mixture_model.py
@@ -0,0 +1,65 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+A Gaussian Mixture Model clustering program using MLlib.
+"""
+import sys
+import random
+import argparse
+import numpy as np
+
+from pyspark import SparkConf, SparkContext
+from pyspark.mllib.clustering import GaussianMixture
+
+
+def parseVector(line):
+ return np.array([float(x) for x in line.split(' ')])
+
+
+if __name__ == "__main__":
+ """
+ Parameters
+ ----------
+ :param inputFile: Input file path which contains data points
+ :param k: Number of mixture components
+ :param convergenceTol: Convergence threshold. Default to 1e-3
+ :param maxIterations: Number of EM iterations to perform. Default to 100
+ :param seed: Random seed
+ """
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('inputFile', help='Input File')
+ parser.add_argument('k', type=int, help='Number of clusters')
+ parser.add_argument('--convergenceTol', default=1e-3, type=float, help='convergence threshold')
+ parser.add_argument('--maxIterations', default=100, type=int, help='Number of iterations')
+ parser.add_argument('--seed', default=random.getrandbits(19),
+ type=long, help='Random seed')
+ args = parser.parse_args()
+
+ conf = SparkConf().setAppName("GMM")
+ sc = SparkContext(conf=conf)
+
+ lines = sc.textFile(args.inputFile)
+ data = lines.map(parseVector)
+ model = GaussianMixture.train(data, args.k, args.convergenceTol,
+ args.maxIterations, args.seed)
+ for i in range(args.k):
+ print ("weight = ", model.weights[i], "mu = ", model.gaussians[i].mu,
+ "sigma = ", model.gaussians[i].sigma.toArray())
+ print ("Cluster labels (first 100): ", model.predict(data).take(100))
+ sc.stop()