diff options
author | FlytxtRnD <meethu.mathew@flytxt.com> | 2015-02-02 23:04:55 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-02-02 23:04:55 -0800 |
commit | 50a1a874e1d087a6c79835b1936d0009622a97b1 (patch) | |
tree | 81381fdb41d6bf9e3cbf59291f200fbc5ddab3d1 /python/pyspark/mllib/tests.py | |
parent | c31c36c4a76bd3449696383321332ec95bff7fed (diff) | |
download | spark-50a1a874e1d087a6c79835b1936d0009622a97b1.tar.gz spark-50a1a874e1d087a6c79835b1936d0009622a97b1.tar.bz2 spark-50a1a874e1d087a6c79835b1936d0009622a97b1.zip |
[SPARK-5012][MLLib][PySpark]Python API for Gaussian Mixture Model
Python API for the Gaussian Mixture Model clustering algorithm in MLLib.
Author: FlytxtRnD <meethu.mathew@flytxt.com>
Closes #4059 from FlytxtRnD/PythonGmmWrapper and squashes the following commits:
c973ab3 [FlytxtRnD] Merge branch 'PythonGmmWrapper', remote-tracking branch 'upstream/master' into PythonGmmWrapper
339b09c [FlytxtRnD] Added MultivariateGaussian namedtuple and Arraybuffer in trainGaussianMixture
fa0a142 [FlytxtRnD] New line added
d5b36ab [FlytxtRnD] Changed argument names to lowercase
ac134f1 [FlytxtRnD] Merge branch 'PythonGmmWrapper' of https://github.com/FlytxtRnD/spark into PythonGmmWrapper
6671ea1 [FlytxtRnD] Added mllib/stat/distribution.py
3aee84b [FlytxtRnD] Fixed style issues
2e9f12a [FlytxtRnD] Added mllib/stat/distribution.py and fixed style issues
b22532c [FlytxtRnD] Merge branch 'PythonGmmWrapper', remote-tracking branch 'upstream/master' into PythonGmmWrapper
2e14d82 [FlytxtRnD] Incorporate MultivariateGaussian instances in GaussianMixtureModel
05767c7 [FlytxtRnD] Merge branch 'PythonGmmWrapper', remote-tracking branch 'upstream/master' into PythonGmmWrapper
3464d19 [FlytxtRnD] Merge branch 'PythonGmmWrapper', remote-tracking branch 'upstream/master' into PythonGmmWrapper
c1d4c71 [FlytxtRnD] Merge branch 'PythonGmmWrapper', remote-tracking branch 'origin/PythonGmmWrapper' into PythonGmmWrapper
426d130 [FlytxtRnD] Added random seed parameter
332bad1 [FlytxtRnD] Merge branch 'PythonGmmWrapper', remote-tracking branch 'upstream/master' into PythonGmmWrapper
f82750b [FlytxtRnD] Fixed style issues
5c83825 [FlytxtRnD] Split input file with space delimiter
fda60f3 [FlytxtRnD] Python API for Gaussian Mixture Model
Diffstat (limited to 'python/pyspark/mllib/tests.py')
-rw-r--r-- | python/pyspark/mllib/tests.py | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 61e0cf5d90..42aa228737 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -167,6 +167,32 @@ class ListTests(PySparkTestCase): # TODO: Allow small numeric difference. self.assertTrue(array_equal(c1, c2)) + def test_gmm(self): + from pyspark.mllib.clustering import GaussianMixture + data = self.sc.parallelize([ + [1, 2], + [8, 9], + [-4, -3], + [-6, -7], + ]) + clusters = GaussianMixture.train(data, 2, convergenceTol=0.001, + maxIterations=100, seed=56) + labels = clusters.predict(data).collect() + self.assertEquals(labels[0], labels[1]) + self.assertEquals(labels[2], labels[3]) + + def test_gmm_deterministic(self): + from pyspark.mllib.clustering import GaussianMixture + x = range(0, 100, 10) + y = range(0, 100, 10) + data = self.sc.parallelize([[a, b] for a, b in zip(x, y)]) + clusters1 = GaussianMixture.train(data, 5, convergenceTol=0.001, + maxIterations=100, seed=63) + clusters2 = GaussianMixture.train(data, 5, convergenceTol=0.001, + maxIterations=100, seed=63) + for c1, c2 in zip(clusters1.weights, clusters2.weights): + self.assertEquals(round(c1, 7), round(c2, 7)) + def test_classification(self): from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes from pyspark.mllib.tree import DecisionTree, RandomForest, GradientBoostedTrees |