aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala27
-rw-r--r--python/pyspark/ml/clustering.py146
2 files changed, 170 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
index ead8ad7806..dfbc8b612c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml.clustering
import org.apache.hadoop.fs.Path
+import org.apache.spark.SparkContext
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param.{IntParam, ParamMap, Params}
@@ -27,7 +28,7 @@ import org.apache.spark.ml.util._
import org.apache.spark.mllib.clustering.{GaussianMixture => MLlibGM, GaussianMixtureModel => MLlibGMModel}
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
-import org.apache.spark.sql.{DataFrame, Dataset, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{IntegerType, StructType}
@@ -104,6 +105,27 @@ class GaussianMixtureModel private[ml] (
@Since("2.0.0")
def gaussians: Array[MultivariateGaussian] = parentModel.gaussians
+ /**
+ * Retrieve Gaussian distributions as a DataFrame.
+ * Each row represents a Gaussian Distribution.
+ * Two columns are defined: mean and cov.
+ * Schema:
+ * {{{
+ * root
+ * |-- mean: vector (nullable = true)
+ * |-- cov: matrix (nullable = true)
+ * }}}
+ */
+ @Since("2.0.0")
+ def gaussiansDF: DataFrame = {
+ val modelGaussians = gaussians.map { gaussian =>
+ (gaussian.mu, gaussian.sigma)
+ }
+ val sc = SparkContext.getOrCreate()
+ val sqlContext = SQLContext.getOrCreate(sc)
+ sqlContext.createDataFrame(modelGaussians).toDF("mean", "cov")
+ }
+
@Since("2.0.0")
override def write: MLWriter = new GaussianMixtureModel.GaussianMixtureModelWriter(this)
@@ -247,7 +269,8 @@ class GaussianMixture @Since("2.0.0") (
.setSeed($(seed))
.setConvergenceTol($(tol))
val parentModel = algo.run(rdd)
- val model = copyValues(new GaussianMixtureModel(uid, parentModel).setParent(this))
+ val model = copyValues(new GaussianMixtureModel(uid, parentModel)
+ .setParent(this))
val summary = new GaussianMixtureSummary(model.transform(dataset),
$(predictionCol), $(probabilityCol), $(featuresCol), $(k))
model.setSummary(summary)
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index ecdaa3a71c..4ce8012754 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -22,7 +22,151 @@ from pyspark.ml.param.shared import *
from pyspark.mllib.common import inherit_doc
__all__ = ['BisectingKMeans', 'BisectingKMeansModel',
- 'KMeans', 'KMeansModel']
+ 'KMeans', 'KMeansModel',
+ 'GaussianMixture', 'GaussianMixtureModel']
+
+
+class GaussianMixtureModel(JavaModel, JavaMLWritable, JavaMLReadable):
+ """
+ .. note:: Experimental
+
+ Model fitted by GaussianMixture.
+
+ .. versionadded:: 2.0.0
+ """
+
+ @property
+ @since("2.0.0")
+ def weights(self):
+ """
+ Weights for each Gaussian distribution in the mixture, where weights[i] is
+ the weight for Gaussian i, and weights.sum == 1.
+ """
+ return self._call_java("weights")
+
+ @property
+ @since("2.0.0")
+ def gaussiansDF(self):
+ """
+ Retrieve Gaussian distributions as a DataFrame.
+ Each row represents a Gaussian Distribution.
+ Two columns are defined: mean and cov.
+ Schema:
+ root
+ -- mean: vector (nullable = true)
+ -- cov: matrix (nullable = true)
+ """
+ return self._call_java("gaussiansDF")
+
+
+@inherit_doc
+class GaussianMixture(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed,
+ HasProbabilityCol, JavaMLWritable, JavaMLReadable):
+ """
+ .. note:: Experimental
+
+ GaussianMixture clustering.
+
+ >>> from pyspark.mllib.linalg import Vectors
+
+ >>> data = [(Vectors.dense([-0.1, -0.05 ]),),
+ ... (Vectors.dense([-0.01, -0.1]),),
+ ... (Vectors.dense([0.9, 0.8]),),
+ ... (Vectors.dense([0.75, 0.935]),),
+ ... (Vectors.dense([-0.83, -0.68]),),
+ ... (Vectors.dense([-0.91, -0.76]),)]
+ >>> df = sqlContext.createDataFrame(data, ["features"])
+ >>> gm = GaussianMixture(k=3, tol=0.0001,
+ ... maxIter=10, seed=10)
+ >>> model = gm.fit(df)
+ >>> weights = model.weights
+ >>> len(weights)
+ 3
+ >>> model.gaussiansDF.show()
+ +--------------------+--------------------+
+ | mean| cov|
+ +--------------------+--------------------+
+ |[-0.0550000000000...|0.002025000000000...|
+ |[0.82499999999999...|0.005625000000000...|
+ |[-0.87,-0.7200000...|0.001600000000000...|
+ +--------------------+--------------------+
+ ...
+ >>> transformed = model.transform(df).select("features", "prediction")
+ >>> rows = transformed.collect()
+ >>> rows[4].prediction == rows[5].prediction
+ True
+ >>> rows[2].prediction == rows[3].prediction
+ True
+ >>> gmm_path = temp_path + "/gmm"
+ >>> gm.save(gmm_path)
+ >>> gm2 = GaussianMixture.load(gmm_path)
+ >>> gm2.getK()
+ 3
+ >>> model_path = temp_path + "/gmm_model"
+ >>> model.save(model_path)
+ >>> model2 = GaussianMixtureModel.load(model_path)
+ >>> model2.weights == model.weights
+ True
+ >>> model2.gaussiansDF.show()
+ +--------------------+--------------------+
+ | mean| cov|
+ +--------------------+--------------------+
+ |[-0.0550000000000...|0.002025000000000...|
+ |[0.82499999999999...|0.005625000000000...|
+ |[-0.87,-0.7200000...|0.001600000000000...|
+ +--------------------+--------------------+
+ ...
+
+ .. versionadded:: 2.0.0
+ """
+
+ k = Param(Params._dummy(), "k", "number of clusters to create",
+ typeConverter=TypeConverters.toInt)
+
+ @keyword_only
+ def __init__(self, featuresCol="features", predictionCol="prediction", k=2,
+ probabilityCol="probability", tol=0.01, maxIter=100, seed=None):
+ """
+ __init__(self, featuresCol="features", predictionCol="prediction", k=2, \
+ probabilityCol="probability", tol=0.01, maxIter=100, seed=None)
+ """
+ super(GaussianMixture, self).__init__()
+ self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.GaussianMixture",
+ self.uid)
+ self._setDefault(k=2, tol=0.01, maxIter=100)
+ kwargs = self.__init__._input_kwargs
+ self.setParams(**kwargs)
+
+ def _create_model(self, java_model):
+ return GaussianMixtureModel(java_model)
+
+ @keyword_only
+ @since("2.0.0")
+ def setParams(self, featuresCol="features", predictionCol="prediction", k=2,
+ probabilityCol="probability", tol=0.01, maxIter=100, seed=None):
+ """
+ setParams(self, featuresCol="features", predictionCol="prediction", k=2, \
+ probabilityCol="probability", tol=0.01, maxIter=100, seed=None)
+
+ Sets params for GaussianMixture.
+ """
+ kwargs = self.setParams._input_kwargs
+ return self._set(**kwargs)
+
+ @since("2.0.0")
+ def setK(self, value):
+ """
+ Sets the value of :py:attr:`k`.
+ """
+ self._set(k=value)
+ return self
+
+ @since("2.0.0")
+ def getK(self):
+ """
+ Gets the value of `k`
+ """
+ return self.getOrDefault(self.k)
class KMeansModel(JavaModel, JavaMLWritable, JavaMLReadable):