aboutsummaryrefslogtreecommitdiff
path: root/R/pkg/inst/tests/testthat/test_mllib_classification.R
diff options
context:
space:
mode:
Diffstat (limited to 'R/pkg/inst/tests/testthat/test_mllib_classification.R')
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib_classification.R44
1 files changed, 44 insertions, 0 deletions
diff --git a/R/pkg/inst/tests/testthat/test_mllib_classification.R b/R/pkg/inst/tests/testthat/test_mllib_classification.R
index 5f84a620c1..620f528f2e 100644
--- a/R/pkg/inst/tests/testthat/test_mllib_classification.R
+++ b/R/pkg/inst/tests/testthat/test_mllib_classification.R
@@ -27,6 +27,50 @@ absoluteSparkPath <- function(x) {
file.path(sparkHome, x)
}
+test_that("spark.svmLinear", {
+ df <- suppressWarnings(createDataFrame(iris))
+ training <- df[df$Species %in% c("versicolor", "virginica"), ]
+ model <- spark.svmLinear(training, Species ~ ., regParam = 0.01, maxIter = 10)
+ summary <- summary(model)
+
+ # test summary coefficients return matrix type
+ expect_true(class(summary$coefficients) == "matrix")
+ expect_true(class(summary$coefficients[, 1]) == "numeric")
+
+ coefs <- summary$coefficients[, "Estimate"]
+ expected_coefs <- c(-0.1563083, -0.460648, 0.2276626, 1.055085)
+ expect_true(all(abs(coefs - expected_coefs) < 0.1))
+ expect_equal(summary$intercept, -0.06004978, tolerance = 1e-2)
+
+ # Test prediction with string label
+ prediction <- predict(model, training)
+ expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "character")
+ expected <- c("versicolor", "versicolor", "versicolor", "virginica", "virginica",
+ "virginica", "virginica", "virginica", "virginica", "virginica")
+ expect_equal(sort(as.list(take(select(prediction, "prediction"), 10))[[1]]), expected)
+
+ # Test model save and load
+ modelPath <- tempfile(pattern = "spark-svm-linear", fileext = ".tmp")
+ write.ml(model, modelPath)
+ expect_error(write.ml(model, modelPath))
+ write.ml(model, modelPath, overwrite = TRUE)
+ model2 <- read.ml(modelPath)
+ coefs <- summary(model)$coefficients
+ coefs2 <- summary(model2)$coefficients
+ expect_equal(coefs, coefs2)
+ unlink(modelPath)
+
+ # Test prediction with numeric label
+ label <- c(0.0, 0.0, 0.0, 1.0, 1.0)
+ feature <- c(1.1419053, 0.9194079, -0.9498666, -1.1069903, 0.2809776)
+ data <- as.data.frame(cbind(label, feature))
+ df <- createDataFrame(data)
+ model <- spark.svmLinear(df, label ~ feature, regParam = 0.1)
+ prediction <- collect(select(predict(model, df), "prediction"))
+ expect_equal(sort(prediction$prediction), c("0.0", "0.0", "0.0", "1.0", "1.0"))
+
+})
+
test_that("spark.logit", {
# R code to reproduce the result.
# nolint start