aboutsummaryrefslogtreecommitdiff
path: root/R
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-08-17 11:18:33 -0700
committerXiangrui Meng <meng@databricks.com>2016-08-17 11:18:33 -0700
commit4d92af310ad29ade039e4130f91f2a3d9180deef (patch)
treedc5467bebdb5ad7467387871e0c91f3e820603d5 /R
parente3fec51fa1ed161789ab7aa32ed36efe357b5d31 (diff)
downloadspark-4d92af310ad29ade039e4130f91f2a3d9180deef.tar.gz
spark-4d92af310ad29ade039e4130f91f2a3d9180deef.tar.bz2
spark-4d92af310ad29ade039e4130f91f2a3d9180deef.zip
[SPARK-16446][SPARKR][ML] Gaussian Mixture Model wrapper in SparkR
## What changes were proposed in this pull request? Gaussian Mixture Model wrapper in SparkR, similarly to R's ```mvnormalmixEM```. ## How was this patch tested? Unit test. Author: Yanbo Liang <ybliang8@gmail.com> Closes #14392 from yanboliang/spark-16446.
Diffstat (limited to 'R')
-rw-r--r--R/pkg/NAMESPACE3
-rw-r--r--R/pkg/R/generics.R7
-rw-r--r--R/pkg/R/mllib.R139
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib.R62
4 files changed, 208 insertions, 3 deletions
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 1e23b233c1..c71eec5ce0 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -25,7 +25,8 @@ exportMethods("glm",
"fitted",
"spark.naiveBayes",
"spark.survreg",
- "spark.isoreg")
+ "spark.isoreg",
+ "spark.gaussianMixture")
# Job group lifecycle management methods
export("setJobGroup",
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index ebacc11741..06bb25d62d 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -1308,6 +1308,13 @@ setGeneric("spark.survreg", function(data, formula, ...) { standardGeneric("spar
#' @export
setGeneric("spark.isoreg", function(data, formula, ...) { standardGeneric("spark.isoreg") })
+#' @rdname spark.gaussianMixture
+#' @export
+setGeneric("spark.gaussianMixture",
+ function(data, formula, ...) {
+ standardGeneric("spark.gaussianMixture")
+ })
+
#' @rdname write.ml
#' @export
setGeneric("write.ml", function(object, path, ...) { standardGeneric("write.ml") })
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index 0dcc54d7af..db74046056 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -60,6 +60,13 @@ setClass("KMeansModel", representation(jobj = "jobj"))
#' @note IsotonicRegressionModel since 2.1.0
setClass("IsotonicRegressionModel", representation(jobj = "jobj"))
+#' S4 class that represents a GaussianMixtureModel
+#'
+#' @param jobj a Java object reference to the backing Scala GaussianMixtureModel
+#' @export
+#' @note GaussianMixtureModel since 2.1.0
+setClass("GaussianMixtureModel", representation(jobj = "jobj"))
+
#' Saves the MLlib model to the input path
#'
#' Saves the MLlib model to the input path. For more information, see the specific
@@ -67,7 +74,7 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj"))
#' @rdname write.ml
#' @name write.ml
#' @export
-#' @seealso \link{spark.glm}, \link{glm}
+#' @seealso \link{spark.glm}, \link{glm}, \link{spark.gaussianMixture}
#' @seealso \link{spark.kmeans}, \link{spark.naiveBayes}, \link{spark.survreg}
#' @seealso \link{spark.isoreg}
#' @seealso \link{read.ml}
@@ -80,7 +87,7 @@ NULL
#' @rdname predict
#' @name predict
#' @export
-#' @seealso \link{spark.glm}, \link{glm}
+#' @seealso \link{spark.glm}, \link{glm}, \link{spark.gaussianMixture}
#' @seealso \link{spark.kmeans}, \link{spark.naiveBayes}, \link{spark.survreg}
#' @seealso \link{spark.isoreg}
NULL
@@ -649,6 +656,25 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "char
invisible(callJMethod(writer, "save", path))
})
+# Save fitted MLlib model to the input path
+
+#' @param path the directory where the model is saved.
+#' @param overwrite overwrites or not if the output path already exists. Default is FALSE
+#' which means throw exception if the output path exists.
+#'
+#' @aliases write.ml,GaussianMixtureModel,character-method
+#' @rdname spark.gaussianMixture
+#' @export
+#' @note write.ml(GaussianMixtureModel, character) since 2.1.0
+setMethod("write.ml", signature(object = "GaussianMixtureModel", path = "character"),
+ function(object, path, overwrite = FALSE) {
+ writer <- callJMethod(object@jobj, "write")
+ if (overwrite) {
+ writer <- callJMethod(writer, "overwrite")
+ }
+ invisible(callJMethod(writer, "save", path))
+ })
+
#' Load a fitted MLlib model from the input path.
#'
#' @param path Path of the model to read.
@@ -676,6 +702,8 @@ read.ml <- function(path) {
return(new("KMeansModel", jobj = jobj))
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.IsotonicRegressionWrapper")) {
return(new("IsotonicRegressionModel", jobj = jobj))
+ } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GaussianMixtureWrapper")) {
+ return(new("GaussianMixtureModel", jobj = jobj))
} else {
stop(paste("Unsupported model: ", jobj))
}
@@ -757,3 +785,110 @@ setMethod("predict", signature(object = "AFTSurvivalRegressionModel"),
function(object, newData) {
return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
})
+
+#' Multivariate Gaussian Mixture Model (GMM)
+#'
+#' Fits multivariate gaussian mixture model against a Spark DataFrame, similarly to R's
+#' mvnormalmixEM(). Users can call \code{summary} to print a summary of the fitted model,
+#' \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml}
+#' to save/load fitted models.
+#'
+#' @param data a SparkDataFrame for training.
+#' @param formula a symbolic description of the model to be fitted. Currently only a few formula
+#' operators are supported, including '~', '.', ':', '+', and '-'.
+#' Note that the response variable of formula is empty in spark.gaussianMixture.
+#' @param k number of independent Gaussians in the mixture model.
+#' @param maxIter maximum iteration number.
+#' @param tol the convergence tolerance.
+#' @aliases spark.gaussianMixture,SparkDataFrame,formula-method
+#' @return \code{spark.gaussianMixture} returns a fitted multivariate gaussian mixture model.
+#' @rdname spark.gaussianMixture
+#' @name spark.gaussianMixture
+#' @seealso mixtools: \url{https://cran.r-project.org/web/packages/mixtools/}
+#' @export
+#' @examples
+#' \dontrun{
+#' sparkR.session()
+#' library(mvtnorm)
+#' set.seed(100)
+#' a <- rmvnorm(4, c(0, 0))
+#' b <- rmvnorm(6, c(3, 4))
+#' data <- rbind(a, b)
+#' df <- createDataFrame(as.data.frame(data))
+#' model <- spark.gaussianMixture(df, ~ V1 + V2, k = 2)
+#' summary(model)
+#'
+#' # fitted values on training data
+#' fitted <- predict(model, df)
+#' head(select(fitted, "V1", "prediction"))
+#'
+#' # save fitted model to input path
+#' path <- "path/to/model"
+#' write.ml(model, path)
+#'
+#' # can also read back the saved model and print
+#' savedModel <- read.ml(path)
+#' summary(savedModel)
+#' }
+#' @note spark.gaussianMixture since 2.1.0
+#' @seealso \link{predict}, \link{read.ml}, \link{write.ml}
+setMethod("spark.gaussianMixture", signature(data = "SparkDataFrame", formula = "formula"),
+ function(data, formula, k = 2, maxIter = 100, tol = 0.01) {
+ formula <- paste(deparse(formula), collapse = "")
+ jobj <- callJStatic("org.apache.spark.ml.r.GaussianMixtureWrapper", "fit", data@sdf,
+ formula, as.integer(k), as.integer(maxIter), as.numeric(tol))
+ return(new("GaussianMixtureModel", jobj = jobj))
+ })
+
+# Get the summary of a multivariate gaussian mixture model
+
+#' @param object a fitted gaussian mixture model.
+#' @param ... currently not used argument(s) passed to the method.
+#' @return \code{summary} returns the model's lambda, mu, sigma and posterior.
+#' @aliases spark.gaussianMixture,SparkDataFrame,formula-method
+#' @rdname spark.gaussianMixture
+#' @export
+#' @note summary(GaussianMixtureModel) since 2.1.0
+setMethod("summary", signature(object = "GaussianMixtureModel"),
+ function(object, ...) {
+ jobj <- object@jobj
+ is.loaded <- callJMethod(jobj, "isLoaded")
+ lambda <- unlist(callJMethod(jobj, "lambda"))
+ muList <- callJMethod(jobj, "mu")
+ sigmaList <- callJMethod(jobj, "sigma")
+ k <- callJMethod(jobj, "k")
+ dim <- callJMethod(jobj, "dim")
+ mu <- c()
+ for (i in 1 : k) {
+ start <- (i - 1) * dim + 1
+ end <- i * dim
+ mu[[i]] <- unlist(muList[start : end])
+ }
+ sigma <- c()
+ for (i in 1 : k) {
+ start <- (i - 1) * dim * dim + 1
+ end <- i * dim * dim
+ sigma[[i]] <- t(matrix(sigmaList[start : end], ncol = dim))
+ }
+ posterior <- if (is.loaded) {
+ NULL
+ } else {
+ dataFrame(callJMethod(jobj, "posterior"))
+ }
+ return(list(lambda = lambda, mu = mu, sigma = sigma,
+ posterior = posterior, is.loaded = is.loaded))
+ })
+
+# Predicted values based on a gaussian mixture model
+
+#' @param newData a SparkDataFrame for testing.
+#' @return \code{predict} returns a SparkDataFrame containing predicted labels in a column named
+#' "prediction".
+#' @aliases predict,GaussianMixtureModel,SparkDataFrame-method
+#' @rdname spark.gaussianMixture
+#' @export
+#' @note predict(GaussianMixtureModel) since 2.1.0
+setMethod("predict", signature(object = "GaussianMixtureModel"),
+ function(object, newData) {
+ return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
+ })
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index b759b28927..96179864a8 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -508,4 +508,66 @@ test_that("spark.isotonicRegression", {
unlink(modelPath)
})
+test_that("spark.gaussianMixture", {
+ # R code to reproduce the result.
+ # nolint start
+ #' library(mvtnorm)
+ #' set.seed(100)
+ #' a <- rmvnorm(4, c(0, 0))
+ #' b <- rmvnorm(6, c(3, 4))
+ #' data <- rbind(a, b)
+ #' model <- mvnormalmixEM(data, k = 2)
+ #' model$lambda
+ #
+ # [1] 0.4 0.6
+ #
+ #' model$mu
+ #
+ # [1] -0.2614822 0.5128697
+ # [1] 2.647284 4.544682
+ #
+ #' model$sigma
+ #
+ # [[1]]
+ # [,1] [,2]
+ # [1,] 0.08427399 0.00548772
+ # [2,] 0.00548772 0.09090715
+ #
+ # [[2]]
+ # [,1] [,2]
+ # [1,] 0.1641373 -0.1673806
+ # [2,] -0.1673806 0.7508951
+ # nolint end
+ data <- list(list(-0.50219235, 0.1315312), list(-0.07891709, 0.8867848),
+ list(0.11697127, 0.3186301), list(-0.58179068, 0.7145327),
+ list(2.17474057, 3.6401379), list(3.08988614, 4.0962745),
+ list(2.79836605, 4.7398405), list(3.12337950, 3.9706833),
+ list(2.61114575, 4.5108563), list(2.08618581, 6.3102968))
+ df <- createDataFrame(data, c("x1", "x2"))
+ model <- spark.gaussianMixture(df, ~ x1 + x2, k = 2)
+ stats <- summary(model)
+ rLambda <- c(0.4, 0.6)
+ rMu <- c(-0.2614822, 0.5128697, 2.647284, 4.544682)
+ rSigma <- c(0.08427399, 0.00548772, 0.00548772, 0.09090715,
+ 0.1641373, -0.1673806, -0.1673806, 0.7508951)
+ expect_equal(stats$lambda, rLambda)
+ expect_equal(unlist(stats$mu), rMu, tolerance = 1e-3)
+ expect_equal(unlist(stats$sigma), rSigma, tolerance = 1e-3)
+ p <- collect(select(predict(model, df), "prediction"))
+ expect_equal(p$prediction, c(0, 0, 0, 0, 1, 1, 1, 1, 1, 1))
+
+ # Test model save/load
+ modelPath <- tempfile(pattern = "spark-gaussianMixture", fileext = ".tmp")
+ write.ml(model, modelPath)
+ expect_error(write.ml(model, modelPath))
+ write.ml(model, modelPath, overwrite = TRUE)
+ model2 <- read.ml(modelPath)
+ stats2 <- summary(model2)
+ expect_equal(stats$lambda, stats2$lambda)
+ expect_equal(unlist(stats$mu), unlist(stats2$mu))
+ expect_equal(unlist(stats$sigma), unlist(stats2$sigma))
+
+ unlink(modelPath)
+})
+
sparkR.session.stop()