From 50a1a874e1d087a6c79835b1936d0009622a97b1 Mon Sep 17 00:00:00 2001 From: FlytxtRnD Date: Mon, 2 Feb 2015 23:04:55 -0800 Subject: [SPARK-5012][MLLib][PySpark]Python API for Gaussian Mixture Model Python API for the Gaussian Mixture Model clustering algorithm in MLLib. Author: FlytxtRnD Closes #4059 from FlytxtRnD/PythonGmmWrapper and squashes the following commits: c973ab3 [FlytxtRnD] Merge branch 'PythonGmmWrapper', remote-tracking branch 'upstream/master' into PythonGmmWrapper 339b09c [FlytxtRnD] Added MultivariateGaussian namedtuple and Arraybuffer in trainGaussianMixture fa0a142 [FlytxtRnD] New line added d5b36ab [FlytxtRnD] Changed argument names to lowercase ac134f1 [FlytxtRnD] Merge branch 'PythonGmmWrapper' of https://github.com/FlytxtRnD/spark into PythonGmmWrapper 6671ea1 [FlytxtRnD] Added mllib/stat/distribution.py 3aee84b [FlytxtRnD] Fixed style issues 2e9f12a [FlytxtRnD] Added mllib/stat/distribution.py and fixed style issues b22532c [FlytxtRnD] Merge branch 'PythonGmmWrapper', remote-tracking branch 'upstream/master' into PythonGmmWrapper 2e14d82 [FlytxtRnD] Incorporate MultivariateGaussian instances in GaussianMixtureModel 05767c7 [FlytxtRnD] Merge branch 'PythonGmmWrapper', remote-tracking branch 'upstream/master' into PythonGmmWrapper 3464d19 [FlytxtRnD] Merge branch 'PythonGmmWrapper', remote-tracking branch 'upstream/master' into PythonGmmWrapper c1d4c71 [FlytxtRnD] Merge branch 'PythonGmmWrapper', remote-tracking branch 'origin/PythonGmmWrapper' into PythonGmmWrapper 426d130 [FlytxtRnD] Added random seed parameter 332bad1 [FlytxtRnD] Merge branch 'PythonGmmWrapper', remote-tracking branch 'upstream/master' into PythonGmmWrapper f82750b [FlytxtRnD] Fixed style issues 5c83825 [FlytxtRnD] Split input file with space delimiter fda60f3 [FlytxtRnD] Python API for Gaussian Mixture Model --- python/pyspark/mllib/tests.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) (limited to 'python/pyspark/mllib/tests.py') diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 61e0cf5d90..42aa228737 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -167,6 +167,32 @@ class ListTests(PySparkTestCase): # TODO: Allow small numeric difference. self.assertTrue(array_equal(c1, c2)) + def test_gmm(self): + from pyspark.mllib.clustering import GaussianMixture + data = self.sc.parallelize([ + [1, 2], + [8, 9], + [-4, -3], + [-6, -7], + ]) + clusters = GaussianMixture.train(data, 2, convergenceTol=0.001, + maxIterations=100, seed=56) + labels = clusters.predict(data).collect() + self.assertEquals(labels[0], labels[1]) + self.assertEquals(labels[2], labels[3]) + + def test_gmm_deterministic(self): + from pyspark.mllib.clustering import GaussianMixture + x = range(0, 100, 10) + y = range(0, 100, 10) + data = self.sc.parallelize([[a, b] for a, b in zip(x, y)]) + clusters1 = GaussianMixture.train(data, 5, convergenceTol=0.001, + maxIterations=100, seed=63) + clusters2 = GaussianMixture.train(data, 5, convergenceTol=0.001, + maxIterations=100, seed=63) + for c1, c2 in zip(clusters1.weights, clusters2.weights): + self.assertEquals(round(c1, 7), round(c2, 7)) + def test_classification(self): from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes from pyspark.mllib.tree import DecisionTree, RandomForest, GradientBoostedTrees -- cgit v1.2.3