aboutsummaryrefslogtreecommitdiff
path: root/R/pkg/inst
diff options
context:
space:
mode:
Diffstat (limited to 'R/pkg/inst')
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib.R22
1 files changed, 22 insertions, 0 deletions
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index ab390a86d1..bc18224680 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -118,6 +118,28 @@ test_that("spark.glm summary", {
expect_equal(stats$df.residual, rStats$df.residual)
expect_equal(stats$aic, rStats$aic)
+ # Test spark.glm works with weighted dataset
+ a1 <- c(0, 1, 2, 3)
+ a2 <- c(5, 2, 1, 3)
+ w <- c(1, 2, 3, 4)
+ b <- c(1, 0, 1, 0)
+ data <- as.data.frame(cbind(a1, a2, w, b))
+ df <- suppressWarnings(createDataFrame(data))
+
+ stats <- summary(spark.glm(df, b ~ a1 + a2, family = "binomial", weightCol = "w"))
+ rStats <- summary(glm(b ~ a1 + a2, family = "binomial", data = data, weights = w))
+
+ coefs <- unlist(stats$coefficients)
+ rCoefs <- unlist(rStats$coefficients)
+ expect_true(all(abs(rCoefs - coefs) < 1e-3))
+ expect_true(all(rownames(stats$coefficients) == c("(Intercept)", "a1", "a2")))
+ expect_equal(stats$dispersion, rStats$dispersion)
+ expect_equal(stats$null.deviance, rStats$null.deviance)
+ expect_equal(stats$deviance, rStats$deviance)
+ expect_equal(stats$df.null, rStats$df.null)
+ expect_equal(stats$df.residual, rStats$df.residual)
+ expect_equal(stats$aic, rStats$aic)
+
# Test summary works on base GLM models
baseModel <- stats::glm(Sepal.Width ~ Sepal.Length + Species, data = iris)
baseSummary <- summary(baseModel)