diff options
author | wm624@hotmail.com <wm624@hotmail.com> | 2016-04-25 10:48:15 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-04-25 10:48:15 -0700 |
commit | b50e2eca93e6b0c391901f1e6d30628f8b6ebaa5 (patch) | |
tree | 615ec12aa6fb245f70953516335754afce73afbd /python/pyspark/ml/clustering.py | |
parent | a680562a6f87a03a00f71bad1c424267ae75c641 (diff) | |
download | spark-b50e2eca93e6b0c391901f1e6d30628f8b6ebaa5.tar.gz spark-b50e2eca93e6b0c391901f1e6d30628f8b6ebaa5.tar.bz2 spark-b50e2eca93e6b0c391901f1e6d30628f8b6ebaa5.zip |
[SPARK-14433][PYSPARK][ML] PySpark ml GaussianMixture
## What changes were proposed in this pull request?
Add Python API in ML for GaussianMixture
## How was this patch tested?
(Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests)
Add doctest and test cases are the same as mllib Python tests
./dev/lint-python
PEP8 checks passed.
rm -rf _build/*
pydoc checks passed.
./python/run-tests --python-executables=python2.7 --modules=pyspark-ml
Running PySpark tests. Output is in /Users/mwang/spark_ws_0904/python/unit-tests.log
Will test against the following Python executables: ['python2.7']
Will test the following Python modules: ['pyspark-ml']
Finished test(python2.7): pyspark.ml.evaluation (18s)
Finished test(python2.7): pyspark.ml.clustering (40s)
Finished test(python2.7): pyspark.ml.classification (49s)
Finished test(python2.7): pyspark.ml.recommendation (44s)
Finished test(python2.7): pyspark.ml.feature (64s)
Finished test(python2.7): pyspark.ml.regression (45s)
Finished test(python2.7): pyspark.ml.tuning (30s)
Finished test(python2.7): pyspark.ml.tests (56s)
Tests passed in 106 seconds
Author: wm624@hotmail.com <wm624@hotmail.com>
Closes #12402 from wangmiao1981/gmm.
Diffstat (limited to 'python/pyspark/ml/clustering.py')
-rw-r--r-- | python/pyspark/ml/clustering.py | 146 |
1 files changed, 145 insertions, 1 deletions
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index ecdaa3a71c..4ce8012754 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -22,7 +22,151 @@ from pyspark.ml.param.shared import * from pyspark.mllib.common import inherit_doc __all__ = ['BisectingKMeans', 'BisectingKMeansModel', - 'KMeans', 'KMeansModel'] + 'KMeans', 'KMeansModel', + 'GaussianMixture', 'GaussianMixtureModel'] + + +class GaussianMixtureModel(JavaModel, JavaMLWritable, JavaMLReadable): + """ + .. note:: Experimental + + Model fitted by GaussianMixture. + + .. versionadded:: 2.0.0 + """ + + @property + @since("2.0.0") + def weights(self): + """ + Weights for each Gaussian distribution in the mixture, where weights[i] is + the weight for Gaussian i, and weights.sum == 1. + """ + return self._call_java("weights") + + @property + @since("2.0.0") + def gaussiansDF(self): + """ + Retrieve Gaussian distributions as a DataFrame. + Each row represents a Gaussian Distribution. + Two columns are defined: mean and cov. + Schema: + root + -- mean: vector (nullable = true) + -- cov: matrix (nullable = true) + """ + return self._call_java("gaussiansDF") + + +@inherit_doc +class GaussianMixture(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed, + HasProbabilityCol, JavaMLWritable, JavaMLReadable): + """ + .. note:: Experimental + + GaussianMixture clustering. + + >>> from pyspark.mllib.linalg import Vectors + + >>> data = [(Vectors.dense([-0.1, -0.05 ]),), + ... (Vectors.dense([-0.01, -0.1]),), + ... (Vectors.dense([0.9, 0.8]),), + ... (Vectors.dense([0.75, 0.935]),), + ... (Vectors.dense([-0.83, -0.68]),), + ... (Vectors.dense([-0.91, -0.76]),)] + >>> df = sqlContext.createDataFrame(data, ["features"]) + >>> gm = GaussianMixture(k=3, tol=0.0001, + ... maxIter=10, seed=10) + >>> model = gm.fit(df) + >>> weights = model.weights + >>> len(weights) + 3 + >>> model.gaussiansDF.show() + +--------------------+--------------------+ + | mean| cov| + +--------------------+--------------------+ + |[-0.0550000000000...|0.002025000000000...| + |[0.82499999999999...|0.005625000000000...| + |[-0.87,-0.7200000...|0.001600000000000...| + +--------------------+--------------------+ + ... + >>> transformed = model.transform(df).select("features", "prediction") + >>> rows = transformed.collect() + >>> rows[4].prediction == rows[5].prediction + True + >>> rows[2].prediction == rows[3].prediction + True + >>> gmm_path = temp_path + "/gmm" + >>> gm.save(gmm_path) + >>> gm2 = GaussianMixture.load(gmm_path) + >>> gm2.getK() + 3 + >>> model_path = temp_path + "/gmm_model" + >>> model.save(model_path) + >>> model2 = GaussianMixtureModel.load(model_path) + >>> model2.weights == model.weights + True + >>> model2.gaussiansDF.show() + +--------------------+--------------------+ + | mean| cov| + +--------------------+--------------------+ + |[-0.0550000000000...|0.002025000000000...| + |[0.82499999999999...|0.005625000000000...| + |[-0.87,-0.7200000...|0.001600000000000...| + +--------------------+--------------------+ + ... + + .. versionadded:: 2.0.0 + """ + + k = Param(Params._dummy(), "k", "number of clusters to create", + typeConverter=TypeConverters.toInt) + + @keyword_only + def __init__(self, featuresCol="features", predictionCol="prediction", k=2, + probabilityCol="probability", tol=0.01, maxIter=100, seed=None): + """ + __init__(self, featuresCol="features", predictionCol="prediction", k=2, \ + probabilityCol="probability", tol=0.01, maxIter=100, seed=None) + """ + super(GaussianMixture, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.GaussianMixture", + self.uid) + self._setDefault(k=2, tol=0.01, maxIter=100) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + def _create_model(self, java_model): + return GaussianMixtureModel(java_model) + + @keyword_only + @since("2.0.0") + def setParams(self, featuresCol="features", predictionCol="prediction", k=2, + probabilityCol="probability", tol=0.01, maxIter=100, seed=None): + """ + setParams(self, featuresCol="features", predictionCol="prediction", k=2, \ + probabilityCol="probability", tol=0.01, maxIter=100, seed=None) + + Sets params for GaussianMixture. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + @since("2.0.0") + def setK(self, value): + """ + Sets the value of :py:attr:`k`. + """ + self._set(k=value) + return self + + @since("2.0.0") + def getK(self): + """ + Gets the value of `k` + """ + return self.getOrDefault(self.k) class KMeansModel(JavaModel, JavaMLWritable, JavaMLReadable): |