aboutsummaryrefslogtreecommitdiff
path: root/R
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2016-03-22 14:16:51 -0700
committerXiangrui Meng <meng@databricks.com>2016-03-22 14:16:51 -0700
commitd6dc12ef0146ae409834c78737c116050961f350 (patch)
tree7e99255f2a15ee2d088677253465ec6951b0a8d4 /R
parentb2b1ad7d4cc3b3469c3d2c841b40b58ed0e34447 (diff)
downloadspark-d6dc12ef0146ae409834c78737c116050961f350.tar.gz
spark-d6dc12ef0146ae409834c78737c116050961f350.tar.bz2
spark-d6dc12ef0146ae409834c78737c116050961f350.zip
[SPARK-13449] Naive Bayes wrapper in SparkR
## What changes were proposed in this pull request? This PR continues the work in #11486 from yinxusen with some code refactoring. In R package e1071, `naiveBayes` supports both categorical (Bernoulli) and continuous features (Gaussian), while in MLlib we support Bernoulli and multinomial. This PR implements the common subset: Bernoulli. I moved the implementation out from SparkRWrappers to NaiveBayesWrapper to make it easier to read. Argument names, default values, and summary now match e1071's naiveBayes. I removed the preprocess part that omit NA values because we don't know which columns to process. ## How was this patch tested? Test against output from R package e1071's naiveBayes. cc: yanboliang yinxusen Closes #11486 Author: Xusen Yin <yinxusen@gmail.com> Author: Xiangrui Meng <meng@databricks.com> Closes #11890 from mengxr/SPARK-13449.
Diffstat (limited to 'R')
-rw-r--r--R/pkg/DESCRIPTION3
-rw-r--r--R/pkg/NAMESPACE3
-rw-r--r--R/pkg/R/generics.R4
-rw-r--r--R/pkg/R/mllib.R91
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib.R59
5 files changed, 153 insertions, 7 deletions
diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION
index 0cd0d75df0..e26f9a7a2a 100644
--- a/R/pkg/DESCRIPTION
+++ b/R/pkg/DESCRIPTION
@@ -11,7 +11,8 @@ Depends:
R (>= 3.0),
methods,
Suggests:
- testthat
+ testthat,
+ e1071
Description: R frontend for Spark
License: Apache License (== 2.0)
Collate:
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 636d39e1e9..5d8a4b1d6e 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -15,7 +15,8 @@ exportMethods("glm",
"predict",
"summary",
"kmeans",
- "fitted")
+ "fitted",
+ "naiveBayes")
# Job group lifecycle management methods
export("setJobGroup",
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 6ad71fcb46..46b115f45e 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -1175,3 +1175,7 @@ setGeneric("kmeans")
#' @rdname fitted
#' @export
setGeneric("fitted")
+
+#' @rdname naiveBayes
+#' @export
+setGeneric("naiveBayes", function(formula, data, ...) { standardGeneric("naiveBayes") })
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index 5c0d3dcf3a..2555019369 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -22,6 +22,11 @@
#' @export
setClass("PipelineModel", representation(model = "jobj"))
+#' @title S4 class that represents a NaiveBayesModel
+#' @param jobj a Java object reference to the backing Scala NaiveBayesWrapper
+#' @export
+setClass("NaiveBayesModel", representation(jobj = "jobj"))
+
#' Fits a generalized linear model
#'
#' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package.
@@ -42,7 +47,7 @@ setClass("PipelineModel", representation(model = "jobj"))
#' @rdname glm
#' @export
#' @examples
-#'\dontrun{
+#' \dontrun{
#' sc <- sparkR.init()
#' sqlContext <- sparkRSQL.init(sc)
#' data(iris)
@@ -71,7 +76,7 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFram
#' @rdname predict
#' @export
#' @examples
-#'\dontrun{
+#' \dontrun{
#' model <- glm(y ~ x, trainingData)
#' predicted <- predict(model, testData)
#' showDF(predicted)
@@ -81,6 +86,26 @@ setMethod("predict", signature(object = "PipelineModel"),
return(dataFrame(callJMethod(object@model, "transform", newData@sdf)))
})
+#' Make predictions from a naive Bayes model
+#'
+#' Makes predictions from a model produced by naiveBayes(), similarly to R package e1071's predict.
+#'
+#' @param object A fitted naive Bayes model
+#' @param newData DataFrame for testing
+#' @return DataFrame containing predicted labels in a column named "prediction"
+#' @rdname predict
+#' @export
+#' @examples
+#' \dontrun{
+#' model <- naiveBayes(y ~ x, trainingData)
+#' predicted <- predict(model, testData)
+#' showDF(predicted)
+#'}
+setMethod("predict", signature(object = "NaiveBayesModel"),
+ function(object, newData) {
+ return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
+ })
+
#' Get the summary of a model
#'
#' Returns the summary of a model produced by glm(), similarly to R's summary().
@@ -97,7 +122,7 @@ setMethod("predict", signature(object = "PipelineModel"),
#' @rdname summary
#' @export
#' @examples
-#'\dontrun{
+#' \dontrun{
#' model <- glm(y ~ x, trainingData)
#' summary(model)
#'}
@@ -140,6 +165,35 @@ setMethod("summary", signature(object = "PipelineModel"),
}
})
+#' Get the summary of a naive Bayes model
+#'
+#' Returns the summary of a naive Bayes model produced by naiveBayes(), similarly to R's summary().
+#'
+#' @param object A fitted MLlib model
+#' @return a list containing 'apriori', the label distribution, and 'tables', conditional
+# probabilities given the target label
+#' @rdname summary
+#' @export
+#' @examples
+#' \dontrun{
+#' model <- naiveBayes(y ~ x, trainingData)
+#' summary(model)
+#'}
+setMethod("summary", signature(object = "NaiveBayesModel"),
+ function(object, ...) {
+ jobj <- object@jobj
+ features <- callJMethod(jobj, "features")
+ labels <- callJMethod(jobj, "labels")
+ apriori <- callJMethod(jobj, "apriori")
+ apriori <- t(as.matrix(unlist(apriori)))
+ colnames(apriori) <- unlist(labels)
+ tables <- callJMethod(jobj, "tables")
+ tables <- matrix(tables, nrow = length(labels))
+ rownames(tables) <- unlist(labels)
+ colnames(tables) <- unlist(features)
+ return(list(apriori = apriori, tables = tables))
+ })
+
#' Fit a k-means model
#'
#' Fit a k-means model, similarly to R's kmeans().
@@ -152,7 +206,7 @@ setMethod("summary", signature(object = "PipelineModel"),
#' @rdname kmeans
#' @export
#' @examples
-#'\dontrun{
+#' \dontrun{
#' model <- kmeans(x, centers = 2, algorithm="random")
#'}
setMethod("kmeans", signature(x = "DataFrame"),
@@ -173,7 +227,7 @@ setMethod("kmeans", signature(x = "DataFrame"),
#' @rdname fitted
#' @export
#' @examples
-#'\dontrun{
+#' \dontrun{
#' model <- kmeans(trainingData, 2)
#' fitted.model <- fitted(model)
#' showDF(fitted.model)
@@ -192,3 +246,30 @@ setMethod("fitted", signature(object = "PipelineModel"),
stop(paste("Unsupported model", modelName, sep = " "))
}
})
+
+#' Fit a Bernoulli naive Bayes model
+#'
+#' Fit a Bernoulli naive Bayes model, similarly to R package e1071's naiveBayes() while only
+#' categorical features are supported. The input should be a DataFrame of observations instead of a
+#' contingency table.
+#'
+#' @param object A symbolic description of the model to be fitted. Currently only a few formula
+#' operators are supported, including '~', '.', ':', '+', and '-'.
+#' @param data DataFrame for training
+#' @param laplace Smoothing parameter
+#' @return a fitted naive Bayes model
+#' @rdname naiveBayes
+#' @seealso e1071: \url{https://cran.r-project.org/web/packages/e1071/}
+#' @export
+#' @examples
+#' \dontrun{
+#' df <- createDataFrame(sqlContext, infert)
+#' model <- naiveBayes(education ~ ., df, laplace = 0)
+#'}
+setMethod("naiveBayes", signature(formula = "formula", data = "DataFrame"),
+ function(formula, data, laplace = 0, ...) {
+ formula <- paste(deparse(formula), collapse = "")
+ jobj <- callJStatic("org.apache.spark.ml.r.NaiveBayesWrapper", "fit",
+ formula, data@sdf, laplace)
+ return(new("NaiveBayesModel", jobj = jobj))
+ })
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index e120462964..44b48369ef 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -141,3 +141,62 @@ test_that("kmeans", {
cluster <- summary.model$cluster
expect_equal(sort(collect(distinct(select(cluster, "prediction")))$prediction), c(0, 1))
})
+
+test_that("naiveBayes", {
+ # R code to reproduce the result.
+ # We do not support instance weights yet. So we ignore the frequencies.
+ #
+ #' library(e1071)
+ #' t <- as.data.frame(Titanic)
+ #' t1 <- t[t$Freq > 0, -5]
+ #' m <- naiveBayes(Survived ~ ., data = t1)
+ #' m
+ #' predict(m, t1)
+ #
+ # -- output of 'm'
+ #
+ # A-priori probabilities:
+ # Y
+ # No Yes
+ # 0.4166667 0.5833333
+ #
+ # Conditional probabilities:
+ # Class
+ # Y 1st 2nd 3rd Crew
+ # No 0.2000000 0.2000000 0.4000000 0.2000000
+ # Yes 0.2857143 0.2857143 0.2857143 0.1428571
+ #
+ # Sex
+ # Y Male Female
+ # No 0.5 0.5
+ # Yes 0.5 0.5
+ #
+ # Age
+ # Y Child Adult
+ # No 0.2000000 0.8000000
+ # Yes 0.4285714 0.5714286
+ #
+ # -- output of 'predict(m, t1)'
+ #
+ # Yes Yes Yes Yes No No Yes Yes No No Yes Yes Yes Yes Yes Yes Yes Yes No No Yes Yes No No
+ #
+
+ t <- as.data.frame(Titanic)
+ t1 <- t[t$Freq > 0, -5]
+ df <- suppressWarnings(createDataFrame(sqlContext, t1))
+ m <- naiveBayes(Survived ~ ., data = df)
+ s <- summary(m)
+ expect_equal(as.double(s$apriori[1, "Yes"]), 0.5833333, tolerance = 1e-6)
+ expect_equal(sum(s$apriori), 1)
+ expect_equal(as.double(s$tables["Yes", "Age_Adult"]), 0.5714286, tolerance = 1e-6)
+ p <- collect(select(predict(m, df), "prediction"))
+ expect_equal(p$prediction, c("Yes", "Yes", "Yes", "Yes", "No", "No", "Yes", "Yes", "No", "No",
+ "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "No", "No",
+ "Yes", "Yes", "No", "No"))
+
+ # Test e1071::naiveBayes
+ if (requireNamespace("e1071", quietly = TRUE)) {
+ expect_that(m <- e1071::naiveBayes(Survived ~ ., data = t1), not(throws_error()))
+ expect_equal(as.character(predict(m, t1[1, ])), "Yes")
+ }
+})