diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2016-01-11 14:43:25 -0800 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-01-11 14:43:25 -0800 |
commit | ee4ee02b86be8756a6d895a2e23e80862134a6d3 (patch) | |
tree | 7d4402425ee3c9c10e38bd5bb85d9186cbde8fc9 /python/pyspark | |
parent | a767ee8a0599f5482717493a3298413c65d8ff89 (diff) | |
download | spark-ee4ee02b86be8756a6d895a2e23e80862134a6d3.tar.gz spark-ee4ee02b86be8756a6d895a2e23e80862134a6d3.tar.bz2 spark-ee4ee02b86be8756a6d895a2e23e80862134a6d3.zip |
[SPARK-12603][MLLIB] PySpark MLlib GaussianMixtureModel should support single instance predict/predictSoft
PySpark MLlib ```GaussianMixtureModel``` should support single instance ```predict/predictSoft``` just like Scala do.
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #10552 from yanboliang/spark-12603.
Diffstat (limited to 'python/pyspark')
-rw-r--r-- | python/pyspark/mllib/clustering.py | 35 |
1 files changed, 22 insertions, 13 deletions
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index d22a7f4c3b..580cb512d8 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -202,16 +202,25 @@ class GaussianMixtureModel(JavaModelWrapper, JavaSaveable, JavaLoader): >>> clusterdata_1 = sc.parallelize(array([-0.1,-0.05,-0.01,-0.1, ... 0.9,0.8,0.75,0.935, - ... -0.83,-0.68,-0.91,-0.76 ]).reshape(6, 2)) + ... -0.83,-0.68,-0.91,-0.76 ]).reshape(6, 2), 2) >>> model = GaussianMixture.train(clusterdata_1, 3, convergenceTol=0.0001, ... maxIterations=50, seed=10) >>> labels = model.predict(clusterdata_1).collect() >>> labels[0]==labels[1] False >>> labels[1]==labels[2] - True + False >>> labels[4]==labels[5] True + >>> model.predict([-0.1,-0.05]) + 0 + >>> softPredicted = model.predictSoft([-0.1,-0.05]) + >>> abs(softPredicted[0] - 1.0) < 0.001 + True + >>> abs(softPredicted[1] - 0.0) < 0.001 + True + >>> abs(softPredicted[2] - 0.0) < 0.001 + True >>> path = tempfile.mkdtemp() >>> model.save(sc, path) @@ -277,26 +286,27 @@ class GaussianMixtureModel(JavaModelWrapper, JavaSaveable, JavaLoader): @since('1.3.0') def predict(self, x): """ - Find the cluster to which the points in 'x' has maximum membership - in this model. + Find the cluster to which the point 'x' or each point in RDD 'x' + has maximum membership in this model. - :param x: RDD of data points. - :return: cluster_labels. RDD of cluster labels. + :param x: vector or RDD of vector represents data points. + :return: cluster label or RDD of cluster labels. """ if isinstance(x, RDD): cluster_labels = self.predictSoft(x).map(lambda z: z.index(max(z))) return cluster_labels else: - raise TypeError("x should be represented by an RDD, " - "but got %s." % type(x)) + z = self.predictSoft(x) + return z.argmax() @since('1.3.0') def predictSoft(self, x): """ - Find the membership of each point in 'x' to all mixture components. + Find the membership of point 'x' or each point in RDD 'x' to all mixture components. - :param x: RDD of data points. - :return: membership_matrix. RDD of array of double values. + :param x: vector or RDD of vector represents data points. + :return: the membership value to all mixture components for vector 'x' + or each vector in RDD 'x'. """ if isinstance(x, RDD): means, sigmas = zip(*[(g.mu, g.sigma) for g in self.gaussians]) @@ -304,8 +314,7 @@ class GaussianMixtureModel(JavaModelWrapper, JavaSaveable, JavaLoader): _convert_to_vector(self.weights), means, sigmas) return membership_matrix.map(lambda x: pyarray.array('d', x)) else: - raise TypeError("x should be represented by an RDD, " - "but got %s." % type(x)) + return self.call("predictSoft", _convert_to_vector(x)).toArray() @classmethod @since('1.5.0') |