aboutsummaryrefslogtreecommitdiff
path: root/R/pkg
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-04-25 14:08:41 -0700
committerXiangrui Meng <meng@databricks.com>2016-04-25 14:08:41 -0700
commit9cb3ba1013a7eae11be8a00fa4a9c5308bb20195 (patch)
treeeb275db612f3bc4f438aa426bb49c528d6fc0fe9 /R/pkg
parent0c47e274ab8c286498fa002e2c92febcb53905c6 (diff)
downloadspark-9cb3ba1013a7eae11be8a00fa4a9c5308bb20195.tar.gz
spark-9cb3ba1013a7eae11be8a00fa4a9c5308bb20195.tar.bz2
spark-9cb3ba1013a7eae11be8a00fa4a9c5308bb20195.zip
[SPARK-14312][ML][SPARKR] NaiveBayes model persistence in SparkR
## What changes were proposed in this pull request? SparkR ```NaiveBayesModel``` supports ```save/load``` by the following API: ``` df <- createDataFrame(sqlContext, infert) model <- naiveBayes(education ~ ., df, laplace = 0) ml.save(model, path) model2 <- ml.load(path) ``` ## How was this patch tested? Add unit tests. cc mengxr Author: Yanbo Liang <ybliang8@gmail.com> Closes #12573 from yanboliang/spark-14312.
Diffstat (limited to 'R/pkg')
-rw-r--r--R/pkg/NAMESPACE6
-rw-r--r--R/pkg/R/generics.R4
-rw-r--r--R/pkg/R/mllib.R48
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib.R12
4 files changed, 68 insertions, 2 deletions
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 0f92b5e597..c0a63d6b3e 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -107,7 +107,8 @@ exportMethods("arrange",
"write.jdbc",
"write.json",
"write.parquet",
- "write.text")
+ "write.text",
+ "ml.save")
exportClasses("Column")
@@ -299,7 +300,8 @@ export("as.DataFrame",
"tableNames",
"tables",
"uncacheTable",
- "print.summary.GeneralizedLinearRegressionModel")
+ "print.summary.GeneralizedLinearRegressionModel",
+ "ml.load")
export("structField",
"structField.jobj",
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 04274a12bc..f654d8330c 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -1200,3 +1200,7 @@ setGeneric("naiveBayes", function(formula, data, ...) { standardGeneric("naiveBa
#' @rdname survreg
#' @export
setGeneric("survreg", function(formula, data, ...) { standardGeneric("survreg") })
+
+#' @rdname ml.save
+#' @export
+setGeneric("ml.save", function(object, path, ...) { standardGeneric("ml.save") })
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index 7dd82963a1..cda6100e79 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -338,6 +338,54 @@ setMethod("naiveBayes", signature(formula = "formula", data = "SparkDataFrame"),
return(new("NaiveBayesModel", jobj = jobj))
})
+#' Save the Bernoulli naive Bayes model to the input path.
+#'
+#' @param object A fitted Bernoulli naive Bayes model
+#' @param path The directory where the model is saved
+#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
+#' which means throw exception if the output path exists.
+#'
+#' @rdname ml.save
+#' @name ml.save
+#' @export
+#' @examples
+#' \dontrun{
+#' df <- createDataFrame(sqlContext, infert)
+#' model <- naiveBayes(education ~ ., df, laplace = 0)
+#' path <- "path/to/model"
+#' ml.save(model, path)
+#' }
+setMethod("ml.save", signature(object = "NaiveBayesModel", path = "character"),
+ function(object, path, overwrite = FALSE) {
+ writer <- callJMethod(object@jobj, "write")
+ if (overwrite) {
+ writer <- callJMethod(writer, "overwrite")
+ }
+ invisible(callJMethod(writer, "save", path))
+ })
+
+#' Load a fitted MLlib model from the input path.
+#'
+#' @param path Path of the model to read.
+#' @return a fitted MLlib model
+#' @rdname ml.load
+#' @name ml.load
+#' @export
+#' @examples
+#' \dontrun{
+#' path <- "path/to/model"
+#' model <- ml.load(path)
+#' }
+ml.load <- function(path) {
+ path <- suppressWarnings(normalizePath(path))
+ jobj <- callJStatic("org.apache.spark.ml.r.RWrappers", "load", path)
+ if (isInstanceOf(jobj, "org.apache.spark.ml.r.NaiveBayesWrapper")) {
+ return(new("NaiveBayesModel", jobj = jobj))
+ } else {
+ stop(paste("Unsupported model: ", jobj))
+ }
+}
+
#' Fit an accelerated failure time (AFT) survival regression model.
#'
#' Fit an accelerated failure time (AFT) survival regression model, similarly to R's survreg().
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index 1597306bb6..63ec84e497 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -204,6 +204,18 @@ test_that("naiveBayes", {
"Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "No", "No",
"Yes", "Yes", "No", "No"))
+ # Test model save/load
+ modelPath <- tempfile(pattern = "naiveBayes", fileext = ".tmp")
+ ml.save(m, modelPath)
+ expect_error(ml.save(m, modelPath))
+ ml.save(m, modelPath, overwrite = TRUE)
+ m2 <- ml.load(modelPath)
+ s2 <- summary(m2)
+ expect_equal(s$apriori, s2$apriori)
+ expect_equal(s$tables, s2$tables)
+
+ unlink(modelPath)
+
# Test e1071::naiveBayes
if (requireNamespace("e1071", quietly = TRUE)) {
expect_that(m <- e1071::naiveBayes(Survived ~ ., data = t1), not(throws_error()))