aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-01-11 14:43:25 -0800
committerJoseph K. Bradley <joseph@databricks.com>2016-01-11 14:43:25 -0800
commitee4ee02b86be8756a6d895a2e23e80862134a6d3 (patch)
tree7d4402425ee3c9c10e38bd5bb85d9186cbde8fc9 /python/pyspark
parenta767ee8a0599f5482717493a3298413c65d8ff89 (diff)
downloadspark-ee4ee02b86be8756a6d895a2e23e80862134a6d3.tar.gz
spark-ee4ee02b86be8756a6d895a2e23e80862134a6d3.tar.bz2
spark-ee4ee02b86be8756a6d895a2e23e80862134a6d3.zip
[SPARK-12603][MLLIB] PySpark MLlib GaussianMixtureModel should support single instance predict/predictSoft
PySpark MLlib ```GaussianMixtureModel``` should support single instance ```predict/predictSoft``` just like Scala do. Author: Yanbo Liang <ybliang8@gmail.com> Closes #10552 from yanboliang/spark-12603.
Diffstat (limited to 'python/pyspark')
-rw-r--r--python/pyspark/mllib/clustering.py35
1 files changed, 22 insertions, 13 deletions
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
index d22a7f4c3b..580cb512d8 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -202,16 +202,25 @@ class GaussianMixtureModel(JavaModelWrapper, JavaSaveable, JavaLoader):
>>> clusterdata_1 = sc.parallelize(array([-0.1,-0.05,-0.01,-0.1,
... 0.9,0.8,0.75,0.935,
- ... -0.83,-0.68,-0.91,-0.76 ]).reshape(6, 2))
+ ... -0.83,-0.68,-0.91,-0.76 ]).reshape(6, 2), 2)
>>> model = GaussianMixture.train(clusterdata_1, 3, convergenceTol=0.0001,
... maxIterations=50, seed=10)
>>> labels = model.predict(clusterdata_1).collect()
>>> labels[0]==labels[1]
False
>>> labels[1]==labels[2]
- True
+ False
>>> labels[4]==labels[5]
True
+ >>> model.predict([-0.1,-0.05])
+ 0
+ >>> softPredicted = model.predictSoft([-0.1,-0.05])
+ >>> abs(softPredicted[0] - 1.0) < 0.001
+ True
+ >>> abs(softPredicted[1] - 0.0) < 0.001
+ True
+ >>> abs(softPredicted[2] - 0.0) < 0.001
+ True
>>> path = tempfile.mkdtemp()
>>> model.save(sc, path)
@@ -277,26 +286,27 @@ class GaussianMixtureModel(JavaModelWrapper, JavaSaveable, JavaLoader):
@since('1.3.0')
def predict(self, x):
"""
- Find the cluster to which the points in 'x' has maximum membership
- in this model.
+ Find the cluster to which the point 'x' or each point in RDD 'x'
+ has maximum membership in this model.
- :param x: RDD of data points.
- :return: cluster_labels. RDD of cluster labels.
+ :param x: vector or RDD of vector represents data points.
+ :return: cluster label or RDD of cluster labels.
"""
if isinstance(x, RDD):
cluster_labels = self.predictSoft(x).map(lambda z: z.index(max(z)))
return cluster_labels
else:
- raise TypeError("x should be represented by an RDD, "
- "but got %s." % type(x))
+ z = self.predictSoft(x)
+ return z.argmax()
@since('1.3.0')
def predictSoft(self, x):
"""
- Find the membership of each point in 'x' to all mixture components.
+ Find the membership of point 'x' or each point in RDD 'x' to all mixture components.
- :param x: RDD of data points.
- :return: membership_matrix. RDD of array of double values.
+ :param x: vector or RDD of vector represents data points.
+ :return: the membership value to all mixture components for vector 'x'
+ or each vector in RDD 'x'.
"""
if isinstance(x, RDD):
means, sigmas = zip(*[(g.mu, g.sigma) for g in self.gaussians])
@@ -304,8 +314,7 @@ class GaussianMixtureModel(JavaModelWrapper, JavaSaveable, JavaLoader):
_convert_to_vector(self.weights), means, sigmas)
return membership_matrix.map(lambda x: pyarray.array('d', x))
else:
- raise TypeError("x should be represented by an RDD, "
- "but got %s." % type(x))
+ return self.call("predictSoft", _convert_to_vector(x)).toArray()
@classmethod
@since('1.5.0')