aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2016-04-26 16:53:16 -0700
committerDB Tsai <dbt@netflix.com>2016-04-26 16:53:16 -0700
commitbd2c9a6d48ef6d489c747d9db2642bdef6b1f728 (patch)
tree9a8a4864825aca4e8f11d4442d33e1ca4f7ac0c4 /mllib
parent0c99c23b7d9f0c3538cd2b062d551411712a2bcc (diff)
downloadspark-bd2c9a6d48ef6d489c747d9db2642bdef6b1f728.tar.gz
spark-bd2c9a6d48ef6d489c747d9db2642bdef6b1f728.tar.bz2
spark-bd2c9a6d48ef6d489c747d9db2642bdef6b1f728.zip
[SPARK-14732][ML] spark.ml GaussianMixture should use MultivariateGaussian in mllib-local
## What changes were proposed in this pull request? Before, spark.ml GaussianMixtureModel used the spark.mllib MultivariateGaussian in its public API. This was added after 1.6, so we can modify this API without breaking APIs. This PR copies MultivariateGaussian to mllib-local in spark.ml, with a few changes: * Renamed fields to match numpy, scipy: mu => mean, sigma => cov This PR then uses the spark.ml MultivariateGaussian in the spark.ml GaussianMixtureModel, which involves: * Modifying the constructor * Adding a computeProbabilities method Also: * Added EPSILON to mllib-local for use in MultivariateGaussian ## How was this patch tested? Existing unit tests Author: Joseph K. Bradley <joseph@databricks.com> Closes #12593 from jkbradley/sparkml-gmm-fix.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala108
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala4
2 files changed, 75 insertions, 37 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
index dfbc8b612c..ac86e4ce25 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
@@ -17,17 +17,21 @@
package org.apache.spark.ml.clustering
+import breeze.linalg.{DenseVector => BDV}
import org.apache.hadoop.fs.Path
import org.apache.spark.SparkContext
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.impl.Utils.EPSILON
+import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param.{IntParam, ParamMap, Params}
import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.stat.distribution.MultivariateGaussian
import org.apache.spark.ml.util._
-import org.apache.spark.mllib.clustering.{GaussianMixture => MLlibGM, GaussianMixtureModel => MLlibGMModel}
-import org.apache.spark.mllib.linalg._
-import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
+import org.apache.spark.mllib.clustering.{GaussianMixture => MLlibGM}
+import org.apache.spark.mllib.linalg.{Matrices => OldMatrices, Matrix => OldMatrix,
+ Vector => OldVector, Vectors => OldVectors, VectorUDT => OldVectorUDT}
import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{IntegerType, StructType}
@@ -56,34 +60,42 @@ private[clustering] trait GaussianMixtureParams extends Params with HasMaxIter w
* @return output schema
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
- SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
+ SchemaUtils.checkColumnType(schema, $(featuresCol), new OldVectorUDT)
SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
- SchemaUtils.appendColumn(schema, $(probabilityCol), new VectorUDT)
+ SchemaUtils.appendColumn(schema, $(probabilityCol), new OldVectorUDT)
}
}
/**
* :: Experimental ::
- * Model fitted by GaussianMixture.
- * @param parentModel a model trained by spark.mllib.clustering.GaussianMixture.
+ *
+ * Multivariate Gaussian Mixture Model (GMM) consisting of k Gaussians, where points
+ * are drawn from each Gaussian i with probability weights(i).
+ *
+ * @param weights Weight for each Gaussian distribution in the mixture.
+ * This is a multinomial probability distribution over the k Gaussians,
+ * where weights(i) is the weight for Gaussian i, and weights sum to 1.
+ * @param gaussians Array of [[MultivariateGaussian]] where gaussians(i) represents
+ * the Multivariate Gaussian (Normal) Distribution for Gaussian i
*/
@Since("2.0.0")
@Experimental
class GaussianMixtureModel private[ml] (
@Since("2.0.0") override val uid: String,
- private val parentModel: MLlibGMModel)
+ @Since("2.0.0") val weights: Array[Double],
+ @Since("2.0.0") val gaussians: Array[MultivariateGaussian])
extends Model[GaussianMixtureModel] with GaussianMixtureParams with MLWritable {
@Since("2.0.0")
override def copy(extra: ParamMap): GaussianMixtureModel = {
- val copied = new GaussianMixtureModel(uid, parentModel)
+ val copied = new GaussianMixtureModel(uid, weights, gaussians)
copyValues(copied, extra).setParent(this.parent)
}
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
- val predUDF = udf((vector: Vector) => predict(vector))
- val probUDF = udf((vector: Vector) => predictProbability(vector))
+ val predUDF = udf((vector: OldVector) => predict(vector.asML))
+ val probUDF = udf((vector: OldVector) => OldVectors.fromML(predictProbability(vector.asML)))
dataset.withColumn($(predictionCol), predUDF(col($(featuresCol))))
.withColumn($(probabilityCol), probUDF(col($(featuresCol))))
}
@@ -93,33 +105,32 @@ class GaussianMixtureModel private[ml] (
validateAndTransformSchema(schema)
}
- private[clustering] def predict(features: Vector): Int = parentModel.predict(features)
+ private[clustering] def predict(features: Vector): Int = {
+ val r = predictProbability(features)
+ r.argmax
+ }
private[clustering] def predictProbability(features: Vector): Vector = {
- Vectors.dense(parentModel.predictSoft(features))
+ val probs: Array[Double] =
+ GaussianMixtureModel.computeProbabilities(features.toBreeze.toDenseVector, gaussians, weights)
+ Vectors.dense(probs)
}
- @Since("2.0.0")
- def weights: Array[Double] = parentModel.weights
-
- @Since("2.0.0")
- def gaussians: Array[MultivariateGaussian] = parentModel.gaussians
-
/**
* Retrieve Gaussian distributions as a DataFrame.
* Each row represents a Gaussian Distribution.
* Two columns are defined: mean and cov.
* Schema:
* {{{
- * root
- * |-- mean: vector (nullable = true)
- * |-- cov: matrix (nullable = true)
+ * root
+ * |-- mean: vector (nullable = true)
+ * |-- cov: matrix (nullable = true)
* }}}
*/
@Since("2.0.0")
def gaussiansDF: DataFrame = {
val modelGaussians = gaussians.map { gaussian =>
- (gaussian.mu, gaussian.sigma)
+ (OldVectors.fromML(gaussian.mean), OldMatrices.fromML(gaussian.cov))
}
val sc = SparkContext.getOrCreate()
val sqlContext = SQLContext.getOrCreate(sc)
@@ -166,7 +177,7 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] {
private[GaussianMixtureModel] class GaussianMixtureModelWriter(
instance: GaussianMixtureModel) extends MLWriter {
- private case class Data(weights: Array[Double], mus: Array[Vector], sigmas: Array[Matrix])
+ private case class Data(weights: Array[Double], mus: Array[OldVector], sigmas: Array[OldMatrix])
override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
@@ -174,8 +185,8 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] {
// Save model data: weights and gaussians
val weights = instance.weights
val gaussians = instance.gaussians
- val mus = gaussians.map(_.mu)
- val sigmas = gaussians.map(_.sigma)
+ val mus = gaussians.map(g => OldVectors.fromML(g.mean))
+ val sigmas = gaussians.map(c => OldMatrices.fromML(c.cov))
val data = Data(weights, mus, sigmas)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
@@ -193,26 +204,50 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] {
val dataPath = new Path(path, "data").toString
val row = sqlContext.read.parquet(dataPath).select("weights", "mus", "sigmas").head()
val weights = row.getSeq[Double](0).toArray
- val mus = row.getSeq[Vector](1).toArray
- val sigmas = row.getSeq[Matrix](2).toArray
+ val mus = row.getSeq[OldVector](1).toArray
+ val sigmas = row.getSeq[OldMatrix](2).toArray
require(mus.length == sigmas.length, "Length of Mu and Sigma array must match")
require(mus.length == weights.length, "Length of weight and Gaussian array must match")
- val gaussians = (mus zip sigmas).map {
+ val gaussians = mus.zip(sigmas).map {
case (mu, sigma) =>
- new MultivariateGaussian(mu, sigma)
+ new MultivariateGaussian(mu.asML, sigma.asML)
}
- val model = new GaussianMixtureModel(metadata.uid, new MLlibGMModel(weights, gaussians))
+ val model = new GaussianMixtureModel(metadata.uid, weights, gaussians)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
}
+
+ /**
+ * Compute the probability (partial assignment) for each cluster for the given data point.
+ * @param features Data point
+ * @param dists Gaussians for model
+ * @param weights Weights for each Gaussian
+ * @return Probability (partial assignment) for each of the k clusters
+ */
+ private[clustering]
+ def computeProbabilities(
+ features: BDV[Double],
+ dists: Array[MultivariateGaussian],
+ weights: Array[Double]): Array[Double] = {
+ val p = weights.zip(dists).map {
+ case (weight, dist) => EPSILON + weight * dist.pdf(features)
+ }
+ val pSum = p.sum
+ var i = 0
+ while (i < weights.length) {
+ p(i) /= pSum
+ i += 1
+ }
+ p
+ }
}
/**
* :: Experimental ::
- * GaussianMixture clustering.
+ * Gaussian Mixture clustering.
*/
@Since("2.0.0")
@Experimental
@@ -261,7 +296,7 @@ class GaussianMixture @Since("2.0.0") (
@Since("2.0.0")
override def fit(dataset: Dataset[_]): GaussianMixtureModel = {
- val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point }
+ val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: OldVector) => point }
val algo = new MLlibGM()
.setK($(k))
@@ -269,8 +304,11 @@ class GaussianMixture @Since("2.0.0") (
.setSeed($(seed))
.setConvergenceTol($(tol))
val parentModel = algo.run(rdd)
- val model = copyValues(new GaussianMixtureModel(uid, parentModel)
- .setParent(this))
+ val gaussians = parentModel.gaussians.map { case g =>
+ new MultivariateGaussian(g.mu.asML, g.sigma.asML)
+ }
+ val model = copyValues(new GaussianMixtureModel(uid, parentModel.weights, gaussians))
+ .setParent(this)
val summary = new GaussianMixtureSummary(model.transform(dataset),
$(predictionCol), $(probabilityCol), $(featuresCol), $(k))
model.setSummary(summary)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
index df6bb411d5..9d868174c1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
@@ -108,8 +108,8 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
test("read/write") {
def checkModelData(model: GaussianMixtureModel, model2: GaussianMixtureModel): Unit = {
assert(model.weights === model2.weights)
- assert(model.gaussians.map(_.mu) === model2.gaussians.map(_.mu))
- assert(model.gaussians.map(_.sigma) === model2.gaussians.map(_.sigma))
+ assert(model.gaussians.map(_.mean) === model2.gaussians.map(_.mean))
+ assert(model.gaussians.map(_.cov) === model2.gaussians.map(_.cov))
}
val gm = new GaussianMixture()
testEstimatorAndModelReadWrite(gm, dataset,