aboutsummaryrefslogtreecommitdiff
path: root/R/pkg/R/mllib.R
diff options
context:
space:
mode:
Diffstat (limited to 'R/pkg/R/mllib.R')
-rw-r--r--R/pkg/R/mllib.R48
1 files changed, 48 insertions, 0 deletions
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index 7dd82963a1..cda6100e79 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -338,6 +338,54 @@ setMethod("naiveBayes", signature(formula = "formula", data = "SparkDataFrame"),
return(new("NaiveBayesModel", jobj = jobj))
})
+#' Save the Bernoulli naive Bayes model to the input path.
+#'
+#' @param object A fitted Bernoulli naive Bayes model
+#' @param path The directory where the model is saved
+#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
+#' which means throw exception if the output path exists.
+#'
+#' @rdname ml.save
+#' @name ml.save
+#' @export
+#' @examples
+#' \dontrun{
+#' df <- createDataFrame(sqlContext, infert)
+#' model <- naiveBayes(education ~ ., df, laplace = 0)
+#' path <- "path/to/model"
+#' ml.save(model, path)
+#' }
+setMethod("ml.save", signature(object = "NaiveBayesModel", path = "character"),
+ function(object, path, overwrite = FALSE) {
+ writer <- callJMethod(object@jobj, "write")
+ if (overwrite) {
+ writer <- callJMethod(writer, "overwrite")
+ }
+ invisible(callJMethod(writer, "save", path))
+ })
+
+#' Load a fitted MLlib model from the input path.
+#'
+#' @param path Path of the model to read.
+#' @return a fitted MLlib model
+#' @rdname ml.load
+#' @name ml.load
+#' @export
+#' @examples
+#' \dontrun{
+#' path <- "path/to/model"
+#' model <- ml.load(path)
+#' }
+ml.load <- function(path) {
+ path <- suppressWarnings(normalizePath(path))
+ jobj <- callJStatic("org.apache.spark.ml.r.RWrappers", "load", path)
+ if (isInstanceOf(jobj, "org.apache.spark.ml.r.NaiveBayesWrapper")) {
+ return(new("NaiveBayesModel", jobj = jobj))
+ } else {
+ stop(paste("Unsupported model: ", jobj))
+ }
+}
+
#' Fit an accelerated failure time (AFT) survival regression model.
#'
#' Fit an accelerated failure time (AFT) survival regression model, similarly to R's survreg().