From 8f4aaba0e4e3350ab152a476d08ff60e9495c6d2 Mon Sep 17 00:00:00 2001 From: FlytxtRnD Date: Fri, 15 May 2015 10:43:18 -0700 Subject: [SPARK-7651] [MLLIB] [PYSPARK] GMM predict, predictSoft should raise error on bad input In the Python API for Gaussian Mixture Model, predict() and predictSoft() methods should raise an error when the input argument is not an RDD. Author: FlytxtRnD Closes #6180 from FlytxtRnD/GmmPredictException and squashes the following commits: 4b6aa11 [FlytxtRnD] Raise error if the input to predict()/predictSoft() is not an RDD --- python/pyspark/mllib/clustering.py | 6 ++++++ 1 file changed, 6 insertions(+) (limited to 'python') diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index a53333dae6..b55583f822 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -212,6 +212,9 @@ class GaussianMixtureModel(object): if isinstance(x, RDD): cluster_labels = self.predictSoft(x).map(lambda z: z.index(max(z))) return cluster_labels + else: + raise TypeError("x should be represented by an RDD, " + "but got %s." % type(x)) def predictSoft(self, x): """ @@ -225,6 +228,9 @@ class GaussianMixtureModel(object): membership_matrix = callMLlibFunc("predictSoftGMM", x.map(_convert_to_vector), _convert_to_vector(self._weights), means, sigmas) return membership_matrix.map(lambda x: pyarray.array('d', x)) + else: + raise TypeError("x should be represented by an RDD, " + "but got %s." % type(x)) class GaussianMixture(object): -- cgit v1.2.3