aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2016-04-26 16:53:16 -0700
committerDB Tsai <dbt@netflix.com>2016-04-26 16:53:16 -0700
commitbd2c9a6d48ef6d489c747d9db2642bdef6b1f728 (patch)
tree9a8a4864825aca4e8f11d4442d33e1ca4f7ac0c4 /python/pyspark/ml
parent0c99c23b7d9f0c3538cd2b062d551411712a2bcc (diff)
downloadspark-bd2c9a6d48ef6d489c747d9db2642bdef6b1f728.tar.gz
spark-bd2c9a6d48ef6d489c747d9db2642bdef6b1f728.tar.bz2
spark-bd2c9a6d48ef6d489c747d9db2642bdef6b1f728.zip
[SPARK-14732][ML] spark.ml GaussianMixture should use MultivariateGaussian in mllib-local
## What changes were proposed in this pull request? Before, spark.ml GaussianMixtureModel used the spark.mllib MultivariateGaussian in its public API. This was added after 1.6, so we can modify this API without breaking APIs. This PR copies MultivariateGaussian to mllib-local in spark.ml, with a few changes: * Renamed fields to match numpy, scipy: mu => mean, sigma => cov This PR then uses the spark.ml MultivariateGaussian in the spark.ml GaussianMixtureModel, which involves: * Modifying the constructor * Adding a computeProbabilities method Also: * Added EPSILON to mllib-local for use in MultivariateGaussian ## How was this patch tested? Existing unit tests Author: Joseph K. Bradley <joseph@databricks.com> Closes #12593 from jkbradley/sparkml-gmm-fix.
Diffstat (limited to 'python/pyspark/ml')
-rw-r--r--python/pyspark/ml/clustering.py11
1 files changed, 4 insertions, 7 deletions
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index 9740ec45af..16ce02ee7d 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -39,8 +39,9 @@ class GaussianMixtureModel(JavaModel, JavaMLWritable, JavaMLReadable):
@since("2.0.0")
def weights(self):
"""
- Weights for each Gaussian distribution in the mixture, where weights[i] is
- the weight for Gaussian i, and weights.sum == 1.
+ Weight for each Gaussian distribution in the mixture.
+ This is a multinomial probability distribution over the k Gaussians,
+ where weights[i] is the weight for Gaussian i, and weights sum to 1.
"""
return self._call_java("weights")
@@ -50,11 +51,7 @@ class GaussianMixtureModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
Retrieve Gaussian distributions as a DataFrame.
Each row represents a Gaussian Distribution.
- Two columns are defined: mean and cov.
- Schema:
- root
- -- mean: vector (nullable = true)
- -- cov: matrix (nullable = true)
+ The DataFrame has two columns: mean (Vector) and cov (Matrix).
"""
return self._call_java("gaussiansDF")