aboutsummaryrefslogtreecommitdiff
path: root/R/pkg/inst/tests
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2016-02-23 15:42:58 -0800
committerXiangrui Meng <meng@databricks.com>2016-02-23 15:42:58 -0800
commit8d29001dec5c3695721a76df3f70da50512ef28f (patch)
treedcb610ddff00188cf9898cce6d3eee029c44010b /R/pkg/inst/tests
parent15e30155631d52e35ab8522584027ab350e5acb3 (diff)
downloadspark-8d29001dec5c3695721a76df3f70da50512ef28f.tar.gz
spark-8d29001dec5c3695721a76df3f70da50512ef28f.tar.bz2
spark-8d29001dec5c3695721a76df3f70da50512ef28f.zip
[SPARK-13011] K-means wrapper in SparkR
https://issues.apache.org/jira/browse/SPARK-13011 Author: Xusen Yin <yinxusen@gmail.com> Closes #11124 from yinxusen/SPARK-13011.
Diffstat (limited to 'R/pkg/inst/tests')
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib.R28
1 files changed, 28 insertions, 0 deletions
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index 08099dd96a..595512e0e0 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -113,3 +113,31 @@ test_that("summary works on base GLM models", {
baseSummary <- summary(baseModel)
expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4)
})
+
+test_that("kmeans", {
+ newIris <- iris
+ newIris$Species <- NULL
+ training <- suppressWarnings(createDataFrame(sqlContext, newIris))
+
+ # Cache the DataFrame here to work around the bug SPARK-13178.
+ cache(training)
+ take(training, 1)
+
+ model <- kmeans(x = training, centers = 2)
+ sample <- take(select(predict(model, training), "prediction"), 1)
+ expect_equal(typeof(sample$prediction), "integer")
+ expect_equal(sample$prediction, 1)
+
+ # Test stats::kmeans is working
+ statsModel <- kmeans(x = newIris, centers = 2)
+ expect_equal(unique(statsModel$cluster), c(1, 2))
+
+ # Test fitted works on KMeans
+ fitted.model <- fitted(model)
+ expect_equal(sort(collect(distinct(select(fitted.model, "prediction")))$prediction), c(0, 1))
+
+ # Test summary works on KMeans
+ summary.model <- summary(model)
+ cluster <- summary.model$cluster
+ expect_equal(sort(collect(distinct(select(cluster, "prediction")))$prediction), c(0, 1))
+})