diff options
Diffstat (limited to 'python/pyspark/ml/clustering.py')
-rw-r--r-- | python/pyspark/ml/clustering.py | 146 |
1 files changed, 145 insertions, 1 deletions
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index ecdaa3a71c..4ce8012754 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -22,7 +22,151 @@ from pyspark.ml.param.shared import * from pyspark.mllib.common import inherit_doc __all__ = ['BisectingKMeans', 'BisectingKMeansModel', - 'KMeans', 'KMeansModel'] + 'KMeans', 'KMeansModel', + 'GaussianMixture', 'GaussianMixtureModel'] + + +class GaussianMixtureModel(JavaModel, JavaMLWritable, JavaMLReadable): + """ + .. note:: Experimental + + Model fitted by GaussianMixture. + + .. versionadded:: 2.0.0 + """ + + @property + @since("2.0.0") + def weights(self): + """ + Weights for each Gaussian distribution in the mixture, where weights[i] is + the weight for Gaussian i, and weights.sum == 1. + """ + return self._call_java("weights") + + @property + @since("2.0.0") + def gaussiansDF(self): + """ + Retrieve Gaussian distributions as a DataFrame. + Each row represents a Gaussian Distribution. + Two columns are defined: mean and cov. + Schema: + root + -- mean: vector (nullable = true) + -- cov: matrix (nullable = true) + """ + return self._call_java("gaussiansDF") + + +@inherit_doc +class GaussianMixture(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed, + HasProbabilityCol, JavaMLWritable, JavaMLReadable): + """ + .. note:: Experimental + + GaussianMixture clustering. + + >>> from pyspark.mllib.linalg import Vectors + + >>> data = [(Vectors.dense([-0.1, -0.05 ]),), + ... (Vectors.dense([-0.01, -0.1]),), + ... (Vectors.dense([0.9, 0.8]),), + ... (Vectors.dense([0.75, 0.935]),), + ... (Vectors.dense([-0.83, -0.68]),), + ... (Vectors.dense([-0.91, -0.76]),)] + >>> df = sqlContext.createDataFrame(data, ["features"]) + >>> gm = GaussianMixture(k=3, tol=0.0001, + ... maxIter=10, seed=10) + >>> model = gm.fit(df) + >>> weights = model.weights + >>> len(weights) + 3 + >>> model.gaussiansDF.show() + +--------------------+--------------------+ + | mean| cov| + +--------------------+--------------------+ + |[-0.0550000000000...|0.002025000000000...| + |[0.82499999999999...|0.005625000000000...| + |[-0.87,-0.7200000...|0.001600000000000...| + +--------------------+--------------------+ + ... + >>> transformed = model.transform(df).select("features", "prediction") + >>> rows = transformed.collect() + >>> rows[4].prediction == rows[5].prediction + True + >>> rows[2].prediction == rows[3].prediction + True + >>> gmm_path = temp_path + "/gmm" + >>> gm.save(gmm_path) + >>> gm2 = GaussianMixture.load(gmm_path) + >>> gm2.getK() + 3 + >>> model_path = temp_path + "/gmm_model" + >>> model.save(model_path) + >>> model2 = GaussianMixtureModel.load(model_path) + >>> model2.weights == model.weights + True + >>> model2.gaussiansDF.show() + +--------------------+--------------------+ + | mean| cov| + +--------------------+--------------------+ + |[-0.0550000000000...|0.002025000000000...| + |[0.82499999999999...|0.005625000000000...| + |[-0.87,-0.7200000...|0.001600000000000...| + +--------------------+--------------------+ + ... + + .. versionadded:: 2.0.0 + """ + + k = Param(Params._dummy(), "k", "number of clusters to create", + typeConverter=TypeConverters.toInt) + + @keyword_only + def __init__(self, featuresCol="features", predictionCol="prediction", k=2, + probabilityCol="probability", tol=0.01, maxIter=100, seed=None): + """ + __init__(self, featuresCol="features", predictionCol="prediction", k=2, \ + probabilityCol="probability", tol=0.01, maxIter=100, seed=None) + """ + super(GaussianMixture, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.GaussianMixture", + self.uid) + self._setDefault(k=2, tol=0.01, maxIter=100) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + def _create_model(self, java_model): + return GaussianMixtureModel(java_model) + + @keyword_only + @since("2.0.0") + def setParams(self, featuresCol="features", predictionCol="prediction", k=2, + probabilityCol="probability", tol=0.01, maxIter=100, seed=None): + """ + setParams(self, featuresCol="features", predictionCol="prediction", k=2, \ + probabilityCol="probability", tol=0.01, maxIter=100, seed=None) + + Sets params for GaussianMixture. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + @since("2.0.0") + def setK(self, value): + """ + Sets the value of :py:attr:`k`. + """ + self._set(k=value) + return self + + @since("2.0.0") + def getK(self): + """ + Gets the value of `k` + """ + return self.getOrDefault(self.k) class KMeansModel(JavaModel, JavaMLWritable, JavaMLReadable): |