diff options
Diffstat (limited to 'R/pkg/inst/tests/testthat/test_mllib_classification.R')
-rw-r--r-- | R/pkg/inst/tests/testthat/test_mllib_classification.R | 44 |
1 files changed, 44 insertions, 0 deletions
diff --git a/R/pkg/inst/tests/testthat/test_mllib_classification.R b/R/pkg/inst/tests/testthat/test_mllib_classification.R index 5f84a620c1..620f528f2e 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_classification.R +++ b/R/pkg/inst/tests/testthat/test_mllib_classification.R @@ -27,6 +27,50 @@ absoluteSparkPath <- function(x) { file.path(sparkHome, x) } +test_that("spark.svmLinear", { + df <- suppressWarnings(createDataFrame(iris)) + training <- df[df$Species %in% c("versicolor", "virginica"), ] + model <- spark.svmLinear(training, Species ~ ., regParam = 0.01, maxIter = 10) + summary <- summary(model) + + # test summary coefficients return matrix type + expect_true(class(summary$coefficients) == "matrix") + expect_true(class(summary$coefficients[, 1]) == "numeric") + + coefs <- summary$coefficients[, "Estimate"] + expected_coefs <- c(-0.1563083, -0.460648, 0.2276626, 1.055085) + expect_true(all(abs(coefs - expected_coefs) < 0.1)) + expect_equal(summary$intercept, -0.06004978, tolerance = 1e-2) + + # Test prediction with string label + prediction <- predict(model, training) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "character") + expected <- c("versicolor", "versicolor", "versicolor", "virginica", "virginica", + "virginica", "virginica", "virginica", "virginica", "virginica") + expect_equal(sort(as.list(take(select(prediction, "prediction"), 10))[[1]]), expected) + + # Test model save and load + modelPath <- tempfile(pattern = "spark-svm-linear", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + coefs <- summary(model)$coefficients + coefs2 <- summary(model2)$coefficients + expect_equal(coefs, coefs2) + unlink(modelPath) + + # Test prediction with numeric label + label <- c(0.0, 0.0, 0.0, 1.0, 1.0) + feature <- c(1.1419053, 0.9194079, -0.9498666, -1.1069903, 0.2809776) + data <- as.data.frame(cbind(label, feature)) + df <- createDataFrame(data) + model <- spark.svmLinear(df, label ~ feature, regParam = 0.1) + prediction <- collect(select(predict(model, df), "prediction")) + expect_equal(sort(prediction$prediction), c("0.0", "0.0", "0.0", "1.0", "1.0")) + +}) + test_that("spark.logit", { # R code to reproduce the result. # nolint start |