diff options
Diffstat (limited to 'R/pkg/R/mllib.R')
-rw-r--r-- | R/pkg/R/mllib.R | 27 |
1 files changed, 27 insertions, 0 deletions
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index cda6100e79..480301192d 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -364,6 +364,31 @@ setMethod("ml.save", signature(object = "NaiveBayesModel", path = "character"), invisible(callJMethod(writer, "save", path)) }) +#' Save the AFT survival regression model to the input path. +#' +#' @param object A fitted AFT survival regression model +#' @param path The directory where the model is saved +#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' +#' @rdname ml.save +#' @name ml.save +#' @export +#' @examples +#' \dontrun{ +#' model <- survreg(Surv(futime, fustat) ~ ecog_ps + rx, trainingData) +#' path <- "path/to/model" +#' ml.save(model, path) +#' } +setMethod("ml.save", signature(object = "AFTSurvivalRegressionModel", path = "character"), + function(object, path, overwrite = FALSE) { + writer <- callJMethod(object@jobj, "write") + if (overwrite) { + writer <- callJMethod(writer, "overwrite") + } + invisible(callJMethod(writer, "save", path)) + }) + #' Load a fitted MLlib model from the input path. #' #' @param path Path of the model to read. @@ -381,6 +406,8 @@ ml.load <- function(path) { jobj <- callJStatic("org.apache.spark.ml.r.RWrappers", "load", path) if (isInstanceOf(jobj, "org.apache.spark.ml.r.NaiveBayesWrapper")) { return(new("NaiveBayesModel", jobj = jobj)) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.AFTSurvivalRegressionWrapper")) { + return(new("AFTSurvivalRegressionModel", jobj = jobj)) } else { stop(paste("Unsupported model: ", jobj)) } |