aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/mllib/tests.py')
-rw-r--r--python/pyspark/mllib/tests.py26
1 files changed, 26 insertions, 0 deletions
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 61e0cf5d90..42aa228737 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -167,6 +167,32 @@ class ListTests(PySparkTestCase):
# TODO: Allow small numeric difference.
self.assertTrue(array_equal(c1, c2))
+ def test_gmm(self):
+ from pyspark.mllib.clustering import GaussianMixture
+ data = self.sc.parallelize([
+ [1, 2],
+ [8, 9],
+ [-4, -3],
+ [-6, -7],
+ ])
+ clusters = GaussianMixture.train(data, 2, convergenceTol=0.001,
+ maxIterations=100, seed=56)
+ labels = clusters.predict(data).collect()
+ self.assertEquals(labels[0], labels[1])
+ self.assertEquals(labels[2], labels[3])
+
+ def test_gmm_deterministic(self):
+ from pyspark.mllib.clustering import GaussianMixture
+ x = range(0, 100, 10)
+ y = range(0, 100, 10)
+ data = self.sc.parallelize([[a, b] for a, b in zip(x, y)])
+ clusters1 = GaussianMixture.train(data, 5, convergenceTol=0.001,
+ maxIterations=100, seed=63)
+ clusters2 = GaussianMixture.train(data, 5, convergenceTol=0.001,
+ maxIterations=100, seed=63)
+ for c1, c2 in zip(clusters1.weights, clusters2.weights):
+ self.assertEquals(round(c1, 7), round(c2, 7))
+
def test_classification(self):
from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes
from pyspark.mllib.tree import DecisionTree, RandomForest, GradientBoostedTrees