diff options
author | Zheng RuiFeng <ruifengz@foxmail.com> | 2016-04-06 11:45:16 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-04-06 11:45:16 -0700 |
commit | af73d9737874f7adaec3cd19ac889ab3badb8e2a (patch) | |
tree | 4114096fedd154920fc2de795fb27c47267cd11b /mllib/src | |
parent | 8cffcb60deb82d04a5c6e144ec9927f6f7addc8b (diff) | |
download | spark-af73d9737874f7adaec3cd19ac889ab3badb8e2a.tar.gz spark-af73d9737874f7adaec3cd19ac889ab3badb8e2a.tar.bz2 spark-af73d9737874f7adaec3cd19ac889ab3badb8e2a.zip |
[SPARK-13538][ML] Add GaussianMixture to ML
JIRA: https://issues.apache.org/jira/browse/SPARK-13538
## What changes were proposed in this pull request?
Add GaussianMixture and GaussianMixtureModel to ML package
## How was this patch tested?
unit tests and manual tests were done.
Local Scalastyle checks passed.
Author: Zheng RuiFeng <ruifengz@foxmail.com>
Author: Ruifeng Zheng <ruifengz@foxmail.com>
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #11419 from zhengruifeng/mlgmm.
Diffstat (limited to 'mllib/src')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala | 311 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala | 133 |
2 files changed, 444 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala new file mode 100644 index 0000000000..120bf3cf9d --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -0,0 +1,311 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.clustering + +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param.{IntParam, ParamMap, Params} +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.clustering.{GaussianMixture => MLlibGM, GaussianMixtureModel => MLlibGMModel} +import org.apache.spark.mllib.linalg._ +import org.apache.spark.mllib.stat.distribution.MultivariateGaussian +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.types.{IntegerType, StructType} + + +/** + * Common params for GaussianMixture and GaussianMixtureModel + */ +private[clustering] trait GaussianMixtureParams extends Params with HasMaxIter with HasFeaturesCol + with HasSeed with HasPredictionCol with HasProbabilityCol with HasTol { + + /** + * Set the number of clusters to create (k). Must be > 1. Default: 2. + * @group param + */ + @Since("2.0.0") + final val k = new IntParam(this, "k", "number of clusters to create", (x: Int) => x > 1) + + /** @group getParam */ + @Since("2.0.0") + def getK: Int = $(k) + + /** + * Validates and transforms the input schema. + * @param schema input schema + * @return output schema + */ + protected def validateAndTransformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) + SchemaUtils.appendColumn(schema, $(probabilityCol), new VectorUDT) + } +} + +/** + * :: Experimental :: + * Model fitted by GaussianMixture. + * @param parentModel a model trained by spark.mllib.clustering.GaussianMixture. + */ +@Since("2.0.0") +@Experimental +class GaussianMixtureModel private[ml] ( + @Since("2.0.0") override val uid: String, + private val parentModel: MLlibGMModel) + extends Model[GaussianMixtureModel] with GaussianMixtureParams with MLWritable { + + @Since("2.0.0") + override def copy(extra: ParamMap): GaussianMixtureModel = { + val copied = new GaussianMixtureModel(uid, parentModel) + copyValues(copied, extra).setParent(this.parent) + } + + @Since("2.0.0") + override def transform(dataset: DataFrame): DataFrame = { + val predUDF = udf((vector: Vector) => predict(vector)) + val probUDF = udf((vector: Vector) => predictProbability(vector)) + dataset.withColumn($(predictionCol), predUDF(col($(featuresCol)))) + .withColumn($(probabilityCol), probUDF(col($(featuresCol)))) + } + + @Since("2.0.0") + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + private[clustering] def predict(features: Vector): Int = parentModel.predict(features) + + private[clustering] def predictProbability(features: Vector): Vector = { + Vectors.dense(parentModel.predictSoft(features)) + } + + @Since("2.0.0") + def weights: Array[Double] = parentModel.weights + + @Since("2.0.0") + def gaussians: Array[MultivariateGaussian] = parentModel.gaussians + + @Since("2.0.0") + override def write: MLWriter = new GaussianMixtureModel.GaussianMixtureModelWriter(this) + + private var trainingSummary: Option[GaussianMixtureSummary] = None + + private[clustering] def setSummary(summary: GaussianMixtureSummary): this.type = { + this.trainingSummary = Some(summary) + this + } + + /** + * Return true if there exists summary of model. + */ + @Since("2.0.0") + def hasSummary: Boolean = trainingSummary.nonEmpty + + /** + * Gets summary of model on training set. An exception is + * thrown if `trainingSummary == None`. + */ + @Since("2.0.0") + def summary: GaussianMixtureSummary = trainingSummary.getOrElse { + throw new RuntimeException( + s"No training summary available for the ${this.getClass.getSimpleName}") + } +} + +@Since("2.0.0") +object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] { + + @Since("2.0.0") + override def read: MLReader[GaussianMixtureModel] = new GaussianMixtureModelReader + + @Since("2.0.0") + override def load(path: String): GaussianMixtureModel = super.load(path) + + /** [[MLWriter]] instance for [[GaussianMixtureModel]] */ + private[GaussianMixtureModel] class GaussianMixtureModelWriter( + instance: GaussianMixtureModel) extends MLWriter { + + private case class Data(weights: Array[Double], mus: Array[Vector], sigmas: Array[Matrix]) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: weights and gaussians + val weights = instance.weights + val gaussians = instance.gaussians + val mus = gaussians.map(_.mu) + val sigmas = gaussians.map(_.sigma) + val data = Data(weights, mus, sigmas) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class GaussianMixtureModelReader extends MLReader[GaussianMixtureModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[GaussianMixtureModel].getName + + override def load(path: String): GaussianMixtureModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val row = sqlContext.read.parquet(dataPath).select("weights", "mus", "sigmas").head() + val weights = row.getSeq[Double](0).toArray + val mus = row.getSeq[Vector](1).toArray + val sigmas = row.getSeq[Matrix](2).toArray + require(mus.length == sigmas.length, "Length of Mu and Sigma array must match") + require(mus.length == weights.length, "Length of weight and Gaussian array must match") + + val gaussians = (mus zip sigmas).map { + case (mu, sigma) => + new MultivariateGaussian(mu, sigma) + } + val model = new GaussianMixtureModel(metadata.uid, new MLlibGMModel(weights, gaussians)) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } +} + +/** + * :: Experimental :: + * GaussianMixture clustering. + */ +@Since("2.0.0") +@Experimental +class GaussianMixture @Since("2.0.0") ( + @Since("2.0.0") override val uid: String) + extends Estimator[GaussianMixtureModel] with GaussianMixtureParams with DefaultParamsWritable { + + setDefault( + k -> 2, + maxIter -> 100, + tol -> 0.01) + + @Since("2.0.0") + override def copy(extra: ParamMap): GaussianMixture = defaultCopy(extra) + + @Since("2.0.0") + def this() = this(Identifiable.randomUID("GaussianMixture")) + + /** @group setParam */ + @Since("2.0.0") + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + @Since("2.0.0") + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** @group setParam */ + @Since("2.0.0") + def setProbabilityCol(value: String): this.type = set(probabilityCol, value) + + /** @group setParam */ + @Since("2.0.0") + def setK(value: Int): this.type = set(k, value) + + /** @group setParam */ + @Since("2.0.0") + def setMaxIter(value: Int): this.type = set(maxIter, value) + + /** @group setParam */ + @Since("2.0.0") + def setTol(value: Double): this.type = set(tol, value) + + /** @group setParam */ + @Since("2.0.0") + def setSeed(value: Long): this.type = set(seed, value) + + @Since("2.0.0") + override def fit(dataset: DataFrame): GaussianMixtureModel = { + val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } + + val algo = new MLlibGM() + .setK($(k)) + .setMaxIterations($(maxIter)) + .setSeed($(seed)) + .setConvergenceTol($(tol)) + val parentModel = algo.run(rdd) + val model = copyValues(new GaussianMixtureModel(uid, parentModel).setParent(this)) + val summary = new GaussianMixtureSummary(model.transform(dataset), + $(predictionCol), $(probabilityCol), $(featuresCol), $(k)) + model.setSummary(summary) + } + + @Since("2.0.0") + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } +} + +@Since("2.0.0") +object GaussianMixture extends DefaultParamsReadable[GaussianMixture] { + + @Since("2.0.0") + override def load(path: String): GaussianMixture = super.load(path) +} + +/** + * :: Experimental :: + * Summary of GaussianMixture. + * + * @param predictions [[DataFrame]] produced by [[GaussianMixtureModel.transform()]] + * @param predictionCol Name for column of predicted clusters in `predictions` + * @param probabilityCol Name for column of predicted probability of each cluster in `predictions` + * @param featuresCol Name for column of features in `predictions` + * @param k Number of clusters + */ +@Since("2.0.0") +@Experimental +class GaussianMixtureSummary private[clustering] ( + @Since("2.0.0") @transient val predictions: DataFrame, + @Since("2.0.0") val predictionCol: String, + @Since("2.0.0") val probabilityCol: String, + @Since("2.0.0") val featuresCol: String, + @Since("2.0.0") val k: Int) extends Serializable { + + /** + * Cluster centers of the transformed data. + */ + @Since("2.0.0") + @transient lazy val cluster: DataFrame = predictions.select(predictionCol) + + /** + * Probability of each cluster. + */ + @Since("2.0.0") + @transient lazy val probability: DataFrame = predictions.select(probabilityCol) + + /** + * Size of (number of data points in) each cluster. + */ + @Since("2.0.0") + lazy val clusterSizes: Array[Long] = { + val sizes = Array.fill[Long](k)(0) + cluster.groupBy(predictionCol).count().select(predictionCol, "count").collect().foreach { + case Row(cluster: Int, count: Long) => sizes(cluster) = count + } + sizes + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala new file mode 100644 index 0000000000..8edd44e5f1 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.clustering + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.DataFrame + + +class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultReadWriteTest { + + final val k = 5 + @transient var dataset: DataFrame = _ + + override def beforeAll(): Unit = { + super.beforeAll() + + dataset = KMeansSuite.generateKMeansData(sqlContext, 50, 3, k) + } + + test("default parameters") { + val gm = new GaussianMixture() + + assert(gm.getK === 2) + assert(gm.getFeaturesCol === "features") + assert(gm.getPredictionCol === "prediction") + assert(gm.getMaxIter === 100) + assert(gm.getTol === 0.01) + } + + test("set parameters") { + val gm = new GaussianMixture() + .setK(9) + .setFeaturesCol("test_feature") + .setPredictionCol("test_prediction") + .setProbabilityCol("test_probability") + .setMaxIter(33) + .setSeed(123) + .setTol(1e-3) + + assert(gm.getK === 9) + assert(gm.getFeaturesCol === "test_feature") + assert(gm.getPredictionCol === "test_prediction") + assert(gm.getProbabilityCol === "test_probability") + assert(gm.getMaxIter === 33) + assert(gm.getSeed === 123) + assert(gm.getTol === 1e-3) + } + + test("parameters validation") { + intercept[IllegalArgumentException] { + new GaussianMixture().setK(1) + } + } + + test("fit, transform, and summary") { + val predictionColName = "gm_prediction" + val probabilityColName = "gm_probability" + val gm = new GaussianMixture().setK(k).setMaxIter(2).setPredictionCol(predictionColName) + .setProbabilityCol(probabilityColName).setSeed(1) + val model = gm.fit(dataset) + assert(model.hasParent) + assert(model.weights.length === k) + assert(model.gaussians.length === k) + + val transformed = model.transform(dataset) + val expectedColumns = Array("features", predictionColName, probabilityColName) + expectedColumns.foreach { column => + assert(transformed.columns.contains(column)) + } + + // Check validity of model summary + val numRows = dataset.count() + assert(model.hasSummary) + val summary: GaussianMixtureSummary = model.summary + assert(summary.predictionCol === predictionColName) + assert(summary.probabilityCol === probabilityColName) + assert(summary.featuresCol === "features") + assert(summary.predictions.count() === numRows) + for (c <- Array(predictionColName, probabilityColName, "features")) { + assert(summary.predictions.columns.contains(c)) + } + assert(summary.cluster.columns === Array(predictionColName)) + assert(summary.probability.columns === Array(probabilityColName)) + val clusterSizes = summary.clusterSizes + assert(clusterSizes.length === k) + assert(clusterSizes.sum === numRows) + assert(clusterSizes.forall(_ >= 0)) + } + + test("read/write") { + def checkModelData(model: GaussianMixtureModel, model2: GaussianMixtureModel): Unit = { + assert(model.weights === model2.weights) + assert(model.gaussians.map(_.mu) === model2.gaussians.map(_.mu)) + assert(model.gaussians.map(_.sigma) === model2.gaussians.map(_.sigma)) + } + val gm = new GaussianMixture() + testEstimatorAndModelReadWrite(gm, dataset, + GaussianMixtureSuite.allParamSettings, checkModelData) + } +} + +object GaussianMixtureSuite { + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "predictionCol" -> "myPrediction", + "probabilityCol" -> "myProbability", + "k" -> 3, + "maxIter" -> 2, + "tol" -> 0.01 + ) +} |