aboutsummaryrefslogtreecommitdiff
path: root/R
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-04-29 09:42:54 -0700
committerXiangrui Meng <meng@databricks.com>2016-04-29 09:43:04 -0700
commit87ac84d43729c54be100bb9ad7dc6e8fa14b8805 (patch)
treed3fbb8c5996a10177fd3af3579d160b6278509ac /R
parenta7d0fedc940721d09350f2e57ae85591e0a3d90e (diff)
downloadspark-87ac84d43729c54be100bb9ad7dc6e8fa14b8805.tar.gz
spark-87ac84d43729c54be100bb9ad7dc6e8fa14b8805.tar.bz2
spark-87ac84d43729c54be100bb9ad7dc6e8fa14b8805.zip
[SPARK-14314][SPARK-14315][ML][SPARKR] Model persistence in SparkR (glm & kmeans)
SparkR ```glm``` and ```kmeans``` model persistence. Unit tests. Author: Yanbo Liang <ybliang8@gmail.com> Author: Gayathri Murali <gayathri.m.softie@gmail.com> Closes #12778 from yanboliang/spark-14311. Closes #12680 Closes #12683
Diffstat (limited to 'R')
-rw-r--r--R/pkg/R/mllib.R98
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib.R41
2 files changed, 127 insertions, 12 deletions
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index 480301192d..c2326ea116 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -99,9 +99,9 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDat
setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"),
function(object, ...) {
jobj <- object@jobj
+ is.loaded <- callJMethod(jobj, "isLoaded")
features <- callJMethod(jobj, "rFeatures")
coefficients <- callJMethod(jobj, "rCoefficients")
- deviance.resid <- callJMethod(jobj, "rDevianceResiduals")
dispersion <- callJMethod(jobj, "rDispersion")
null.deviance <- callJMethod(jobj, "rNullDeviance")
deviance <- callJMethod(jobj, "rDeviance")
@@ -110,15 +110,18 @@ setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"),
aic <- callJMethod(jobj, "rAic")
iter <- callJMethod(jobj, "rNumIterations")
family <- callJMethod(jobj, "rFamily")
-
- deviance.resid <- dataFrame(deviance.resid)
+ deviance.resid <- if (is.loaded) {
+ NULL
+ } else {
+ dataFrame(callJMethod(jobj, "rDevianceResiduals"))
+ }
coefficients <- matrix(coefficients, ncol = 4)
colnames(coefficients) <- c("Estimate", "Std. Error", "t value", "Pr(>|t|)")
rownames(coefficients) <- unlist(features)
ans <- list(deviance.resid = deviance.resid, coefficients = coefficients,
dispersion = dispersion, null.deviance = null.deviance,
deviance = deviance, df.null = df.null, df.residual = df.residual,
- aic = aic, iter = iter, family = family)
+ aic = aic, iter = iter, family = family, is.loaded = is.loaded)
class(ans) <- "summary.GeneralizedLinearRegressionModel"
return(ans)
})
@@ -129,12 +132,16 @@ setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"),
#' @name print.summary.GeneralizedLinearRegressionModel
#' @export
print.summary.GeneralizedLinearRegressionModel <- function(x, ...) {
- x$deviance.resid <- setNames(unlist(approxQuantile(x$deviance.resid, "devianceResiduals",
+ if (x$is.loaded) {
+ cat("\nSaved-loaded model does not support output 'Deviance Residuals'.\n")
+ } else {
+ x$deviance.resid <- setNames(unlist(approxQuantile(x$deviance.resid, "devianceResiduals",
c(0.0, 0.25, 0.5, 0.75, 1.0), 0.01)), c("Min", "1Q", "Median", "3Q", "Max"))
- x$deviance.resid <- zapsmall(x$deviance.resid, 5L)
- cat("\nDeviance Residuals: \n")
- cat("(Note: These are approximate quantiles with relative error <= 0.01)\n")
- print.default(x$deviance.resid, digits = 5L, na.print = "", print.gap = 2L)
+ x$deviance.resid <- zapsmall(x$deviance.resid, 5L)
+ cat("\nDeviance Residuals: \n")
+ cat("(Note: These are approximate quantiles with relative error <= 0.01)\n")
+ print.default(x$deviance.resid, digits = 5L, na.print = "", print.gap = 2L)
+ }
cat("\nCoefficients:\n")
print.default(x$coefficients, digits = 5L, na.print = "", print.gap = 2L)
@@ -246,6 +253,7 @@ setMethod("kmeans", signature(x = "SparkDataFrame"),
#' Get fitted result from a k-means model
#'
#' Get fitted result from a k-means model, similarly to R's fitted().
+#' Note: A saved-loaded model does not support this method.
#'
#' @param object A fitted k-means model
#' @return SparkDataFrame containing fitted values
@@ -260,7 +268,13 @@ setMethod("kmeans", signature(x = "SparkDataFrame"),
setMethod("fitted", signature(object = "KMeansModel"),
function(object, method = c("centers", "classes"), ...) {
method <- match.arg(method)
- return(dataFrame(callJMethod(object@jobj, "fitted", method)))
+ jobj <- object@jobj
+ is.loaded <- callJMethod(jobj, "isLoaded")
+ if (is.loaded) {
+ stop(paste("Saved-loaded k-means model does not support 'fitted' method"))
+ } else {
+ return(dataFrame(callJMethod(jobj, "fitted", method)))
+ }
})
#' Get the summary of a k-means model
@@ -280,15 +294,21 @@ setMethod("fitted", signature(object = "KMeansModel"),
setMethod("summary", signature(object = "KMeansModel"),
function(object, ...) {
jobj <- object@jobj
+ is.loaded <- callJMethod(jobj, "isLoaded")
features <- callJMethod(jobj, "features")
coefficients <- callJMethod(jobj, "coefficients")
- cluster <- callJMethod(jobj, "cluster")
k <- callJMethod(jobj, "k")
size <- callJMethod(jobj, "size")
coefficients <- t(matrix(coefficients, ncol = k))
colnames(coefficients) <- unlist(features)
rownames(coefficients) <- 1:k
- return(list(coefficients = coefficients, size = size, cluster = dataFrame(cluster)))
+ cluster <- if (is.loaded) {
+ NULL
+ } else {
+ dataFrame(callJMethod(jobj, "cluster"))
+ }
+ return(list(coefficients = coefficients, size = size,
+ cluster = cluster, is.loaded = is.loaded))
})
#' Make predictions from a k-means model
@@ -389,6 +409,56 @@ setMethod("ml.save", signature(object = "AFTSurvivalRegressionModel", path = "ch
invisible(callJMethod(writer, "save", path))
})
+#' Save the generalized linear model to the input path.
+#'
+#' @param object A fitted generalized linear model
+#' @param path The directory where the model is saved
+#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
+#' which means throw exception if the output path exists.
+#'
+#' @rdname ml.save
+#' @name ml.save
+#' @export
+#' @examples
+#' \dontrun{
+#' model <- glm(y ~ x, trainingData)
+#' path <- "path/to/model"
+#' ml.save(model, path)
+#' }
+setMethod("ml.save", signature(object = "GeneralizedLinearRegressionModel", path = "character"),
+ function(object, path, overwrite = FALSE) {
+ writer <- callJMethod(object@jobj, "write")
+ if (overwrite) {
+ writer <- callJMethod(writer, "overwrite")
+ }
+ invisible(callJMethod(writer, "save", path))
+ })
+
+#' Save the k-means model to the input path.
+#'
+#' @param object A fitted k-means model
+#' @param path The directory where the model is saved
+#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
+#' which means throw exception if the output path exists.
+#'
+#' @rdname ml.save
+#' @name ml.save
+#' @export
+#' @examples
+#' \dontrun{
+#' model <- kmeans(x, centers = 2, algorithm="random")
+#' path <- "path/to/model"
+#' ml.save(model, path)
+#' }
+setMethod("ml.save", signature(object = "KMeansModel", path = "character"),
+ function(object, path, overwrite = FALSE) {
+ writer <- callJMethod(object@jobj, "write")
+ if (overwrite) {
+ writer <- callJMethod(writer, "overwrite")
+ }
+ invisible(callJMethod(writer, "save", path))
+ })
+
#' Load a fitted MLlib model from the input path.
#'
#' @param path Path of the model to read.
@@ -408,6 +478,10 @@ ml.load <- function(path) {
return(new("NaiveBayesModel", jobj = jobj))
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.AFTSurvivalRegressionWrapper")) {
return(new("AFTSurvivalRegressionModel", jobj = jobj))
+ } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper")) {
+ return(new("GeneralizedLinearRegressionModel", jobj = jobj))
+ } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.KMeansWrapper")) {
+ return(new("KMeansModel", jobj = jobj))
} else {
stop(paste("Unsupported model: ", jobj))
}
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index 954abb00d4..6a822be121 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -126,6 +126,33 @@ test_that("glm summary", {
expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4)
})
+test_that("glm save/load", {
+ training <- suppressWarnings(createDataFrame(sqlContext, iris))
+ m <- glm(Sepal_Width ~ Sepal_Length + Species, data = training)
+ s <- summary(m)
+
+ modelPath <- tempfile(pattern = "glm", fileext = ".tmp")
+ ml.save(m, modelPath)
+ expect_error(ml.save(m, modelPath))
+ ml.save(m, modelPath, overwrite = TRUE)
+ m2 <- ml.load(modelPath)
+ s2 <- summary(m2)
+
+ expect_equal(s$coefficients, s2$coefficients)
+ expect_equal(rownames(s$coefficients), rownames(s2$coefficients))
+ expect_equal(s$dispersion, s2$dispersion)
+ expect_equal(s$null.deviance, s2$null.deviance)
+ expect_equal(s$deviance, s2$deviance)
+ expect_equal(s$df.null, s2$df.null)
+ expect_equal(s$df.residual, s2$df.residual)
+ expect_equal(s$aic, s2$aic)
+ expect_equal(s$iter, s2$iter)
+ expect_true(!s$is.loaded)
+ expect_true(s2$is.loaded)
+
+ unlink(modelPath)
+})
+
test_that("kmeans", {
newIris <- iris
newIris$Species <- NULL
@@ -150,6 +177,20 @@ test_that("kmeans", {
summary.model <- summary(model)
cluster <- summary.model$cluster
expect_equal(sort(collect(distinct(select(cluster, "prediction")))$prediction), c(0, 1))
+
+ # Test model save/load
+ modelPath <- tempfile(pattern = "kmeans", fileext = ".tmp")
+ ml.save(model, modelPath)
+ expect_error(ml.save(model, modelPath))
+ ml.save(model, modelPath, overwrite = TRUE)
+ model2 <- ml.load(modelPath)
+ summary2 <- summary(model2)
+ expect_equal(sort(unlist(summary.model$size)), sort(unlist(summary2$size)))
+ expect_equal(summary.model$coefficients, summary2$coefficients)
+ expect_true(!summary.model$is.loaded)
+ expect_true(summary2$is.loaded)
+
+ unlink(modelPath)
})
test_that("naiveBayes", {