diff options
Diffstat (limited to 'R/pkg/inst/tests/testthat/test_mllib.R')
-rw-r--r-- | R/pkg/inst/tests/testthat/test_mllib.R | 59 |
1 files changed, 59 insertions, 0 deletions
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index e120462964..44b48369ef 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -141,3 +141,62 @@ test_that("kmeans", { cluster <- summary.model$cluster expect_equal(sort(collect(distinct(select(cluster, "prediction")))$prediction), c(0, 1)) }) + +test_that("naiveBayes", { + # R code to reproduce the result. + # We do not support instance weights yet. So we ignore the frequencies. + # + #' library(e1071) + #' t <- as.data.frame(Titanic) + #' t1 <- t[t$Freq > 0, -5] + #' m <- naiveBayes(Survived ~ ., data = t1) + #' m + #' predict(m, t1) + # + # -- output of 'm' + # + # A-priori probabilities: + # Y + # No Yes + # 0.4166667 0.5833333 + # + # Conditional probabilities: + # Class + # Y 1st 2nd 3rd Crew + # No 0.2000000 0.2000000 0.4000000 0.2000000 + # Yes 0.2857143 0.2857143 0.2857143 0.1428571 + # + # Sex + # Y Male Female + # No 0.5 0.5 + # Yes 0.5 0.5 + # + # Age + # Y Child Adult + # No 0.2000000 0.8000000 + # Yes 0.4285714 0.5714286 + # + # -- output of 'predict(m, t1)' + # + # Yes Yes Yes Yes No No Yes Yes No No Yes Yes Yes Yes Yes Yes Yes Yes No No Yes Yes No No + # + + t <- as.data.frame(Titanic) + t1 <- t[t$Freq > 0, -5] + df <- suppressWarnings(createDataFrame(sqlContext, t1)) + m <- naiveBayes(Survived ~ ., data = df) + s <- summary(m) + expect_equal(as.double(s$apriori[1, "Yes"]), 0.5833333, tolerance = 1e-6) + expect_equal(sum(s$apriori), 1) + expect_equal(as.double(s$tables["Yes", "Age_Adult"]), 0.5714286, tolerance = 1e-6) + p <- collect(select(predict(m, df), "prediction")) + expect_equal(p$prediction, c("Yes", "Yes", "Yes", "Yes", "No", "No", "Yes", "Yes", "No", "No", + "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "No", "No", + "Yes", "Yes", "No", "No")) + + # Test e1071::naiveBayes + if (requireNamespace("e1071", quietly = TRUE)) { + expect_that(m <- e1071::naiveBayes(Survived ~ ., data = t1), not(throws_error())) + expect_equal(as.character(predict(m, t1[1, ])), "Yes") + } +}) |