aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/tests.py
diff options
context:
space:
mode:
authorFlytxtRnD <meethu.mathew@flytxt.com>2015-02-02 23:04:55 -0800
committerXiangrui Meng <meng@databricks.com>2015-02-02 23:04:55 -0800
commit50a1a874e1d087a6c79835b1936d0009622a97b1 (patch)
tree81381fdb41d6bf9e3cbf59291f200fbc5ddab3d1 /python/pyspark/mllib/tests.py
parentc31c36c4a76bd3449696383321332ec95bff7fed (diff)
downloadspark-50a1a874e1d087a6c79835b1936d0009622a97b1.tar.gz
spark-50a1a874e1d087a6c79835b1936d0009622a97b1.tar.bz2
spark-50a1a874e1d087a6c79835b1936d0009622a97b1.zip
[SPARK-5012][MLLib][PySpark]Python API for Gaussian Mixture Model
Python API for the Gaussian Mixture Model clustering algorithm in MLLib. Author: FlytxtRnD <meethu.mathew@flytxt.com> Closes #4059 from FlytxtRnD/PythonGmmWrapper and squashes the following commits: c973ab3 [FlytxtRnD] Merge branch 'PythonGmmWrapper', remote-tracking branch 'upstream/master' into PythonGmmWrapper 339b09c [FlytxtRnD] Added MultivariateGaussian namedtuple and Arraybuffer in trainGaussianMixture fa0a142 [FlytxtRnD] New line added d5b36ab [FlytxtRnD] Changed argument names to lowercase ac134f1 [FlytxtRnD] Merge branch 'PythonGmmWrapper' of https://github.com/FlytxtRnD/spark into PythonGmmWrapper 6671ea1 [FlytxtRnD] Added mllib/stat/distribution.py 3aee84b [FlytxtRnD] Fixed style issues 2e9f12a [FlytxtRnD] Added mllib/stat/distribution.py and fixed style issues b22532c [FlytxtRnD] Merge branch 'PythonGmmWrapper', remote-tracking branch 'upstream/master' into PythonGmmWrapper 2e14d82 [FlytxtRnD] Incorporate MultivariateGaussian instances in GaussianMixtureModel 05767c7 [FlytxtRnD] Merge branch 'PythonGmmWrapper', remote-tracking branch 'upstream/master' into PythonGmmWrapper 3464d19 [FlytxtRnD] Merge branch 'PythonGmmWrapper', remote-tracking branch 'upstream/master' into PythonGmmWrapper c1d4c71 [FlytxtRnD] Merge branch 'PythonGmmWrapper', remote-tracking branch 'origin/PythonGmmWrapper' into PythonGmmWrapper 426d130 [FlytxtRnD] Added random seed parameter 332bad1 [FlytxtRnD] Merge branch 'PythonGmmWrapper', remote-tracking branch 'upstream/master' into PythonGmmWrapper f82750b [FlytxtRnD] Fixed style issues 5c83825 [FlytxtRnD] Split input file with space delimiter fda60f3 [FlytxtRnD] Python API for Gaussian Mixture Model
Diffstat (limited to 'python/pyspark/mllib/tests.py')
-rw-r--r--python/pyspark/mllib/tests.py26
1 files changed, 26 insertions, 0 deletions
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 61e0cf5d90..42aa228737 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -167,6 +167,32 @@ class ListTests(PySparkTestCase):
# TODO: Allow small numeric difference.
self.assertTrue(array_equal(c1, c2))
+ def test_gmm(self):
+ from pyspark.mllib.clustering import GaussianMixture
+ data = self.sc.parallelize([
+ [1, 2],
+ [8, 9],
+ [-4, -3],
+ [-6, -7],
+ ])
+ clusters = GaussianMixture.train(data, 2, convergenceTol=0.001,
+ maxIterations=100, seed=56)
+ labels = clusters.predict(data).collect()
+ self.assertEquals(labels[0], labels[1])
+ self.assertEquals(labels[2], labels[3])
+
+ def test_gmm_deterministic(self):
+ from pyspark.mllib.clustering import GaussianMixture
+ x = range(0, 100, 10)
+ y = range(0, 100, 10)
+ data = self.sc.parallelize([[a, b] for a, b in zip(x, y)])
+ clusters1 = GaussianMixture.train(data, 5, convergenceTol=0.001,
+ maxIterations=100, seed=63)
+ clusters2 = GaussianMixture.train(data, 5, convergenceTol=0.001,
+ maxIterations=100, seed=63)
+ for c1, c2 in zip(clusters1.weights, clusters2.weights):
+ self.assertEquals(round(c1, 7), round(c2, 7))
+
def test_classification(self):
from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes
from pyspark.mllib.tree import DecisionTree, RandomForest, GradientBoostedTrees