aboutsummaryrefslogtreecommitdiff
path: root/R/pkg
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-04-26 10:30:24 -0700
committerXiangrui Meng <meng@databricks.com>2016-04-26 10:30:24 -0700
commit92f66331b4ba3634f54f57ddb5e7962b14aa4ca1 (patch)
tree55ea47996ed3688041cea019c24908931c31eddf /R/pkg
parent162cf02efa025fdb32adc3eaabb8e4232fe90e08 (diff)
downloadspark-92f66331b4ba3634f54f57ddb5e7962b14aa4ca1.tar.gz
spark-92f66331b4ba3634f54f57ddb5e7962b14aa4ca1.tar.bz2
spark-92f66331b4ba3634f54f57ddb5e7962b14aa4ca1.zip
[SPARK-14313][ML][SPARKR] AFTSurvivalRegression model persistence in SparkR
## What changes were proposed in this pull request? ```AFTSurvivalRegressionModel``` supports ```save/load``` in SparkR. ## How was this patch tested? Unit tests. Author: Yanbo Liang <ybliang8@gmail.com> Closes #12685 from yanboliang/spark-14313.
Diffstat (limited to 'R/pkg')
-rw-r--r--R/pkg/R/mllib.R27
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib.R13
2 files changed, 40 insertions, 0 deletions
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index cda6100e79..480301192d 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -364,6 +364,31 @@ setMethod("ml.save", signature(object = "NaiveBayesModel", path = "character"),
invisible(callJMethod(writer, "save", path))
})
+#' Save the AFT survival regression model to the input path.
+#'
+#' @param object A fitted AFT survival regression model
+#' @param path The directory where the model is saved
+#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
+#' which means throw exception if the output path exists.
+#'
+#' @rdname ml.save
+#' @name ml.save
+#' @export
+#' @examples
+#' \dontrun{
+#' model <- survreg(Surv(futime, fustat) ~ ecog_ps + rx, trainingData)
+#' path <- "path/to/model"
+#' ml.save(model, path)
+#' }
+setMethod("ml.save", signature(object = "AFTSurvivalRegressionModel", path = "character"),
+ function(object, path, overwrite = FALSE) {
+ writer <- callJMethod(object@jobj, "write")
+ if (overwrite) {
+ writer <- callJMethod(writer, "overwrite")
+ }
+ invisible(callJMethod(writer, "save", path))
+ })
+
#' Load a fitted MLlib model from the input path.
#'
#' @param path Path of the model to read.
@@ -381,6 +406,8 @@ ml.load <- function(path) {
jobj <- callJStatic("org.apache.spark.ml.r.RWrappers", "load", path)
if (isInstanceOf(jobj, "org.apache.spark.ml.r.NaiveBayesWrapper")) {
return(new("NaiveBayesModel", jobj = jobj))
+ } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.AFTSurvivalRegressionWrapper")) {
+ return(new("AFTSurvivalRegressionModel", jobj = jobj))
} else {
stop(paste("Unsupported model: ", jobj))
}
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index 63ec84e497..954abb00d4 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -261,6 +261,19 @@ test_that("survreg", {
expect_equal(p$prediction, c(3.724591, 2.545368, 3.079035, 3.079035,
2.390146, 2.891269, 2.891269), tolerance = 1e-4)
+ # Test model save/load
+ modelPath <- tempfile(pattern = "survreg", fileext = ".tmp")
+ ml.save(model, modelPath)
+ expect_error(ml.save(model, modelPath))
+ ml.save(model, modelPath, overwrite = TRUE)
+ model2 <- ml.load(modelPath)
+ stats2 <- summary(model2)
+ coefs2 <- as.vector(stats2$coefficients[, 1])
+ expect_equal(coefs, coefs2)
+ expect_equal(rownames(stats$coefficients), rownames(stats2$coefficients))
+
+ unlink(modelPath)
+
# Test survival::survreg
if (requireNamespace("survival", quietly = TRUE)) {
rData <- list(time = c(4, 3, 1, 1, 2, 2, 3), status = c(1, 1, 1, 0, 1, 1, 0),