diff options
author | Xusen Yin <yinxusen@gmail.com> | 2016-03-22 14:16:51 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2016-03-22 14:16:51 -0700 |
commit | d6dc12ef0146ae409834c78737c116050961f350 (patch) | |
tree | 7e99255f2a15ee2d088677253465ec6951b0a8d4 /R/pkg/inst/tests | |
parent | b2b1ad7d4cc3b3469c3d2c841b40b58ed0e34447 (diff) | |
download | spark-d6dc12ef0146ae409834c78737c116050961f350.tar.gz spark-d6dc12ef0146ae409834c78737c116050961f350.tar.bz2 spark-d6dc12ef0146ae409834c78737c116050961f350.zip |
[SPARK-13449] Naive Bayes wrapper in SparkR
## What changes were proposed in this pull request?
This PR continues the work in #11486 from yinxusen with some code refactoring. In R package e1071, `naiveBayes` supports both categorical (Bernoulli) and continuous features (Gaussian), while in MLlib we support Bernoulli and multinomial. This PR implements the common subset: Bernoulli.
I moved the implementation out from SparkRWrappers to NaiveBayesWrapper to make it easier to read. Argument names, default values, and summary now match e1071's naiveBayes.
I removed the preprocess part that omit NA values because we don't know which columns to process.
## How was this patch tested?
Test against output from R package e1071's naiveBayes.
cc: yanboliang yinxusen
Closes #11486
Author: Xusen Yin <yinxusen@gmail.com>
Author: Xiangrui Meng <meng@databricks.com>
Closes #11890 from mengxr/SPARK-13449.
Diffstat (limited to 'R/pkg/inst/tests')
-rw-r--r-- | R/pkg/inst/tests/testthat/test_mllib.R | 59 |
1 files changed, 59 insertions, 0 deletions
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index e120462964..44b48369ef 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -141,3 +141,62 @@ test_that("kmeans", { cluster <- summary.model$cluster expect_equal(sort(collect(distinct(select(cluster, "prediction")))$prediction), c(0, 1)) }) + +test_that("naiveBayes", { + # R code to reproduce the result. + # We do not support instance weights yet. So we ignore the frequencies. + # + #' library(e1071) + #' t <- as.data.frame(Titanic) + #' t1 <- t[t$Freq > 0, -5] + #' m <- naiveBayes(Survived ~ ., data = t1) + #' m + #' predict(m, t1) + # + # -- output of 'm' + # + # A-priori probabilities: + # Y + # No Yes + # 0.4166667 0.5833333 + # + # Conditional probabilities: + # Class + # Y 1st 2nd 3rd Crew + # No 0.2000000 0.2000000 0.4000000 0.2000000 + # Yes 0.2857143 0.2857143 0.2857143 0.1428571 + # + # Sex + # Y Male Female + # No 0.5 0.5 + # Yes 0.5 0.5 + # + # Age + # Y Child Adult + # No 0.2000000 0.8000000 + # Yes 0.4285714 0.5714286 + # + # -- output of 'predict(m, t1)' + # + # Yes Yes Yes Yes No No Yes Yes No No Yes Yes Yes Yes Yes Yes Yes Yes No No Yes Yes No No + # + + t <- as.data.frame(Titanic) + t1 <- t[t$Freq > 0, -5] + df <- suppressWarnings(createDataFrame(sqlContext, t1)) + m <- naiveBayes(Survived ~ ., data = df) + s <- summary(m) + expect_equal(as.double(s$apriori[1, "Yes"]), 0.5833333, tolerance = 1e-6) + expect_equal(sum(s$apriori), 1) + expect_equal(as.double(s$tables["Yes", "Age_Adult"]), 0.5714286, tolerance = 1e-6) + p <- collect(select(predict(m, df), "prediction")) + expect_equal(p$prediction, c("Yes", "Yes", "Yes", "Yes", "No", "No", "Yes", "Yes", "No", "No", + "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "No", "No", + "Yes", "Yes", "No", "No")) + + # Test e1071::naiveBayes + if (requireNamespace("e1071", quietly = TRUE)) { + expect_that(m <- e1071::naiveBayes(Survived ~ ., data = t1), not(throws_error())) + expect_equal(as.character(predict(m, t1[1, ])), "Yes") + } +}) |