aboutsummaryrefslogtreecommitdiff
path: root/R/pkg/inst/tests
diff options
context:
space:
mode:
authorwm624@hotmail.com <wm624@hotmail.com>2017-01-31 21:16:37 -0800
committerFelix Cheung <felixcheung@apache.org>2017-01-31 21:16:37 -0800
commit9ac05225e870e41dc86cd6d61c7f0d111d172810 (patch)
tree0576e4c396b0fd34abb94c63bdcae2beb623d64a /R/pkg/inst/tests
parent9063835803e54538c94d95bbddcb4810dd7a1c55 (diff)
downloadspark-9ac05225e870e41dc86cd6d61c7f0d111d172810.tar.gz
spark-9ac05225e870e41dc86cd6d61c7f0d111d172810.tar.bz2
spark-9ac05225e870e41dc86cd6d61c7f0d111d172810.zip
[SPARK-19319][SPARKR] SparkR Kmeans summary returns error when the cluster size doesn't equal to k
## What changes were proposed in this pull request When Kmeans using initMode = "random" and some random seed, it is possible the actual cluster size doesn't equal to the configured `k`. In this case, summary(model) returns error due to the number of cols of coefficient matrix doesn't equal to k. Example: > col1 <- c(1, 2, 3, 4, 0, 1, 2, 3, 4, 0) > col2 <- c(1, 2, 3, 4, 0, 1, 2, 3, 4, 0) > col3 <- c(1, 2, 3, 4, 0, 1, 2, 3, 4, 0) > cols <- as.data.frame(cbind(col1, col2, col3)) > df <- createDataFrame(cols) > > model2 <- spark.kmeans(data = df, ~ ., k = 5, maxIter = 10, initMode = "random", seed = 22222, tol = 1E-5) > > summary(model2) Error in `colnames<-`(`*tmp*`, value = c("col1", "col2", "col3")) : length of 'dimnames' [2] not equal to array extent In addition: Warning message: In matrix(coefficients, ncol = k) : data length [9] is not a sub-multiple or multiple of the number of rows [2] Fix: Get the actual cluster size in the summary and use it to build the coefficient matrix. ## How was this patch tested? Add unit tests. Author: wm624@hotmail.com <wm624@hotmail.com> Closes #16666 from wangmiao1981/kmeans.
Diffstat (limited to 'R/pkg/inst/tests')
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib_clustering.R15
1 files changed, 11 insertions, 4 deletions
diff --git a/R/pkg/inst/tests/testthat/test_mllib_clustering.R b/R/pkg/inst/tests/testthat/test_mllib_clustering.R
index 28a6eeba2c..1661e987b7 100644
--- a/R/pkg/inst/tests/testthat/test_mllib_clustering.R
+++ b/R/pkg/inst/tests/testthat/test_mllib_clustering.R
@@ -196,13 +196,20 @@ test_that("spark.kmeans", {
model2 <- spark.kmeans(data = df, ~ ., k = 5, maxIter = 10,
initMode = "random", seed = 22222, tol = 1E-5)
- fitted.model1 <- fitted(model1)
- fitted.model2 <- fitted(model2)
+ summary.model1 <- summary(model1)
+ summary.model2 <- summary(model2)
+ cluster1 <- summary.model1$cluster
+ cluster2 <- summary.model2$cluster
+ clusterSize1 <- summary.model1$clusterSize
+ clusterSize2 <- summary.model2$clusterSize
+
# The predicted clusters are different
- expect_equal(sort(collect(distinct(select(fitted.model1, "prediction")))$prediction),
+ expect_equal(sort(collect(distinct(select(cluster1, "prediction")))$prediction),
c(0, 1, 2, 3))
- expect_equal(sort(collect(distinct(select(fitted.model2, "prediction")))$prediction),
+ expect_equal(sort(collect(distinct(select(cluster2, "prediction")))$prediction),
c(0, 1, 2))
+ expect_equal(clusterSize1, 4)
+ expect_equal(clusterSize2, 3)
})
test_that("spark.lda with libsvm", {