aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorFlytxtRnD <meethu.mathew@flytxt.com>2015-05-15 10:43:18 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-05-15 10:43:18 -0700
commit8f4aaba0e4e3350ab152a476d08ff60e9495c6d2 (patch)
tree11f739d8f0f4dd9c9b7d5fe51f6d29fd8e2aef98 /python
parentf96b85ab44b82736363764ea39ee62884007f4a3 (diff)
downloadspark-8f4aaba0e4e3350ab152a476d08ff60e9495c6d2.tar.gz
spark-8f4aaba0e4e3350ab152a476d08ff60e9495c6d2.tar.bz2
spark-8f4aaba0e4e3350ab152a476d08ff60e9495c6d2.zip
[SPARK-7651] [MLLIB] [PYSPARK] GMM predict, predictSoft should raise error on bad input
In the Python API for Gaussian Mixture Model, predict() and predictSoft() methods should raise an error when the input argument is not an RDD. Author: FlytxtRnD <meethu.mathew@flytxt.com> Closes #6180 from FlytxtRnD/GmmPredictException and squashes the following commits: 4b6aa11 [FlytxtRnD] Raise error if the input to predict()/predictSoft() is not an RDD
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/mllib/clustering.py6
1 files changed, 6 insertions, 0 deletions
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
index a53333dae6..b55583f822 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -212,6 +212,9 @@ class GaussianMixtureModel(object):
if isinstance(x, RDD):
cluster_labels = self.predictSoft(x).map(lambda z: z.index(max(z)))
return cluster_labels
+ else:
+ raise TypeError("x should be represented by an RDD, "
+ "but got %s." % type(x))
def predictSoft(self, x):
"""
@@ -225,6 +228,9 @@ class GaussianMixtureModel(object):
membership_matrix = callMLlibFunc("predictSoftGMM", x.map(_convert_to_vector),
_convert_to_vector(self._weights), means, sigmas)
return membership_matrix.map(lambda x: pyarray.array('d', x))
+ else:
+ raise TypeError("x should be represented by an RDD, "
+ "but got %s." % type(x))
class GaussianMixture(object):