aboutsummaryrefslogtreecommitdiff
path: root/R/pkg/inst/tests/testthat/test_mllib.R
diff options
context:
space:
mode:
Diffstat (limited to 'R/pkg/inst/tests/testthat/test_mllib.R')
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib.R49
1 files changed, 49 insertions, 0 deletions
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index 44b48369ef..fdb591756e 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -200,3 +200,52 @@ test_that("naiveBayes", {
expect_equal(as.character(predict(m, t1[1, ])), "Yes")
}
})
+
+test_that("survreg", {
+ # R code to reproduce the result.
+ #
+ #' rData <- list(time = c(4, 3, 1, 1, 2, 2, 3), status = c(1, 1, 1, 0, 1, 1, 0),
+ #' x = c(0, 2, 1, 1, 1, 0, 0), sex = c(0, 0, 0, 0, 1, 1, 1))
+ #' library(survival)
+ #' model <- survreg(Surv(time, status) ~ x + sex, rData)
+ #' summary(model)
+ #' predict(model, data)
+ #
+ # -- output of 'summary(model)'
+ #
+ # Value Std. Error z p
+ # (Intercept) 1.315 0.270 4.88 1.07e-06
+ # x -0.190 0.173 -1.10 2.72e-01
+ # sex -0.253 0.329 -0.77 4.42e-01
+ # Log(scale) -1.160 0.396 -2.93 3.41e-03
+ #
+ # -- output of 'predict(model, data)'
+ #
+ # 1 2 3 4 5 6 7
+ # 3.724591 2.545368 3.079035 3.079035 2.390146 2.891269 2.891269
+ #
+ data <- list(list(4, 1, 0, 0), list(3, 1, 2, 0), list(1, 1, 1, 0),
+ list(1, 0, 1, 0), list(2, 1, 1, 1), list(2, 1, 0, 1), list(3, 0, 0, 1))
+ df <- createDataFrame(sqlContext, data, c("time", "status", "x", "sex"))
+ model <- survreg(Surv(time, status) ~ x + sex, df)
+ stats <- summary(model)
+ coefs <- as.vector(stats$coefficients[, 1])
+ rCoefs <- c(1.3149571, -0.1903409, -0.2532618, -1.1599800)
+ expect_equal(coefs, rCoefs, tolerance = 1e-4)
+ expect_true(all(
+ rownames(stats$coefficients) ==
+ c("(Intercept)", "x", "sex", "Log(scale)")))
+ p <- collect(select(predict(model, df), "prediction"))
+ expect_equal(p$prediction, c(3.724591, 2.545368, 3.079035, 3.079035,
+ 2.390146, 2.891269, 2.891269), tolerance = 1e-4)
+
+ # Test survival::survreg
+ if (requireNamespace("survival", quietly = TRUE)) {
+ rData <- list(time = c(4, 3, 1, 1, 2, 2, 3), status = c(1, 1, 1, 0, 1, 1, 0),
+ x = c(0, 2, 1, 1, 1, 0, 0), sex = c(0, 0, 0, 0, 1, 1, 1))
+ expect_that(
+ model <- survival::survreg(formula = survival::Surv(time, status) ~ x + sex, data = rData),
+ not(throws_error()))
+ expect_equal(predict(model, rData)[[1]], 3.724591, tolerance = 1e-4)
+ }
+})