aboutsummaryrefslogtreecommitdiff
path: root/R/pkg/inst/tests/testthat/test_mllib_regression.R
diff options
context:
space:
mode:
Diffstat (limited to 'R/pkg/inst/tests/testthat/test_mllib_regression.R')
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib_regression.R27
1 files changed, 15 insertions, 12 deletions
diff --git a/R/pkg/inst/tests/testthat/test_mllib_regression.R b/R/pkg/inst/tests/testthat/test_mllib_regression.R
index c450a15171..81a5bdc414 100644
--- a/R/pkg/inst/tests/testthat/test_mllib_regression.R
+++ b/R/pkg/inst/tests/testthat/test_mllib_regression.R
@@ -87,11 +87,14 @@ test_that("spark.glm summary", {
# gaussian family
training <- suppressWarnings(createDataFrame(iris))
stats <- summary(spark.glm(training, Sepal_Width ~ Sepal_Length + Species))
-
rStats <- summary(glm(Sepal.Width ~ Sepal.Length + Species, data = iris))
- coefs <- unlist(stats$coefficients)
- rCoefs <- unlist(rStats$coefficients)
+ # test summary coefficients return matrix type
+ expect_true(class(stats$coefficients) == "matrix")
+ expect_true(class(stats$coefficients[, 1]) == "numeric")
+
+ coefs <- stats$coefficients
+ rCoefs <- rStats$coefficients
expect_true(all(abs(rCoefs - coefs) < 1e-4))
expect_true(all(
rownames(stats$coefficients) ==
@@ -117,8 +120,8 @@ test_that("spark.glm summary", {
rStats <- summary(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining,
family = binomial(link = "logit")))
- coefs <- unlist(stats$coefficients)
- rCoefs <- unlist(rStats$coefficients)
+ coefs <- stats$coefficients
+ rCoefs <- rStats$coefficients
expect_true(all(abs(rCoefs - coefs) < 1e-4))
expect_true(all(
rownames(stats$coefficients) ==
@@ -141,8 +144,8 @@ test_that("spark.glm summary", {
stats <- summary(spark.glm(df, b ~ a1 + a2, family = "binomial", weightCol = "w"))
rStats <- summary(glm(b ~ a1 + a2, family = "binomial", data = data, weights = w))
- coefs <- unlist(stats$coefficients)
- rCoefs <- unlist(rStats$coefficients)
+ coefs <- stats$coefficients
+ rCoefs <- rStats$coefficients
expect_true(all(abs(rCoefs - coefs) < 1e-3))
expect_true(all(rownames(stats$coefficients) == c("(Intercept)", "a1", "a2")))
expect_equal(stats$dispersion, rStats$dispersion)
@@ -169,7 +172,7 @@ test_that("spark.glm summary", {
data <- as.data.frame(cbind(A, b))
df <- createDataFrame(data)
stats <- summary(spark.glm(df, b ~ . - 1))
- coefs <- unlist(stats$coefficients)
+ coefs <- stats$coefficients
expect_true(all(abs(c(0.5, 0.25) - coefs) < 1e-4))
})
@@ -259,8 +262,8 @@ test_that("glm summary", {
rStats <- summary(glm(Sepal.Width ~ Sepal.Length + Species, data = iris))
- coefs <- unlist(stats$coefficients)
- rCoefs <- unlist(rStats$coefficients)
+ coefs <- stats$coefficients
+ rCoefs <- rStats$coefficients
expect_true(all(abs(rCoefs - coefs) < 1e-4))
expect_true(all(
rownames(stats$coefficients) ==
@@ -282,8 +285,8 @@ test_that("glm summary", {
rStats <- summary(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining,
family = binomial(link = "logit")))
- coefs <- unlist(stats$coefficients)
- rCoefs <- unlist(rStats$coefficients)
+ coefs <- stats$coefficients
+ rCoefs <- rStats$coefficients
expect_true(all(abs(rCoefs - coefs) < 1e-4))
expect_true(all(
rownames(stats$coefficients) ==