aboutsummaryrefslogtreecommitdiff
path: root/R/pkg/inst/tests/testthat/test_mllib.R
diff options
context:
space:
mode:
Diffstat (limited to 'R/pkg/inst/tests/testthat/test_mllib.R')
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib.R32
1 files changed, 32 insertions, 0 deletions
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index bc18224680..b759b28927 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -476,4 +476,36 @@ test_that("spark.survreg", {
}
})
+test_that("spark.isotonicRegression", {
+ label <- c(7.0, 5.0, 3.0, 5.0, 1.0)
+ feature <- c(0.0, 1.0, 2.0, 3.0, 4.0)
+ weight <- c(1.0, 1.0, 1.0, 1.0, 1.0)
+ data <- as.data.frame(cbind(label, feature, weight))
+ df <- suppressWarnings(createDataFrame(data))
+
+ model <- spark.isoreg(df, label ~ feature, isotonic = FALSE,
+ weightCol = "weight")
+ # only allow one variable on the right hand side of the formula
+ expect_error(model2 <- spark.isoreg(df, ~., isotonic = FALSE))
+ result <- summary(model, df)
+ expect_equal(result$predictions, list(7, 5, 4, 4, 1))
+
+ # Test model prediction
+ predict_data <- list(list(-2.0), list(-1.0), list(0.5),
+ list(0.75), list(1.0), list(2.0), list(9.0))
+ predict_df <- createDataFrame(predict_data, c("feature"))
+ predict_result <- collect(select(predict(model, predict_df), "prediction"))
+ expect_equal(predict_result$prediction, c(7.0, 7.0, 6.0, 5.5, 5.0, 4.0, 1.0))
+
+ # Test model save/load
+ modelPath <- tempfile(pattern = "spark-isotonicRegression", fileext = ".tmp")
+ write.ml(model, modelPath)
+ expect_error(write.ml(model, modelPath))
+ write.ml(model, modelPath, overwrite = TRUE)
+ model2 <- read.ml(modelPath)
+ expect_equal(result, summary(model2, df))
+
+ unlink(modelPath)
+})
+
sparkR.session.stop()