aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorzero323 <matthew.szymkiewicz@gmail.com>2016-01-07 10:32:56 -0800
committerJoseph K. Bradley <joseph@databricks.com>2016-01-07 10:32:56 -0800
commit592f64985d0d58b4f6a0366bf975e04ca496bdbe (patch)
tree1826c62f3af4d9f7ff6e513af3c0b98ad258fa71
parent8113dbda0bd51fdbe20dbfad466b8d25304a01f4 (diff)
downloadspark-592f64985d0d58b4f6a0366bf975e04ca496bdbe.tar.gz
spark-592f64985d0d58b4f6a0366bf975e04ca496bdbe.tar.bz2
spark-592f64985d0d58b4f6a0366bf975e04ca496bdbe.zip
[SPARK-12006][ML][PYTHON] Fix GMM failure if initialModel is not None
If initial model passed to GMM is not empty it causes net.razorvine.pickle.PickleException. It can be fixed by converting initialModel.weights to list. Author: zero323 <matthew.szymkiewicz@gmail.com> Closes #10644 from zero323/SPARK-12006.
-rw-r--r--python/pyspark/mllib/clustering.py2
-rw-r--r--python/pyspark/mllib/tests.py12
2 files changed, 13 insertions, 1 deletions
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
index c9e6f1dec6..48daa87e82 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -346,7 +346,7 @@ class GaussianMixture(object):
if initialModel.k != k:
raise Exception("Mismatched cluster count, initialModel.k = %s, however k = %s"
% (initialModel.k, k))
- initialModelWeights = initialModel.weights
+ initialModelWeights = list(initialModel.weights)
initialModelMu = [initialModel.gaussians[i].mu for i in range(initialModel.k)]
initialModelSigma = [initialModel.gaussians[i].sigma for i in range(initialModel.k)]
java_model = callMLlibFunc("trainGaussianMixtureModel", rdd.map(_convert_to_vector),
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 6ed03e3582..3436a28b29 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -475,6 +475,18 @@ class ListTests(MLlibTestCase):
for c1, c2 in zip(clusters1.weights, clusters2.weights):
self.assertEqual(round(c1, 7), round(c2, 7))
+ def test_gmm_with_initial_model(self):
+ from pyspark.mllib.clustering import GaussianMixture
+ data = self.sc.parallelize([
+ (-10, -5), (-9, -4), (10, 5), (9, 4)
+ ])
+
+ gmm1 = GaussianMixture.train(data, 2, convergenceTol=0.001,
+ maxIterations=10, seed=63)
+ gmm2 = GaussianMixture.train(data, 2, convergenceTol=0.001,
+ maxIterations=10, seed=63, initialModel=gmm1)
+ self.assertAlmostEqual((gmm1.weights - gmm2.weights).sum(), 0.0)
+
def test_classification(self):
from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes
from pyspark.mllib.tree import DecisionTree, DecisionTreeModel, RandomForest,\