aboutsummaryrefslogtreecommitdiff
path: root/R/pkg/inst/tests/test_mllib.R
diff options
context:
space:
mode:
Diffstat (limited to 'R/pkg/inst/tests/test_mllib.R')
-rw-r--r--R/pkg/inst/tests/test_mllib.R17
1 files changed, 17 insertions, 0 deletions
diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R
index 3331ce7383..032cfef061 100644
--- a/R/pkg/inst/tests/test_mllib.R
+++ b/R/pkg/inst/tests/test_mllib.R
@@ -67,3 +67,20 @@ test_that("summary coefficients match with native glm", {
as.character(stats$features) ==
c("(Intercept)", "Sepal_Length", "Species_versicolor", "Species_virginica")))
})
+
+test_that("summary coefficients match with native glm of family 'binomial'", {
+ df <- createDataFrame(sqlContext, iris)
+ training <- filter(df, df$Species != "setosa")
+ stats <- summary(glm(Species ~ Sepal_Length + Sepal_Width, data = training,
+ family = "binomial"))
+ coefs <- as.vector(stats$coefficients)
+
+ rTraining <- iris[iris$Species %in% c("versicolor","virginica"),]
+ rCoefs <- as.vector(coef(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining,
+ family = binomial(link = "logit"))))
+
+ expect_true(all(abs(rCoefs - coefs) < 1e-4))
+ expect_true(all(
+ as.character(stats$features) ==
+ c("(Intercept)", "Sepal_Length", "Sepal_Width")))
+})