aboutsummaryrefslogtreecommitdiff
path: root/R/pkg/inst/tests
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-04-29 09:42:54 -0700
committerXiangrui Meng <meng@databricks.com>2016-04-29 09:43:04 -0700
commit87ac84d43729c54be100bb9ad7dc6e8fa14b8805 (patch)
treed3fbb8c5996a10177fd3af3579d160b6278509ac /R/pkg/inst/tests
parenta7d0fedc940721d09350f2e57ae85591e0a3d90e (diff)
downloadspark-87ac84d43729c54be100bb9ad7dc6e8fa14b8805.tar.gz
spark-87ac84d43729c54be100bb9ad7dc6e8fa14b8805.tar.bz2
spark-87ac84d43729c54be100bb9ad7dc6e8fa14b8805.zip
[SPARK-14314][SPARK-14315][ML][SPARKR] Model persistence in SparkR (glm & kmeans)
SparkR ```glm``` and ```kmeans``` model persistence. Unit tests. Author: Yanbo Liang <ybliang8@gmail.com> Author: Gayathri Murali <gayathri.m.softie@gmail.com> Closes #12778 from yanboliang/spark-14311. Closes #12680 Closes #12683
Diffstat (limited to 'R/pkg/inst/tests')
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib.R41
1 files changed, 41 insertions, 0 deletions
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index 954abb00d4..6a822be121 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -126,6 +126,33 @@ test_that("glm summary", {
expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4)
})
+test_that("glm save/load", {
+ training <- suppressWarnings(createDataFrame(sqlContext, iris))
+ m <- glm(Sepal_Width ~ Sepal_Length + Species, data = training)
+ s <- summary(m)
+
+ modelPath <- tempfile(pattern = "glm", fileext = ".tmp")
+ ml.save(m, modelPath)
+ expect_error(ml.save(m, modelPath))
+ ml.save(m, modelPath, overwrite = TRUE)
+ m2 <- ml.load(modelPath)
+ s2 <- summary(m2)
+
+ expect_equal(s$coefficients, s2$coefficients)
+ expect_equal(rownames(s$coefficients), rownames(s2$coefficients))
+ expect_equal(s$dispersion, s2$dispersion)
+ expect_equal(s$null.deviance, s2$null.deviance)
+ expect_equal(s$deviance, s2$deviance)
+ expect_equal(s$df.null, s2$df.null)
+ expect_equal(s$df.residual, s2$df.residual)
+ expect_equal(s$aic, s2$aic)
+ expect_equal(s$iter, s2$iter)
+ expect_true(!s$is.loaded)
+ expect_true(s2$is.loaded)
+
+ unlink(modelPath)
+})
+
test_that("kmeans", {
newIris <- iris
newIris$Species <- NULL
@@ -150,6 +177,20 @@ test_that("kmeans", {
summary.model <- summary(model)
cluster <- summary.model$cluster
expect_equal(sort(collect(distinct(select(cluster, "prediction")))$prediction), c(0, 1))
+
+ # Test model save/load
+ modelPath <- tempfile(pattern = "kmeans", fileext = ".tmp")
+ ml.save(model, modelPath)
+ expect_error(ml.save(model, modelPath))
+ ml.save(model, modelPath, overwrite = TRUE)
+ model2 <- ml.load(modelPath)
+ summary2 <- summary(model2)
+ expect_equal(sort(unlist(summary.model$size)), sort(unlist(summary2$size)))
+ expect_equal(summary.model$coefficients, summary2$coefficients)
+ expect_true(!summary.model$is.loaded)
+ expect_true(summary2$is.loaded)
+
+ unlink(modelPath)
})
test_that("naiveBayes", {