aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorFlytxtRnD <meethu.mathew@flytxt.com>2015-05-15 10:43:18 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-05-15 10:43:26 -0700
commitdfdae5800c6a4f9b8e941138f61b784b24b0b00b (patch)
tree21c32d6b4e326e773124795d7d587209076a6556 /python
parentd1f5651004449dad5fc4bf5d4ba3b2888f6b900a (diff)
downloadspark-dfdae5800c6a4f9b8e941138f61b784b24b0b00b.tar.gz
spark-dfdae5800c6a4f9b8e941138f61b784b24b0b00b.tar.bz2
spark-dfdae5800c6a4f9b8e941138f61b784b24b0b00b.zip
[SPARK-7651] [MLLIB] [PYSPARK] GMM predict, predictSoft should raise error on bad input
In the Python API for Gaussian Mixture Model, predict() and predictSoft() methods should raise an error when the input argument is not an RDD. Author: FlytxtRnD <meethu.mathew@flytxt.com> Closes #6180 from FlytxtRnD/GmmPredictException and squashes the following commits: 4b6aa11 [FlytxtRnD] Raise error if the input to predict()/predictSoft() is not an RDD (cherry picked from commit 8f4aaba0e4e3350ab152a476d08ff60e9495c6d2) Signed-off-by: Joseph K. Bradley <joseph@databricks.com>
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/mllib/clustering.py6
1 files changed, 6 insertions, 0 deletions
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
index a53333dae6..b55583f822 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -212,6 +212,9 @@ class GaussianMixtureModel(object):
if isinstance(x, RDD):
cluster_labels = self.predictSoft(x).map(lambda z: z.index(max(z)))
return cluster_labels
+ else:
+ raise TypeError("x should be represented by an RDD, "
+ "but got %s." % type(x))
def predictSoft(self, x):
"""
@@ -225,6 +228,9 @@ class GaussianMixtureModel(object):
membership_matrix = callMLlibFunc("predictSoftGMM", x.map(_convert_to_vector),
_convert_to_vector(self._weights), means, sigmas)
return membership_matrix.map(lambda x: pyarray.array('d', x))
+ else:
+ raise TypeError("x should be represented by an RDD, "
+ "but got %s." % type(x))
class GaussianMixture(object):