diff options
Diffstat (limited to 'R/pkg/inst/tests')
-rw-r--r-- | R/pkg/inst/tests/testthat/test_mllib_clustering.R | 20 |
1 files changed, 20 insertions, 0 deletions
diff --git a/R/pkg/inst/tests/testthat/test_mllib_clustering.R b/R/pkg/inst/tests/testthat/test_mllib_clustering.R index 1980fffd80..f013991002 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_clustering.R +++ b/R/pkg/inst/tests/testthat/test_mllib_clustering.R @@ -132,6 +132,26 @@ test_that("spark.kmeans", { expect_true(summary2$is.loaded) unlink(modelPath) + + # Test Kmeans on dataset that is sensitive to seed value + col1 <- c(1, 2, 3, 4, 0, 1, 2, 3, 4, 0) + col2 <- c(1, 2, 3, 4, 0, 1, 2, 3, 4, 0) + col3 <- c(1, 2, 3, 4, 0, 1, 2, 3, 4, 0) + cols <- as.data.frame(cbind(col1, col2, col3)) + df <- createDataFrame(cols) + + model1 <- spark.kmeans(data = df, ~ ., k = 5, maxIter = 10, + initMode = "random", seed = 1, tol = 1E-5) + model2 <- spark.kmeans(data = df, ~ ., k = 5, maxIter = 10, + initMode = "random", seed = 22222, tol = 1E-5) + + fitted.model1 <- fitted(model1) + fitted.model2 <- fitted(model2) + # The predicted clusters are different + expect_equal(sort(collect(distinct(select(fitted.model1, "prediction")))$prediction), + c(0, 1, 2, 3)) + expect_equal(sort(collect(distinct(select(fitted.model2, "prediction")))$prediction), + c(0, 1, 2)) }) test_that("spark.lda with libsvm", { |