aboutsummaryrefslogtreecommitdiff
path: root/R/pkg/inst/tests
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2016-03-22 14:16:51 -0700
committerXiangrui Meng <meng@databricks.com>2016-03-22 14:16:51 -0700
commitd6dc12ef0146ae409834c78737c116050961f350 (patch)
tree7e99255f2a15ee2d088677253465ec6951b0a8d4 /R/pkg/inst/tests
parentb2b1ad7d4cc3b3469c3d2c841b40b58ed0e34447 (diff)
downloadspark-d6dc12ef0146ae409834c78737c116050961f350.tar.gz
spark-d6dc12ef0146ae409834c78737c116050961f350.tar.bz2
spark-d6dc12ef0146ae409834c78737c116050961f350.zip
[SPARK-13449] Naive Bayes wrapper in SparkR
## What changes were proposed in this pull request? This PR continues the work in #11486 from yinxusen with some code refactoring. In R package e1071, `naiveBayes` supports both categorical (Bernoulli) and continuous features (Gaussian), while in MLlib we support Bernoulli and multinomial. This PR implements the common subset: Bernoulli. I moved the implementation out from SparkRWrappers to NaiveBayesWrapper to make it easier to read. Argument names, default values, and summary now match e1071's naiveBayes. I removed the preprocess part that omit NA values because we don't know which columns to process. ## How was this patch tested? Test against output from R package e1071's naiveBayes. cc: yanboliang yinxusen Closes #11486 Author: Xusen Yin <yinxusen@gmail.com> Author: Xiangrui Meng <meng@databricks.com> Closes #11890 from mengxr/SPARK-13449.
Diffstat (limited to 'R/pkg/inst/tests')
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib.R59
1 files changed, 59 insertions, 0 deletions
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index e120462964..44b48369ef 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -141,3 +141,62 @@ test_that("kmeans", {
cluster <- summary.model$cluster
expect_equal(sort(collect(distinct(select(cluster, "prediction")))$prediction), c(0, 1))
})
+
+test_that("naiveBayes", {
+ # R code to reproduce the result.
+ # We do not support instance weights yet. So we ignore the frequencies.
+ #
+ #' library(e1071)
+ #' t <- as.data.frame(Titanic)
+ #' t1 <- t[t$Freq > 0, -5]
+ #' m <- naiveBayes(Survived ~ ., data = t1)
+ #' m
+ #' predict(m, t1)
+ #
+ # -- output of 'm'
+ #
+ # A-priori probabilities:
+ # Y
+ # No Yes
+ # 0.4166667 0.5833333
+ #
+ # Conditional probabilities:
+ # Class
+ # Y 1st 2nd 3rd Crew
+ # No 0.2000000 0.2000000 0.4000000 0.2000000
+ # Yes 0.2857143 0.2857143 0.2857143 0.1428571
+ #
+ # Sex
+ # Y Male Female
+ # No 0.5 0.5
+ # Yes 0.5 0.5
+ #
+ # Age
+ # Y Child Adult
+ # No 0.2000000 0.8000000
+ # Yes 0.4285714 0.5714286
+ #
+ # -- output of 'predict(m, t1)'
+ #
+ # Yes Yes Yes Yes No No Yes Yes No No Yes Yes Yes Yes Yes Yes Yes Yes No No Yes Yes No No
+ #
+
+ t <- as.data.frame(Titanic)
+ t1 <- t[t$Freq > 0, -5]
+ df <- suppressWarnings(createDataFrame(sqlContext, t1))
+ m <- naiveBayes(Survived ~ ., data = df)
+ s <- summary(m)
+ expect_equal(as.double(s$apriori[1, "Yes"]), 0.5833333, tolerance = 1e-6)
+ expect_equal(sum(s$apriori), 1)
+ expect_equal(as.double(s$tables["Yes", "Age_Adult"]), 0.5714286, tolerance = 1e-6)
+ p <- collect(select(predict(m, df), "prediction"))
+ expect_equal(p$prediction, c("Yes", "Yes", "Yes", "Yes", "No", "No", "Yes", "Yes", "No", "No",
+ "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "No", "No",
+ "Yes", "Yes", "No", "No"))
+
+ # Test e1071::naiveBayes
+ if (requireNamespace("e1071", quietly = TRUE)) {
+ expect_that(m <- e1071::naiveBayes(Survived ~ ., data = t1), not(throws_error()))
+ expect_equal(as.character(predict(m, t1[1, ])), "Yes")
+ }
+})