diff options
Diffstat (limited to 'R/pkg/inst/tests/testthat')
-rw-r--r-- | R/pkg/inst/tests/testthat/test_mllib.R | 32 |
1 files changed, 32 insertions, 0 deletions
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index de9bd48662..1e6da650d1 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -347,6 +347,38 @@ test_that("spark.kmeans", { unlink(modelPath) }) +test_that("spark.mlp", { + df <- read.df("data/mllib/sample_multiclass_classification_data.txt", source = "libsvm") + model <- spark.mlp(df, blockSize = 128, layers = c(4, 5, 4, 3), solver = "l-bfgs", maxIter = 100, + tol = 0.5, stepSize = 1, seed = 1) + + # Test summary method + summary <- summary(model) + expect_equal(summary$labelCount, 3) + expect_equal(summary$layers, c(4, 5, 4, 3)) + expect_equal(length(summary$weights), 64) + + # Test predict method + mlpTestDF <- df + mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) + expect_equal(head(mlpPredictions$prediction, 6), c(0, 1, 1, 1, 1, 1)) + + # Test model save/load + modelPath <- tempfile(pattern = "spark-mlp", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + summary2 <- summary(model2) + + expect_equal(summary2$labelCount, 3) + expect_equal(summary2$layers, c(4, 5, 4, 3)) + expect_equal(length(summary2$weights), 64) + + unlink(modelPath) + +}) + test_that("spark.naiveBayes", { # R code to reproduce the result. # We do not support instance weights yet. So we ignore the frequencies. |