aboutsummaryrefslogtreecommitdiff
path: root/R/pkg/inst/tests/testthat/test_mllib_classification.R
diff options
context:
space:
mode:
Diffstat (limited to 'R/pkg/inst/tests/testthat/test_mllib_classification.R')
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib_classification.R10
1 files changed, 9 insertions, 1 deletions
diff --git a/R/pkg/inst/tests/testthat/test_mllib_classification.R b/R/pkg/inst/tests/testthat/test_mllib_classification.R
index 620f528f2e..459254d271 100644
--- a/R/pkg/inst/tests/testthat/test_mllib_classification.R
+++ b/R/pkg/inst/tests/testthat/test_mllib_classification.R
@@ -211,7 +211,15 @@ test_that("spark.logit", {
df <- createDataFrame(data)
model <- spark.logit(df, label ~ feature)
prediction <- collect(select(predict(model, df), "prediction"))
- expect_equal(prediction$prediction, c("0.0", "0.0", "1.0", "1.0", "0.0"))
+ expect_equal(sort(prediction$prediction), c("0.0", "0.0", "0.0", "1.0", "1.0"))
+
+ # Test prediction with weightCol
+ weight <- c(2.0, 2.0, 2.0, 1.0, 1.0)
+ data2 <- as.data.frame(cbind(label, feature, weight))
+ df2 <- createDataFrame(data2)
+ model2 <- spark.logit(df2, label ~ feature, weightCol = "weight")
+ prediction2 <- collect(select(predict(model2, df2), "prediction"))
+ expect_equal(sort(prediction2$prediction), c("0.0", "0.0", "0.0", "0.0", "0.0"))
})
test_that("spark.mlp", {