aboutsummaryrefslogtreecommitdiff
path: root/R/pkg/inst/tests/testthat/test_mllib.R
diff options
context:
space:
mode:
Diffstat (limited to 'R/pkg/inst/tests/testthat/test_mllib.R')
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib.R62
1 files changed, 62 insertions, 0 deletions
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index b759b28927..96179864a8 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -508,4 +508,66 @@ test_that("spark.isotonicRegression", {
unlink(modelPath)
})
+test_that("spark.gaussianMixture", {
+ # R code to reproduce the result.
+ # nolint start
+ #' library(mvtnorm)
+ #' set.seed(100)
+ #' a <- rmvnorm(4, c(0, 0))
+ #' b <- rmvnorm(6, c(3, 4))
+ #' data <- rbind(a, b)
+ #' model <- mvnormalmixEM(data, k = 2)
+ #' model$lambda
+ #
+ # [1] 0.4 0.6
+ #
+ #' model$mu
+ #
+ # [1] -0.2614822 0.5128697
+ # [1] 2.647284 4.544682
+ #
+ #' model$sigma
+ #
+ # [[1]]
+ # [,1] [,2]
+ # [1,] 0.08427399 0.00548772
+ # [2,] 0.00548772 0.09090715
+ #
+ # [[2]]
+ # [,1] [,2]
+ # [1,] 0.1641373 -0.1673806
+ # [2,] -0.1673806 0.7508951
+ # nolint end
+ data <- list(list(-0.50219235, 0.1315312), list(-0.07891709, 0.8867848),
+ list(0.11697127, 0.3186301), list(-0.58179068, 0.7145327),
+ list(2.17474057, 3.6401379), list(3.08988614, 4.0962745),
+ list(2.79836605, 4.7398405), list(3.12337950, 3.9706833),
+ list(2.61114575, 4.5108563), list(2.08618581, 6.3102968))
+ df <- createDataFrame(data, c("x1", "x2"))
+ model <- spark.gaussianMixture(df, ~ x1 + x2, k = 2)
+ stats <- summary(model)
+ rLambda <- c(0.4, 0.6)
+ rMu <- c(-0.2614822, 0.5128697, 2.647284, 4.544682)
+ rSigma <- c(0.08427399, 0.00548772, 0.00548772, 0.09090715,
+ 0.1641373, -0.1673806, -0.1673806, 0.7508951)
+ expect_equal(stats$lambda, rLambda)
+ expect_equal(unlist(stats$mu), rMu, tolerance = 1e-3)
+ expect_equal(unlist(stats$sigma), rSigma, tolerance = 1e-3)
+ p <- collect(select(predict(model, df), "prediction"))
+ expect_equal(p$prediction, c(0, 0, 0, 0, 1, 1, 1, 1, 1, 1))
+
+ # Test model save/load
+ modelPath <- tempfile(pattern = "spark-gaussianMixture", fileext = ".tmp")
+ write.ml(model, modelPath)
+ expect_error(write.ml(model, modelPath))
+ write.ml(model, modelPath, overwrite = TRUE)
+ model2 <- read.ml(modelPath)
+ stats2 <- summary(model2)
+ expect_equal(stats$lambda, stats2$lambda)
+ expect_equal(unlist(stats$mu), unlist(stats2$mu))
+ expect_equal(unlist(stats$sigma), unlist(stats2$sigma))
+
+ unlink(modelPath)
+})
+
sparkR.session.stop()