diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2016-04-25 14:08:41 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2016-04-25 14:08:41 -0700 |
commit | 9cb3ba1013a7eae11be8a00fa4a9c5308bb20195 (patch) | |
tree | eb275db612f3bc4f438aa426bb49c528d6fc0fe9 /R/pkg | |
parent | 0c47e274ab8c286498fa002e2c92febcb53905c6 (diff) | |
download | spark-9cb3ba1013a7eae11be8a00fa4a9c5308bb20195.tar.gz spark-9cb3ba1013a7eae11be8a00fa4a9c5308bb20195.tar.bz2 spark-9cb3ba1013a7eae11be8a00fa4a9c5308bb20195.zip |
[SPARK-14312][ML][SPARKR] NaiveBayes model persistence in SparkR
## What changes were proposed in this pull request?
SparkR ```NaiveBayesModel``` supports ```save/load``` by the following API:
```
df <- createDataFrame(sqlContext, infert)
model <- naiveBayes(education ~ ., df, laplace = 0)
ml.save(model, path)
model2 <- ml.load(path)
```
## How was this patch tested?
Add unit tests.
cc mengxr
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #12573 from yanboliang/spark-14312.
Diffstat (limited to 'R/pkg')
-rw-r--r-- | R/pkg/NAMESPACE | 6 | ||||
-rw-r--r-- | R/pkg/R/generics.R | 4 | ||||
-rw-r--r-- | R/pkg/R/mllib.R | 48 | ||||
-rw-r--r-- | R/pkg/inst/tests/testthat/test_mllib.R | 12 |
4 files changed, 68 insertions, 2 deletions
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 0f92b5e597..c0a63d6b3e 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -107,7 +107,8 @@ exportMethods("arrange", "write.jdbc", "write.json", "write.parquet", - "write.text") + "write.text", + "ml.save") exportClasses("Column") @@ -299,7 +300,8 @@ export("as.DataFrame", "tableNames", "tables", "uncacheTable", - "print.summary.GeneralizedLinearRegressionModel") + "print.summary.GeneralizedLinearRegressionModel", + "ml.load") export("structField", "structField.jobj", diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 04274a12bc..f654d8330c 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1200,3 +1200,7 @@ setGeneric("naiveBayes", function(formula, data, ...) { standardGeneric("naiveBa #' @rdname survreg #' @export setGeneric("survreg", function(formula, data, ...) { standardGeneric("survreg") }) + +#' @rdname ml.save +#' @export +setGeneric("ml.save", function(object, path, ...) { standardGeneric("ml.save") }) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 7dd82963a1..cda6100e79 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -338,6 +338,54 @@ setMethod("naiveBayes", signature(formula = "formula", data = "SparkDataFrame"), return(new("NaiveBayesModel", jobj = jobj)) }) +#' Save the Bernoulli naive Bayes model to the input path. +#' +#' @param object A fitted Bernoulli naive Bayes model +#' @param path The directory where the model is saved +#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' +#' @rdname ml.save +#' @name ml.save +#' @export +#' @examples +#' \dontrun{ +#' df <- createDataFrame(sqlContext, infert) +#' model <- naiveBayes(education ~ ., df, laplace = 0) +#' path <- "path/to/model" +#' ml.save(model, path) +#' } +setMethod("ml.save", signature(object = "NaiveBayesModel", path = "character"), + function(object, path, overwrite = FALSE) { + writer <- callJMethod(object@jobj, "write") + if (overwrite) { + writer <- callJMethod(writer, "overwrite") + } + invisible(callJMethod(writer, "save", path)) + }) + +#' Load a fitted MLlib model from the input path. +#' +#' @param path Path of the model to read. +#' @return a fitted MLlib model +#' @rdname ml.load +#' @name ml.load +#' @export +#' @examples +#' \dontrun{ +#' path <- "path/to/model" +#' model <- ml.load(path) +#' } +ml.load <- function(path) { + path <- suppressWarnings(normalizePath(path)) + jobj <- callJStatic("org.apache.spark.ml.r.RWrappers", "load", path) + if (isInstanceOf(jobj, "org.apache.spark.ml.r.NaiveBayesWrapper")) { + return(new("NaiveBayesModel", jobj = jobj)) + } else { + stop(paste("Unsupported model: ", jobj)) + } +} + #' Fit an accelerated failure time (AFT) survival regression model. #' #' Fit an accelerated failure time (AFT) survival regression model, similarly to R's survreg(). diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index 1597306bb6..63ec84e497 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -204,6 +204,18 @@ test_that("naiveBayes", { "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "No", "No", "Yes", "Yes", "No", "No")) + # Test model save/load + modelPath <- tempfile(pattern = "naiveBayes", fileext = ".tmp") + ml.save(m, modelPath) + expect_error(ml.save(m, modelPath)) + ml.save(m, modelPath, overwrite = TRUE) + m2 <- ml.load(modelPath) + s2 <- summary(m2) + expect_equal(s$apriori, s2$apriori) + expect_equal(s$tables, s2$tables) + + unlink(modelPath) + # Test e1071::naiveBayes if (requireNamespace("e1071", quietly = TRUE)) { expect_that(m <- e1071::naiveBayes(Survived ~ ., data = t1), not(throws_error())) |