diff options
author | Yin Huai <yhuai@databricks.com> | 2016-01-06 22:03:31 -0800 |
---|---|---|
committer | Yin Huai <yhuai@databricks.com> | 2016-01-06 22:03:31 -0800 |
commit | e5cde7ab11a43334fa01b1bb8904da5c0774bc62 (patch) | |
tree | b124be7dff9c25aed3bbe013c7b9b5a456500021 /python/pyspark/mllib | |
parent | b6738520374637347ab5ae6c801730cdb6b35daa (diff) | |
download | spark-e5cde7ab11a43334fa01b1bb8904da5c0774bc62.tar.gz spark-e5cde7ab11a43334fa01b1bb8904da5c0774bc62.tar.bz2 spark-e5cde7ab11a43334fa01b1bb8904da5c0774bc62.zip |
Revert "[SPARK-12006][ML][PYTHON] Fix GMM failure if initialModel is not None"
This reverts commit fcd013cf70e7890aa25a8fe3cb6c8b36bf0e1f04.
Author: Yin Huai <yhuai@databricks.com>
Closes #10632 from yhuai/pythonStyle.
Diffstat (limited to 'python/pyspark/mllib')
-rw-r--r-- | python/pyspark/mllib/clustering.py | 2 | ||||
-rw-r--r-- | python/pyspark/mllib/tests.py | 12 |
2 files changed, 1 insertions, 13 deletions
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 48daa87e82..c9e6f1dec6 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -346,7 +346,7 @@ class GaussianMixture(object): if initialModel.k != k: raise Exception("Mismatched cluster count, initialModel.k = %s, however k = %s" % (initialModel.k, k)) - initialModelWeights = list(initialModel.weights) + initialModelWeights = initialModel.weights initialModelMu = [initialModel.gaussians[i].mu for i in range(initialModel.k)] initialModelSigma = [initialModel.gaussians[i].sigma for i in range(initialModel.k)] java_model = callMLlibFunc("trainGaussianMixtureModel", rdd.map(_convert_to_vector), diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 97fed7662e..6ed03e3582 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -475,18 +475,6 @@ class ListTests(MLlibTestCase): for c1, c2 in zip(clusters1.weights, clusters2.weights): self.assertEqual(round(c1, 7), round(c2, 7)) - def test_gmm_with_initial_model(self): - from pyspark.mllib.clustering import GaussianMixture - data = self.sc.parallelize([ - (-10, -5), (-9, -4), (10, 5), (9, 4) - ]) - - gmm1 = GaussianMixture.train(data, 2, convergenceTol=0.001, - maxIterations=10, seed=63) - gmm2 = GaussianMixture.train(data, 2, convergenceTol=0.001, - maxIterations=10, seed=63, initialModel=gmm1) - self.assertAlmostEqual((gmm1.weights - gmm2.weights).sum(), 0.0) - def test_classification(self): from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes from pyspark.mllib.tree import DecisionTree, DecisionTreeModel, RandomForest,\ |