diff options
author | FlytxtRnD <meethu.mathew@flytxt.com> | 2015-05-15 10:43:18 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2015-05-15 10:43:18 -0700 |
commit | 8f4aaba0e4e3350ab152a476d08ff60e9495c6d2 (patch) | |
tree | 11f739d8f0f4dd9c9b7d5fe51f6d29fd8e2aef98 /python | |
parent | f96b85ab44b82736363764ea39ee62884007f4a3 (diff) | |
download | spark-8f4aaba0e4e3350ab152a476d08ff60e9495c6d2.tar.gz spark-8f4aaba0e4e3350ab152a476d08ff60e9495c6d2.tar.bz2 spark-8f4aaba0e4e3350ab152a476d08ff60e9495c6d2.zip |
[SPARK-7651] [MLLIB] [PYSPARK] GMM predict, predictSoft should raise error on bad input
In the Python API for Gaussian Mixture Model, predict() and predictSoft() methods should raise an error when the input argument is not an RDD.
Author: FlytxtRnD <meethu.mathew@flytxt.com>
Closes #6180 from FlytxtRnD/GmmPredictException and squashes the following commits:
4b6aa11 [FlytxtRnD] Raise error if the input to predict()/predictSoft() is not an RDD
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/mllib/clustering.py | 6 |
1 files changed, 6 insertions, 0 deletions
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index a53333dae6..b55583f822 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -212,6 +212,9 @@ class GaussianMixtureModel(object): if isinstance(x, RDD): cluster_labels = self.predictSoft(x).map(lambda z: z.index(max(z))) return cluster_labels + else: + raise TypeError("x should be represented by an RDD, " + "but got %s." % type(x)) def predictSoft(self, x): """ @@ -225,6 +228,9 @@ class GaussianMixtureModel(object): membership_matrix = callMLlibFunc("predictSoftGMM", x.map(_convert_to_vector), _convert_to_vector(self._weights), means, sigmas) return membership_matrix.map(lambda x: pyarray.array('d', x)) + else: + raise TypeError("x should be represented by an RDD, " + "but got %s." % type(x)) class GaussianMixture(object): |