aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorwm624@hotmail.com <wm624@hotmail.com>2016-04-25 10:48:15 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-04-25 10:48:15 -0700
commitb50e2eca93e6b0c391901f1e6d30628f8b6ebaa5 (patch)
tree615ec12aa6fb245f70953516335754afce73afbd /python
parenta680562a6f87a03a00f71bad1c424267ae75c641 (diff)
downloadspark-b50e2eca93e6b0c391901f1e6d30628f8b6ebaa5.tar.gz
spark-b50e2eca93e6b0c391901f1e6d30628f8b6ebaa5.tar.bz2
spark-b50e2eca93e6b0c391901f1e6d30628f8b6ebaa5.zip
[SPARK-14433][PYSPARK][ML] PySpark ml GaussianMixture
## What changes were proposed in this pull request? Add Python API in ML for GaussianMixture ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) Add doctest and test cases are the same as mllib Python tests ./dev/lint-python PEP8 checks passed. rm -rf _build/* pydoc checks passed. ./python/run-tests --python-executables=python2.7 --modules=pyspark-ml Running PySpark tests. Output is in /Users/mwang/spark_ws_0904/python/unit-tests.log Will test against the following Python executables: ['python2.7'] Will test the following Python modules: ['pyspark-ml'] Finished test(python2.7): pyspark.ml.evaluation (18s) Finished test(python2.7): pyspark.ml.clustering (40s) Finished test(python2.7): pyspark.ml.classification (49s) Finished test(python2.7): pyspark.ml.recommendation (44s) Finished test(python2.7): pyspark.ml.feature (64s) Finished test(python2.7): pyspark.ml.regression (45s) Finished test(python2.7): pyspark.ml.tuning (30s) Finished test(python2.7): pyspark.ml.tests (56s) Tests passed in 106 seconds Author: wm624@hotmail.com <wm624@hotmail.com> Closes #12402 from wangmiao1981/gmm.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/ml/clustering.py146
1 files changed, 145 insertions, 1 deletions
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index ecdaa3a71c..4ce8012754 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -22,7 +22,151 @@ from pyspark.ml.param.shared import *
from pyspark.mllib.common import inherit_doc
__all__ = ['BisectingKMeans', 'BisectingKMeansModel',
- 'KMeans', 'KMeansModel']
+ 'KMeans', 'KMeansModel',
+ 'GaussianMixture', 'GaussianMixtureModel']
+
+
+class GaussianMixtureModel(JavaModel, JavaMLWritable, JavaMLReadable):
+ """
+ .. note:: Experimental
+
+ Model fitted by GaussianMixture.
+
+ .. versionadded:: 2.0.0
+ """
+
+ @property
+ @since("2.0.0")
+ def weights(self):
+ """
+ Weights for each Gaussian distribution in the mixture, where weights[i] is
+ the weight for Gaussian i, and weights.sum == 1.
+ """
+ return self._call_java("weights")
+
+ @property
+ @since("2.0.0")
+ def gaussiansDF(self):
+ """
+ Retrieve Gaussian distributions as a DataFrame.
+ Each row represents a Gaussian Distribution.
+ Two columns are defined: mean and cov.
+ Schema:
+ root
+ -- mean: vector (nullable = true)
+ -- cov: matrix (nullable = true)
+ """
+ return self._call_java("gaussiansDF")
+
+
+@inherit_doc
+class GaussianMixture(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed,
+ HasProbabilityCol, JavaMLWritable, JavaMLReadable):
+ """
+ .. note:: Experimental
+
+ GaussianMixture clustering.
+
+ >>> from pyspark.mllib.linalg import Vectors
+
+ >>> data = [(Vectors.dense([-0.1, -0.05 ]),),
+ ... (Vectors.dense([-0.01, -0.1]),),
+ ... (Vectors.dense([0.9, 0.8]),),
+ ... (Vectors.dense([0.75, 0.935]),),
+ ... (Vectors.dense([-0.83, -0.68]),),
+ ... (Vectors.dense([-0.91, -0.76]),)]
+ >>> df = sqlContext.createDataFrame(data, ["features"])
+ >>> gm = GaussianMixture(k=3, tol=0.0001,
+ ... maxIter=10, seed=10)
+ >>> model = gm.fit(df)
+ >>> weights = model.weights
+ >>> len(weights)
+ 3
+ >>> model.gaussiansDF.show()
+ +--------------------+--------------------+
+ | mean| cov|
+ +--------------------+--------------------+
+ |[-0.0550000000000...|0.002025000000000...|
+ |[0.82499999999999...|0.005625000000000...|
+ |[-0.87,-0.7200000...|0.001600000000000...|
+ +--------------------+--------------------+
+ ...
+ >>> transformed = model.transform(df).select("features", "prediction")
+ >>> rows = transformed.collect()
+ >>> rows[4].prediction == rows[5].prediction
+ True
+ >>> rows[2].prediction == rows[3].prediction
+ True
+ >>> gmm_path = temp_path + "/gmm"
+ >>> gm.save(gmm_path)
+ >>> gm2 = GaussianMixture.load(gmm_path)
+ >>> gm2.getK()
+ 3
+ >>> model_path = temp_path + "/gmm_model"
+ >>> model.save(model_path)
+ >>> model2 = GaussianMixtureModel.load(model_path)
+ >>> model2.weights == model.weights
+ True
+ >>> model2.gaussiansDF.show()
+ +--------------------+--------------------+
+ | mean| cov|
+ +--------------------+--------------------+
+ |[-0.0550000000000...|0.002025000000000...|
+ |[0.82499999999999...|0.005625000000000...|
+ |[-0.87,-0.7200000...|0.001600000000000...|
+ +--------------------+--------------------+
+ ...
+
+ .. versionadded:: 2.0.0
+ """
+
+ k = Param(Params._dummy(), "k", "number of clusters to create",
+ typeConverter=TypeConverters.toInt)
+
+ @keyword_only
+ def __init__(self, featuresCol="features", predictionCol="prediction", k=2,
+ probabilityCol="probability", tol=0.01, maxIter=100, seed=None):
+ """
+ __init__(self, featuresCol="features", predictionCol="prediction", k=2, \
+ probabilityCol="probability", tol=0.01, maxIter=100, seed=None)
+ """
+ super(GaussianMixture, self).__init__()
+ self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.GaussianMixture",
+ self.uid)
+ self._setDefault(k=2, tol=0.01, maxIter=100)
+ kwargs = self.__init__._input_kwargs
+ self.setParams(**kwargs)
+
+ def _create_model(self, java_model):
+ return GaussianMixtureModel(java_model)
+
+ @keyword_only
+ @since("2.0.0")
+ def setParams(self, featuresCol="features", predictionCol="prediction", k=2,
+ probabilityCol="probability", tol=0.01, maxIter=100, seed=None):
+ """
+ setParams(self, featuresCol="features", predictionCol="prediction", k=2, \
+ probabilityCol="probability", tol=0.01, maxIter=100, seed=None)
+
+ Sets params for GaussianMixture.
+ """
+ kwargs = self.setParams._input_kwargs
+ return self._set(**kwargs)
+
+ @since("2.0.0")
+ def setK(self, value):
+ """
+ Sets the value of :py:attr:`k`.
+ """
+ self._set(k=value)
+ return self
+
+ @since("2.0.0")
+ def getK(self):
+ """
+ Gets the value of `k`
+ """
+ return self.getOrDefault(self.k)
class KMeansModel(JavaModel, JavaMLWritable, JavaMLReadable):