aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/mllib/clustering.py6
1 files changed, 6 insertions, 0 deletions
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
index a53333dae6..b55583f822 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -212,6 +212,9 @@ class GaussianMixtureModel(object):
if isinstance(x, RDD):
cluster_labels = self.predictSoft(x).map(lambda z: z.index(max(z)))
return cluster_labels
+ else:
+ raise TypeError("x should be represented by an RDD, "
+ "but got %s." % type(x))
def predictSoft(self, x):
"""
@@ -225,6 +228,9 @@ class GaussianMixtureModel(object):
membership_matrix = callMLlibFunc("predictSoftGMM", x.map(_convert_to_vector),
_convert_to_vector(self._weights), means, sigmas)
return membership_matrix.map(lambda x: pyarray.array('d', x))
+ else:
+ raise TypeError("x should be represented by an RDD, "
+ "but got %s." % type(x))
class GaussianMixture(object):