diff options
author | wm624@hotmail.com <wm624@hotmail.com> | 2017-01-12 22:27:57 -0800 |
---|---|---|
committer | Yanbo Liang <ybliang8@gmail.com> | 2017-01-12 22:27:57 -0800 |
commit | 7f24a0b6c32c56a38cf879d953bbd523922ab9c9 (patch) | |
tree | d60ea1d9a8fcf309fb5c938452ac7018fbc5dd38 /R/pkg/inst/tests | |
parent | 3356b8b6a9184fcab8d0fe993f3545c3beaa4d99 (diff) | |
download | spark-7f24a0b6c32c56a38cf879d953bbd523922ab9c9.tar.gz spark-7f24a0b6c32c56a38cf879d953bbd523922ab9c9.tar.bz2 spark-7f24a0b6c32c56a38cf879d953bbd523922ab9c9.zip |
[SPARK-19142][SPARKR] spark.kmeans should take seed, initSteps, and tol as parameters
## What changes were proposed in this pull request?
spark.kmeans doesn't have interface to set initSteps, seed and tol. As Spark Kmeans algorithm doesn't take the same set of parameters as R kmeans, we should maintain a different interface in spark.kmeans.
Add missing parameters and corresponding document.
Modified existing unit tests to take additional parameters.
Author: wm624@hotmail.com <wm624@hotmail.com>
Closes #16523 from wangmiao1981/kmeans.
Diffstat (limited to 'R/pkg/inst/tests')
-rw-r--r-- | R/pkg/inst/tests/testthat/test_mllib_clustering.R | 20 |
1 files changed, 20 insertions, 0 deletions
diff --git a/R/pkg/inst/tests/testthat/test_mllib_clustering.R b/R/pkg/inst/tests/testthat/test_mllib_clustering.R index 1980fffd80..f013991002 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_clustering.R +++ b/R/pkg/inst/tests/testthat/test_mllib_clustering.R @@ -132,6 +132,26 @@ test_that("spark.kmeans", { expect_true(summary2$is.loaded) unlink(modelPath) + + # Test Kmeans on dataset that is sensitive to seed value + col1 <- c(1, 2, 3, 4, 0, 1, 2, 3, 4, 0) + col2 <- c(1, 2, 3, 4, 0, 1, 2, 3, 4, 0) + col3 <- c(1, 2, 3, 4, 0, 1, 2, 3, 4, 0) + cols <- as.data.frame(cbind(col1, col2, col3)) + df <- createDataFrame(cols) + + model1 <- spark.kmeans(data = df, ~ ., k = 5, maxIter = 10, + initMode = "random", seed = 1, tol = 1E-5) + model2 <- spark.kmeans(data = df, ~ ., k = 5, maxIter = 10, + initMode = "random", seed = 22222, tol = 1E-5) + + fitted.model1 <- fitted(model1) + fitted.model2 <- fitted(model2) + # The predicted clusters are different + expect_equal(sort(collect(distinct(select(fitted.model1, "prediction")))$prediction), + c(0, 1, 2, 3)) + expect_equal(sort(collect(distinct(select(fitted.model2, "prediction")))$prediction), + c(0, 1, 2)) }) test_that("spark.lda with libsvm", { |