aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib-local/src/main/scala/org/apache/spark/ml/impl/Utils.scala30
-rw-r--r--mllib-local/src/main/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussian.scala131
-rw-r--r--mllib-local/src/test/scala/org/apache/spark/ml/impl/UtilsSuite.scala30
-rw-r--r--mllib-local/src/test/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussianSuite.scala83
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala108
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala4
-rw-r--r--python/pyspark/ml/clustering.py11
7 files changed, 353 insertions, 44 deletions
diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/impl/Utils.scala b/mllib-local/src/main/scala/org/apache/spark/ml/impl/Utils.scala
new file mode 100644
index 0000000000..112de982e4
--- /dev/null
+++ b/mllib-local/src/main/scala/org/apache/spark/ml/impl/Utils.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.impl
+
+
+private[ml] object Utils {
+
+ lazy val EPSILON = {
+ var eps = 1.0
+ while ((1.0 + (eps / 2.0)) != 1.0) {
+ eps /= 2.0
+ }
+ eps
+ }
+}
diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussian.scala b/mllib-local/src/main/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussian.scala
new file mode 100644
index 0000000000..c62a1eab20
--- /dev/null
+++ b/mllib-local/src/main/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussian.scala
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.stat.distribution
+
+import breeze.linalg.{diag, eigSym, max, DenseMatrix => BDM, DenseVector => BDV, Vector => BV}
+
+import org.apache.spark.ml.impl.Utils
+import org.apache.spark.ml.linalg.{Matrices, Matrix, Vector, Vectors}
+
+
+/**
+ * This class provides basic functionality for a Multivariate Gaussian (Normal) Distribution. In
+ * the event that the covariance matrix is singular, the density will be computed in a
+ * reduced dimensional subspace under which the distribution is supported.
+ * (see [[http://en.wikipedia.org/wiki/Multivariate_normal_distribution#Degenerate_case]])
+ *
+ * @param mean The mean vector of the distribution
+ * @param cov The covariance matrix of the distribution
+ */
+class MultivariateGaussian(
+ val mean: Vector,
+ val cov: Matrix) extends Serializable {
+
+ require(cov.numCols == cov.numRows, "Covariance matrix must be square")
+ require(mean.size == cov.numCols, "Mean vector length must match covariance matrix size")
+
+ /** Private constructor taking Breeze types */
+ private[ml] def this(mean: BDV[Double], cov: BDM[Double]) = {
+ this(Vectors.fromBreeze(mean), Matrices.fromBreeze(cov))
+ }
+
+ private val breezeMu = mean.toBreeze.toDenseVector
+
+ /**
+ * Compute distribution dependent constants:
+ * rootSigmaInv = D^(-1/2)^ * U.t, where sigma = U * D * U.t
+ * u = log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^)
+ */
+ private val (rootSigmaInv: BDM[Double], u: Double) = calculateCovarianceConstants
+
+ /**
+ * Returns density of this multivariate Gaussian at given point, x
+ */
+ def pdf(x: Vector): Double = {
+ pdf(x.toBreeze)
+ }
+
+ /**
+ * Returns the log-density of this multivariate Gaussian at given point, x
+ */
+ def logpdf(x: Vector): Double = {
+ logpdf(x.toBreeze)
+ }
+
+ /** Returns density of this multivariate Gaussian at given point, x */
+ private[ml] def pdf(x: BV[Double]): Double = {
+ math.exp(logpdf(x))
+ }
+
+ /** Returns the log-density of this multivariate Gaussian at given point, x */
+ private[ml] def logpdf(x: BV[Double]): Double = {
+ val delta = x - breezeMu
+ val v = rootSigmaInv * delta
+ u + v.t * v * -0.5
+ }
+
+ /**
+ * Calculate distribution dependent components used for the density function:
+ * pdf(x) = (2*pi)^(-k/2)^ * det(sigma)^(-1/2)^ * exp((-1/2) * (x-mu).t * inv(sigma) * (x-mu))
+ * where k is length of the mean vector.
+ *
+ * We here compute distribution-fixed parts
+ * log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^)
+ * and
+ * D^(-1/2)^ * U, where sigma = U * D * U.t
+ *
+ * Both the determinant and the inverse can be computed from the singular value decomposition
+ * of sigma. Noting that covariance matrices are always symmetric and positive semi-definite,
+ * we can use the eigendecomposition. We also do not compute the inverse directly; noting
+ * that
+ *
+ * sigma = U * D * U.t
+ * inv(Sigma) = U * inv(D) * U.t
+ * = (D^{-1/2}^ * U.t).t * (D^{-1/2}^ * U.t)
+ *
+ * and thus
+ *
+ * -0.5 * (x-mu).t * inv(Sigma) * (x-mu) = -0.5 * norm(D^{-1/2}^ * U.t * (x-mu))^2^
+ *
+ * To guard against singular covariance matrices, this method computes both the
+ * pseudo-determinant and the pseudo-inverse (Moore-Penrose). Singular values are considered
+ * to be non-zero only if they exceed a tolerance based on machine precision, matrix size, and
+ * relation to the maximum singular value (same tolerance used by, e.g., Octave).
+ */
+ private def calculateCovarianceConstants: (BDM[Double], Double) = {
+ val eigSym.EigSym(d, u) = eigSym(cov.toBreeze.toDenseMatrix) // sigma = u * diag(d) * u.t
+
+ // For numerical stability, values are considered to be non-zero only if they exceed tol.
+ // This prevents any inverted value from exceeding (eps * n * max(d))^-1
+ val tol = Utils.EPSILON * max(d) * d.length
+
+ try {
+ // log(pseudo-determinant) is sum of the logs of all non-zero singular values
+ val logPseudoDetSigma = d.activeValuesIterator.filter(_ > tol).map(math.log).sum
+
+ // calculate the root-pseudo-inverse of the diagonal matrix of singular values
+ // by inverting the square root of all non-zero values
+ val pinvS = diag(new BDV(d.map(v => if (v > tol) math.sqrt(1.0 / v) else 0.0).toArray))
+
+ (pinvS * u.t, -0.5 * (mean.size * math.log(2.0 * math.Pi) + logPseudoDetSigma))
+ } catch {
+ case uex: UnsupportedOperationException =>
+ throw new IllegalArgumentException("Covariance matrix has no non-zero singular values")
+ }
+ }
+}
diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/impl/UtilsSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/impl/UtilsSuite.scala
new file mode 100644
index 0000000000..44b122b694
--- /dev/null
+++ b/mllib-local/src/test/scala/org/apache/spark/ml/impl/UtilsSuite.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.impl
+
+import org.apache.spark.ml.impl.Utils.EPSILON
+import org.apache.spark.ml.SparkMLFunSuite
+
+
+class UtilsSuite extends SparkMLFunSuite {
+
+ test("EPSILON") {
+ assert(1.0 + EPSILON > 1.0, s"EPSILON is too small: $EPSILON.")
+ assert(1.0 + EPSILON / 2.0 === 1.0, s"EPSILON is too big: $EPSILON.")
+ }
+}
diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussianSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussianSuite.scala
new file mode 100644
index 0000000000..f9306ed83e
--- /dev/null
+++ b/mllib-local/src/test/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussianSuite.scala
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.stat.distribution
+
+import org.apache.spark.ml.SparkMLFunSuite
+import org.apache.spark.ml.linalg.{Matrices, Vectors}
+import org.apache.spark.ml.util.TestingUtils._
+
+
+class MultivariateGaussianSuite extends SparkMLFunSuite {
+
+ test("univariate") {
+ val x1 = Vectors.dense(0.0)
+ val x2 = Vectors.dense(1.5)
+
+ val mu = Vectors.dense(0.0)
+ val sigma1 = Matrices.dense(1, 1, Array(1.0))
+ val dist1 = new MultivariateGaussian(mu, sigma1)
+ assert(dist1.pdf(x1) ~== 0.39894 absTol 1E-5)
+ assert(dist1.pdf(x2) ~== 0.12952 absTol 1E-5)
+
+ val sigma2 = Matrices.dense(1, 1, Array(4.0))
+ val dist2 = new MultivariateGaussian(mu, sigma2)
+ assert(dist2.pdf(x1) ~== 0.19947 absTol 1E-5)
+ assert(dist2.pdf(x2) ~== 0.15057 absTol 1E-5)
+ }
+
+ test("multivariate") {
+ val x1 = Vectors.dense(0.0, 0.0)
+ val x2 = Vectors.dense(1.0, 1.0)
+
+ val mu = Vectors.dense(0.0, 0.0)
+ val sigma1 = Matrices.dense(2, 2, Array(1.0, 0.0, 0.0, 1.0))
+ val dist1 = new MultivariateGaussian(mu, sigma1)
+ assert(dist1.pdf(x1) ~== 0.15915 absTol 1E-5)
+ assert(dist1.pdf(x2) ~== 0.05855 absTol 1E-5)
+
+ val sigma2 = Matrices.dense(2, 2, Array(4.0, -1.0, -1.0, 2.0))
+ val dist2 = new MultivariateGaussian(mu, sigma2)
+ assert(dist2.pdf(x1) ~== 0.060155 absTol 1E-5)
+ assert(dist2.pdf(x2) ~== 0.033971 absTol 1E-5)
+ }
+
+ test("multivariate degenerate") {
+ val x1 = Vectors.dense(0.0, 0.0)
+ val x2 = Vectors.dense(1.0, 1.0)
+
+ val mu = Vectors.dense(0.0, 0.0)
+ val sigma = Matrices.dense(2, 2, Array(1.0, 1.0, 1.0, 1.0))
+ val dist = new MultivariateGaussian(mu, sigma)
+ assert(dist.pdf(x1) ~== 0.11254 absTol 1E-5)
+ assert(dist.pdf(x2) ~== 0.068259 absTol 1E-5)
+ }
+
+ test("SPARK-11302") {
+ val x = Vectors.dense(629, 640, 1.7188, 618.19)
+ val mu = Vectors.dense(
+ 1055.3910505836575, 1070.489299610895, 1.39020554474708, 1040.5907503867697)
+ val sigma = Matrices.dense(4, 4, Array(
+ 166769.00466698944, 169336.6705268059, 12.820670788921873, 164243.93314092053,
+ 169336.6705268059, 172041.5670061245, 21.62590020524533, 166678.01075856484,
+ 12.820670788921873, 21.62590020524533, 0.872524191943962, 4.283255814732373,
+ 164243.93314092053, 166678.01075856484, 4.283255814732373, 161848.9196719207))
+ val dist = new MultivariateGaussian(mu, sigma)
+ // Agrees with R's dmvnorm: 7.154782e-05
+ assert(dist.pdf(x) ~== 7.154782224045512E-5 absTol 1E-9)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
index dfbc8b612c..ac86e4ce25 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
@@ -17,17 +17,21 @@
package org.apache.spark.ml.clustering
+import breeze.linalg.{DenseVector => BDV}
import org.apache.hadoop.fs.Path
import org.apache.spark.SparkContext
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.impl.Utils.EPSILON
+import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param.{IntParam, ParamMap, Params}
import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.stat.distribution.MultivariateGaussian
import org.apache.spark.ml.util._
-import org.apache.spark.mllib.clustering.{GaussianMixture => MLlibGM, GaussianMixtureModel => MLlibGMModel}
-import org.apache.spark.mllib.linalg._
-import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
+import org.apache.spark.mllib.clustering.{GaussianMixture => MLlibGM}
+import org.apache.spark.mllib.linalg.{Matrices => OldMatrices, Matrix => OldMatrix,
+ Vector => OldVector, Vectors => OldVectors, VectorUDT => OldVectorUDT}
import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{IntegerType, StructType}
@@ -56,34 +60,42 @@ private[clustering] trait GaussianMixtureParams extends Params with HasMaxIter w
* @return output schema
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
- SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
+ SchemaUtils.checkColumnType(schema, $(featuresCol), new OldVectorUDT)
SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
- SchemaUtils.appendColumn(schema, $(probabilityCol), new VectorUDT)
+ SchemaUtils.appendColumn(schema, $(probabilityCol), new OldVectorUDT)
}
}
/**
* :: Experimental ::
- * Model fitted by GaussianMixture.
- * @param parentModel a model trained by spark.mllib.clustering.GaussianMixture.
+ *
+ * Multivariate Gaussian Mixture Model (GMM) consisting of k Gaussians, where points
+ * are drawn from each Gaussian i with probability weights(i).
+ *
+ * @param weights Weight for each Gaussian distribution in the mixture.
+ * This is a multinomial probability distribution over the k Gaussians,
+ * where weights(i) is the weight for Gaussian i, and weights sum to 1.
+ * @param gaussians Array of [[MultivariateGaussian]] where gaussians(i) represents
+ * the Multivariate Gaussian (Normal) Distribution for Gaussian i
*/
@Since("2.0.0")
@Experimental
class GaussianMixtureModel private[ml] (
@Since("2.0.0") override val uid: String,
- private val parentModel: MLlibGMModel)
+ @Since("2.0.0") val weights: Array[Double],
+ @Since("2.0.0") val gaussians: Array[MultivariateGaussian])
extends Model[GaussianMixtureModel] with GaussianMixtureParams with MLWritable {
@Since("2.0.0")
override def copy(extra: ParamMap): GaussianMixtureModel = {
- val copied = new GaussianMixtureModel(uid, parentModel)
+ val copied = new GaussianMixtureModel(uid, weights, gaussians)
copyValues(copied, extra).setParent(this.parent)
}
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
- val predUDF = udf((vector: Vector) => predict(vector))
- val probUDF = udf((vector: Vector) => predictProbability(vector))
+ val predUDF = udf((vector: OldVector) => predict(vector.asML))
+ val probUDF = udf((vector: OldVector) => OldVectors.fromML(predictProbability(vector.asML)))
dataset.withColumn($(predictionCol), predUDF(col($(featuresCol))))
.withColumn($(probabilityCol), probUDF(col($(featuresCol))))
}
@@ -93,33 +105,32 @@ class GaussianMixtureModel private[ml] (
validateAndTransformSchema(schema)
}
- private[clustering] def predict(features: Vector): Int = parentModel.predict(features)
+ private[clustering] def predict(features: Vector): Int = {
+ val r = predictProbability(features)
+ r.argmax
+ }
private[clustering] def predictProbability(features: Vector): Vector = {
- Vectors.dense(parentModel.predictSoft(features))
+ val probs: Array[Double] =
+ GaussianMixtureModel.computeProbabilities(features.toBreeze.toDenseVector, gaussians, weights)
+ Vectors.dense(probs)
}
- @Since("2.0.0")
- def weights: Array[Double] = parentModel.weights
-
- @Since("2.0.0")
- def gaussians: Array[MultivariateGaussian] = parentModel.gaussians
-
/**
* Retrieve Gaussian distributions as a DataFrame.
* Each row represents a Gaussian Distribution.
* Two columns are defined: mean and cov.
* Schema:
* {{{
- * root
- * |-- mean: vector (nullable = true)
- * |-- cov: matrix (nullable = true)
+ * root
+ * |-- mean: vector (nullable = true)
+ * |-- cov: matrix (nullable = true)
* }}}
*/
@Since("2.0.0")
def gaussiansDF: DataFrame = {
val modelGaussians = gaussians.map { gaussian =>
- (gaussian.mu, gaussian.sigma)
+ (OldVectors.fromML(gaussian.mean), OldMatrices.fromML(gaussian.cov))
}
val sc = SparkContext.getOrCreate()
val sqlContext = SQLContext.getOrCreate(sc)
@@ -166,7 +177,7 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] {
private[GaussianMixtureModel] class GaussianMixtureModelWriter(
instance: GaussianMixtureModel) extends MLWriter {
- private case class Data(weights: Array[Double], mus: Array[Vector], sigmas: Array[Matrix])
+ private case class Data(weights: Array[Double], mus: Array[OldVector], sigmas: Array[OldMatrix])
override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
@@ -174,8 +185,8 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] {
// Save model data: weights and gaussians
val weights = instance.weights
val gaussians = instance.gaussians
- val mus = gaussians.map(_.mu)
- val sigmas = gaussians.map(_.sigma)
+ val mus = gaussians.map(g => OldVectors.fromML(g.mean))
+ val sigmas = gaussians.map(c => OldMatrices.fromML(c.cov))
val data = Data(weights, mus, sigmas)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
@@ -193,26 +204,50 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] {
val dataPath = new Path(path, "data").toString
val row = sqlContext.read.parquet(dataPath).select("weights", "mus", "sigmas").head()
val weights = row.getSeq[Double](0).toArray
- val mus = row.getSeq[Vector](1).toArray
- val sigmas = row.getSeq[Matrix](2).toArray
+ val mus = row.getSeq[OldVector](1).toArray
+ val sigmas = row.getSeq[OldMatrix](2).toArray
require(mus.length == sigmas.length, "Length of Mu and Sigma array must match")
require(mus.length == weights.length, "Length of weight and Gaussian array must match")
- val gaussians = (mus zip sigmas).map {
+ val gaussians = mus.zip(sigmas).map {
case (mu, sigma) =>
- new MultivariateGaussian(mu, sigma)
+ new MultivariateGaussian(mu.asML, sigma.asML)
}
- val model = new GaussianMixtureModel(metadata.uid, new MLlibGMModel(weights, gaussians))
+ val model = new GaussianMixtureModel(metadata.uid, weights, gaussians)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
}
+
+ /**
+ * Compute the probability (partial assignment) for each cluster for the given data point.
+ * @param features Data point
+ * @param dists Gaussians for model
+ * @param weights Weights for each Gaussian
+ * @return Probability (partial assignment) for each of the k clusters
+ */
+ private[clustering]
+ def computeProbabilities(
+ features: BDV[Double],
+ dists: Array[MultivariateGaussian],
+ weights: Array[Double]): Array[Double] = {
+ val p = weights.zip(dists).map {
+ case (weight, dist) => EPSILON + weight * dist.pdf(features)
+ }
+ val pSum = p.sum
+ var i = 0
+ while (i < weights.length) {
+ p(i) /= pSum
+ i += 1
+ }
+ p
+ }
}
/**
* :: Experimental ::
- * GaussianMixture clustering.
+ * Gaussian Mixture clustering.
*/
@Since("2.0.0")
@Experimental
@@ -261,7 +296,7 @@ class GaussianMixture @Since("2.0.0") (
@Since("2.0.0")
override def fit(dataset: Dataset[_]): GaussianMixtureModel = {
- val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point }
+ val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: OldVector) => point }
val algo = new MLlibGM()
.setK($(k))
@@ -269,8 +304,11 @@ class GaussianMixture @Since("2.0.0") (
.setSeed($(seed))
.setConvergenceTol($(tol))
val parentModel = algo.run(rdd)
- val model = copyValues(new GaussianMixtureModel(uid, parentModel)
- .setParent(this))
+ val gaussians = parentModel.gaussians.map { case g =>
+ new MultivariateGaussian(g.mu.asML, g.sigma.asML)
+ }
+ val model = copyValues(new GaussianMixtureModel(uid, parentModel.weights, gaussians))
+ .setParent(this)
val summary = new GaussianMixtureSummary(model.transform(dataset),
$(predictionCol), $(probabilityCol), $(featuresCol), $(k))
model.setSummary(summary)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
index df6bb411d5..9d868174c1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
@@ -108,8 +108,8 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
test("read/write") {
def checkModelData(model: GaussianMixtureModel, model2: GaussianMixtureModel): Unit = {
assert(model.weights === model2.weights)
- assert(model.gaussians.map(_.mu) === model2.gaussians.map(_.mu))
- assert(model.gaussians.map(_.sigma) === model2.gaussians.map(_.sigma))
+ assert(model.gaussians.map(_.mean) === model2.gaussians.map(_.mean))
+ assert(model.gaussians.map(_.cov) === model2.gaussians.map(_.cov))
}
val gm = new GaussianMixture()
testEstimatorAndModelReadWrite(gm, dataset,
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index 9740ec45af..16ce02ee7d 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -39,8 +39,9 @@ class GaussianMixtureModel(JavaModel, JavaMLWritable, JavaMLReadable):
@since("2.0.0")
def weights(self):
"""
- Weights for each Gaussian distribution in the mixture, where weights[i] is
- the weight for Gaussian i, and weights.sum == 1.
+ Weight for each Gaussian distribution in the mixture.
+ This is a multinomial probability distribution over the k Gaussians,
+ where weights[i] is the weight for Gaussian i, and weights sum to 1.
"""
return self._call_java("weights")
@@ -50,11 +51,7 @@ class GaussianMixtureModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
Retrieve Gaussian distributions as a DataFrame.
Each row represents a Gaussian Distribution.
- Two columns are defined: mean and cov.
- Schema:
- root
- -- mean: vector (nullable = true)
- -- cov: matrix (nullable = true)
+ The DataFrame has two columns: mean (Vector) and cov (Matrix).
"""
return self._call_java("gaussiansDF")