diff options
Diffstat (limited to 'R/pkg/R/mllib.R')
-rw-r--r-- | R/pkg/R/mllib.R | 48 |
1 files changed, 48 insertions, 0 deletions
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 7dd82963a1..cda6100e79 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -338,6 +338,54 @@ setMethod("naiveBayes", signature(formula = "formula", data = "SparkDataFrame"), return(new("NaiveBayesModel", jobj = jobj)) }) +#' Save the Bernoulli naive Bayes model to the input path. +#' +#' @param object A fitted Bernoulli naive Bayes model +#' @param path The directory where the model is saved +#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' +#' @rdname ml.save +#' @name ml.save +#' @export +#' @examples +#' \dontrun{ +#' df <- createDataFrame(sqlContext, infert) +#' model <- naiveBayes(education ~ ., df, laplace = 0) +#' path <- "path/to/model" +#' ml.save(model, path) +#' } +setMethod("ml.save", signature(object = "NaiveBayesModel", path = "character"), + function(object, path, overwrite = FALSE) { + writer <- callJMethod(object@jobj, "write") + if (overwrite) { + writer <- callJMethod(writer, "overwrite") + } + invisible(callJMethod(writer, "save", path)) + }) + +#' Load a fitted MLlib model from the input path. +#' +#' @param path Path of the model to read. +#' @return a fitted MLlib model +#' @rdname ml.load +#' @name ml.load +#' @export +#' @examples +#' \dontrun{ +#' path <- "path/to/model" +#' model <- ml.load(path) +#' } +ml.load <- function(path) { + path <- suppressWarnings(normalizePath(path)) + jobj <- callJStatic("org.apache.spark.ml.r.RWrappers", "load", path) + if (isInstanceOf(jobj, "org.apache.spark.ml.r.NaiveBayesWrapper")) { + return(new("NaiveBayesModel", jobj = jobj)) + } else { + stop(paste("Unsupported model: ", jobj)) + } +} + #' Fit an accelerated failure time (AFT) survival regression model. #' #' Fit an accelerated failure time (AFT) survival regression model, similarly to R's survreg(). |