diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2016-08-17 11:18:33 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2016-08-17 11:18:33 -0700 |
commit | 4d92af310ad29ade039e4130f91f2a3d9180deef (patch) | |
tree | dc5467bebdb5ad7467387871e0c91f3e820603d5 /mllib/src/main/scala | |
parent | e3fec51fa1ed161789ab7aa32ed36efe357b5d31 (diff) | |
download | spark-4d92af310ad29ade039e4130f91f2a3d9180deef.tar.gz spark-4d92af310ad29ade039e4130f91f2a3d9180deef.tar.bz2 spark-4d92af310ad29ade039e4130f91f2a3d9180deef.zip |
[SPARK-16446][SPARKR][ML] Gaussian Mixture Model wrapper in SparkR
## What changes were proposed in this pull request?
Gaussian Mixture Model wrapper in SparkR, similarly to R's ```mvnormalmixEM```.
## How was this patch tested?
Unit test.
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #14392 from yanboliang/spark-16446.
Diffstat (limited to 'mllib/src/main/scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala | 128 | ||||
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala | 2 |
2 files changed, 130 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala new file mode 100644 index 0000000000..1e8b3bbab6 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.r + +import org.apache.hadoop.fs.Path +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.ml.{Pipeline, PipelineModel} +import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.clustering.{GaussianMixture, GaussianMixtureModel} +import org.apache.spark.ml.feature.RFormula +import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.util.{MLReadable, MLReader, MLWritable, MLWriter} +import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.functions._ + +private[r] class GaussianMixtureWrapper private ( + val pipeline: PipelineModel, + val dim: Int, + val isLoaded: Boolean = false) extends MLWritable { + + private val gmm: GaussianMixtureModel = pipeline.stages(1).asInstanceOf[GaussianMixtureModel] + + lazy val k: Int = gmm.getK + + lazy val lambda: Array[Double] = gmm.weights + + lazy val mu: Array[Double] = gmm.gaussians.flatMap(_.mean.toArray) + + lazy val sigma: Array[Double] = gmm.gaussians.flatMap(_.cov.toArray) + + lazy val vectorToArray = udf { probability: Vector => probability.toArray } + lazy val posterior: DataFrame = gmm.summary.probability + .withColumn("posterior", vectorToArray(col(gmm.summary.probabilityCol))) + .drop(gmm.summary.probabilityCol) + + def transform(dataset: Dataset[_]): DataFrame = { + pipeline.transform(dataset).drop(gmm.getFeaturesCol) + } + + override def write: MLWriter = new GaussianMixtureWrapper.GaussianMixtureWrapperWriter(this) + +} + +private[r] object GaussianMixtureWrapper extends MLReadable[GaussianMixtureWrapper] { + + def fit( + data: DataFrame, + formula: String, + k: Int, + maxIter: Int, + tol: Double): GaussianMixtureWrapper = { + + val rFormulaModel = new RFormula() + .setFormula(formula) + .setFeaturesCol("features") + .fit(data) + + // get feature names from output schema + val schema = rFormulaModel.transform(data).schema + val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol)) + .attributes.get + val features = featureAttrs.map(_.name.get) + val dim = features.length + + val gm = new GaussianMixture() + .setK(k) + .setMaxIter(maxIter) + .setTol(tol) + + val pipeline = new Pipeline() + .setStages(Array(rFormulaModel, gm)) + .fit(data) + + new GaussianMixtureWrapper(pipeline, dim) + } + + override def read: MLReader[GaussianMixtureWrapper] = new GaussianMixtureWrapperReader + + override def load(path: String): GaussianMixtureWrapper = super.load(path) + + class GaussianMixtureWrapperWriter(instance: GaussianMixtureWrapper) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + + val rMetadata = ("class" -> instance.getClass.getName) ~ + ("dim" -> instance.dim) + val rMetadataJson: String = compact(render(rMetadata)) + + sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath) + instance.pipeline.save(pipelinePath) + } + } + + class GaussianMixtureWrapperReader extends MLReader[GaussianMixtureWrapper] { + + override def load(path: String): GaussianMixtureWrapper = { + implicit val format = DefaultFormats + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + val pipeline = PipelineModel.load(pipelinePath) + + val rMetadataStr = sc.textFile(rMetadataPath, 1).first() + val rMetadata = parse(rMetadataStr) + val dim = (rMetadata \ "dim").extract[Int] + new GaussianMixtureWrapper(pipeline, dim, isLoaded = true) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala index f9a44d60e6..88ac26bc5e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala @@ -46,6 +46,8 @@ private[r] object RWrappers extends MLReader[Object] { KMeansWrapper.load(path) case "org.apache.spark.ml.r.IsotonicRegressionWrapper" => IsotonicRegressionWrapper.load(path) + case "org.apache.spark.ml.r.GaussianMixtureWrapper" => + GaussianMixtureWrapper.load(path) case _ => throw new SparkException(s"SparkR read.ml does not support load $className") } |