From 4d92af310ad29ade039e4130f91f2a3d9180deef Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 17 Aug 2016 11:18:33 -0700 Subject: [SPARK-16446][SPARKR][ML] Gaussian Mixture Model wrapper in SparkR ## What changes were proposed in this pull request? Gaussian Mixture Model wrapper in SparkR, similarly to R's ```mvnormalmixEM```. ## How was this patch tested? Unit test. Author: Yanbo Liang Closes #14392 from yanboliang/spark-16446. --- R/pkg/NAMESPACE | 3 +- R/pkg/R/generics.R | 7 ++ R/pkg/R/mllib.R | 139 ++++++++++++++++++++++++++++++++- R/pkg/inst/tests/testthat/test_mllib.R | 62 +++++++++++++++ 4 files changed, 208 insertions(+), 3 deletions(-) (limited to 'R/pkg') diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 1e23b233c1..c71eec5ce0 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -25,7 +25,8 @@ exportMethods("glm", "fitted", "spark.naiveBayes", "spark.survreg", - "spark.isoreg") + "spark.isoreg", + "spark.gaussianMixture") # Job group lifecycle management methods export("setJobGroup", diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index ebacc11741..06bb25d62d 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1308,6 +1308,13 @@ setGeneric("spark.survreg", function(data, formula, ...) { standardGeneric("spar #' @export setGeneric("spark.isoreg", function(data, formula, ...) { standardGeneric("spark.isoreg") }) +#' @rdname spark.gaussianMixture +#' @export +setGeneric("spark.gaussianMixture", + function(data, formula, ...) { + standardGeneric("spark.gaussianMixture") + }) + #' @rdname write.ml #' @export setGeneric("write.ml", function(object, path, ...) { standardGeneric("write.ml") }) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 0dcc54d7af..db74046056 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -60,6 +60,13 @@ setClass("KMeansModel", representation(jobj = "jobj")) #' @note IsotonicRegressionModel since 2.1.0 setClass("IsotonicRegressionModel", representation(jobj = "jobj")) +#' S4 class that represents a GaussianMixtureModel +#' +#' @param jobj a Java object reference to the backing Scala GaussianMixtureModel +#' @export +#' @note GaussianMixtureModel since 2.1.0 +setClass("GaussianMixtureModel", representation(jobj = "jobj")) + #' Saves the MLlib model to the input path #' #' Saves the MLlib model to the input path. For more information, see the specific @@ -67,7 +74,7 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj")) #' @rdname write.ml #' @name write.ml #' @export -#' @seealso \link{spark.glm}, \link{glm} +#' @seealso \link{spark.glm}, \link{glm}, \link{spark.gaussianMixture} #' @seealso \link{spark.kmeans}, \link{spark.naiveBayes}, \link{spark.survreg} #' @seealso \link{spark.isoreg} #' @seealso \link{read.ml} @@ -80,7 +87,7 @@ NULL #' @rdname predict #' @name predict #' @export -#' @seealso \link{spark.glm}, \link{glm} +#' @seealso \link{spark.glm}, \link{glm}, \link{spark.gaussianMixture} #' @seealso \link{spark.kmeans}, \link{spark.naiveBayes}, \link{spark.survreg} #' @seealso \link{spark.isoreg} NULL @@ -649,6 +656,25 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "char invisible(callJMethod(writer, "save", path)) }) +# Save fitted MLlib model to the input path + +#' @param path the directory where the model is saved. +#' @param overwrite overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' +#' @aliases write.ml,GaussianMixtureModel,character-method +#' @rdname spark.gaussianMixture +#' @export +#' @note write.ml(GaussianMixtureModel, character) since 2.1.0 +setMethod("write.ml", signature(object = "GaussianMixtureModel", path = "character"), + function(object, path, overwrite = FALSE) { + writer <- callJMethod(object@jobj, "write") + if (overwrite) { + writer <- callJMethod(writer, "overwrite") + } + invisible(callJMethod(writer, "save", path)) + }) + #' Load a fitted MLlib model from the input path. #' #' @param path Path of the model to read. @@ -676,6 +702,8 @@ read.ml <- function(path) { return(new("KMeansModel", jobj = jobj)) } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.IsotonicRegressionWrapper")) { return(new("IsotonicRegressionModel", jobj = jobj)) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GaussianMixtureWrapper")) { + return(new("GaussianMixtureModel", jobj = jobj)) } else { stop(paste("Unsupported model: ", jobj)) } @@ -757,3 +785,110 @@ setMethod("predict", signature(object = "AFTSurvivalRegressionModel"), function(object, newData) { return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf))) }) + +#' Multivariate Gaussian Mixture Model (GMM) +#' +#' Fits multivariate gaussian mixture model against a Spark DataFrame, similarly to R's +#' mvnormalmixEM(). Users can call \code{summary} to print a summary of the fitted model, +#' \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} +#' to save/load fitted models. +#' +#' @param data a SparkDataFrame for training. +#' @param formula a symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', '.', ':', '+', and '-'. +#' Note that the response variable of formula is empty in spark.gaussianMixture. +#' @param k number of independent Gaussians in the mixture model. +#' @param maxIter maximum iteration number. +#' @param tol the convergence tolerance. +#' @aliases spark.gaussianMixture,SparkDataFrame,formula-method +#' @return \code{spark.gaussianMixture} returns a fitted multivariate gaussian mixture model. +#' @rdname spark.gaussianMixture +#' @name spark.gaussianMixture +#' @seealso mixtools: \url{https://cran.r-project.org/web/packages/mixtools/} +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' library(mvtnorm) +#' set.seed(100) +#' a <- rmvnorm(4, c(0, 0)) +#' b <- rmvnorm(6, c(3, 4)) +#' data <- rbind(a, b) +#' df <- createDataFrame(as.data.frame(data)) +#' model <- spark.gaussianMixture(df, ~ V1 + V2, k = 2) +#' summary(model) +#' +#' # fitted values on training data +#' fitted <- predict(model, df) +#' head(select(fitted, "V1", "prediction")) +#' +#' # save fitted model to input path +#' path <- "path/to/model" +#' write.ml(model, path) +#' +#' # can also read back the saved model and print +#' savedModel <- read.ml(path) +#' summary(savedModel) +#' } +#' @note spark.gaussianMixture since 2.1.0 +#' @seealso \link{predict}, \link{read.ml}, \link{write.ml} +setMethod("spark.gaussianMixture", signature(data = "SparkDataFrame", formula = "formula"), + function(data, formula, k = 2, maxIter = 100, tol = 0.01) { + formula <- paste(deparse(formula), collapse = "") + jobj <- callJStatic("org.apache.spark.ml.r.GaussianMixtureWrapper", "fit", data@sdf, + formula, as.integer(k), as.integer(maxIter), as.numeric(tol)) + return(new("GaussianMixtureModel", jobj = jobj)) + }) + +# Get the summary of a multivariate gaussian mixture model + +#' @param object a fitted gaussian mixture model. +#' @param ... currently not used argument(s) passed to the method. +#' @return \code{summary} returns the model's lambda, mu, sigma and posterior. +#' @aliases spark.gaussianMixture,SparkDataFrame,formula-method +#' @rdname spark.gaussianMixture +#' @export +#' @note summary(GaussianMixtureModel) since 2.1.0 +setMethod("summary", signature(object = "GaussianMixtureModel"), + function(object, ...) { + jobj <- object@jobj + is.loaded <- callJMethod(jobj, "isLoaded") + lambda <- unlist(callJMethod(jobj, "lambda")) + muList <- callJMethod(jobj, "mu") + sigmaList <- callJMethod(jobj, "sigma") + k <- callJMethod(jobj, "k") + dim <- callJMethod(jobj, "dim") + mu <- c() + for (i in 1 : k) { + start <- (i - 1) * dim + 1 + end <- i * dim + mu[[i]] <- unlist(muList[start : end]) + } + sigma <- c() + for (i in 1 : k) { + start <- (i - 1) * dim * dim + 1 + end <- i * dim * dim + sigma[[i]] <- t(matrix(sigmaList[start : end], ncol = dim)) + } + posterior <- if (is.loaded) { + NULL + } else { + dataFrame(callJMethod(jobj, "posterior")) + } + return(list(lambda = lambda, mu = mu, sigma = sigma, + posterior = posterior, is.loaded = is.loaded)) + }) + +# Predicted values based on a gaussian mixture model + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns a SparkDataFrame containing predicted labels in a column named +#' "prediction". +#' @aliases predict,GaussianMixtureModel,SparkDataFrame-method +#' @rdname spark.gaussianMixture +#' @export +#' @note predict(GaussianMixtureModel) since 2.1.0 +setMethod("predict", signature(object = "GaussianMixtureModel"), + function(object, newData) { + return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf))) + }) diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index b759b28927..96179864a8 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -508,4 +508,66 @@ test_that("spark.isotonicRegression", { unlink(modelPath) }) +test_that("spark.gaussianMixture", { + # R code to reproduce the result. + # nolint start + #' library(mvtnorm) + #' set.seed(100) + #' a <- rmvnorm(4, c(0, 0)) + #' b <- rmvnorm(6, c(3, 4)) + #' data <- rbind(a, b) + #' model <- mvnormalmixEM(data, k = 2) + #' model$lambda + # + # [1] 0.4 0.6 + # + #' model$mu + # + # [1] -0.2614822 0.5128697 + # [1] 2.647284 4.544682 + # + #' model$sigma + # + # [[1]] + # [,1] [,2] + # [1,] 0.08427399 0.00548772 + # [2,] 0.00548772 0.09090715 + # + # [[2]] + # [,1] [,2] + # [1,] 0.1641373 -0.1673806 + # [2,] -0.1673806 0.7508951 + # nolint end + data <- list(list(-0.50219235, 0.1315312), list(-0.07891709, 0.8867848), + list(0.11697127, 0.3186301), list(-0.58179068, 0.7145327), + list(2.17474057, 3.6401379), list(3.08988614, 4.0962745), + list(2.79836605, 4.7398405), list(3.12337950, 3.9706833), + list(2.61114575, 4.5108563), list(2.08618581, 6.3102968)) + df <- createDataFrame(data, c("x1", "x2")) + model <- spark.gaussianMixture(df, ~ x1 + x2, k = 2) + stats <- summary(model) + rLambda <- c(0.4, 0.6) + rMu <- c(-0.2614822, 0.5128697, 2.647284, 4.544682) + rSigma <- c(0.08427399, 0.00548772, 0.00548772, 0.09090715, + 0.1641373, -0.1673806, -0.1673806, 0.7508951) + expect_equal(stats$lambda, rLambda) + expect_equal(unlist(stats$mu), rMu, tolerance = 1e-3) + expect_equal(unlist(stats$sigma), rSigma, tolerance = 1e-3) + p <- collect(select(predict(model, df), "prediction")) + expect_equal(p$prediction, c(0, 0, 0, 0, 1, 1, 1, 1, 1, 1)) + + # Test model save/load + modelPath <- tempfile(pattern = "spark-gaussianMixture", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$lambda, stats2$lambda) + expect_equal(unlist(stats$mu), unlist(stats2$mu)) + expect_equal(unlist(stats$sigma), unlist(stats2$sigma)) + + unlink(modelPath) +}) + sparkR.session.stop() -- cgit v1.2.3