From 87ac84d43729c54be100bb9ad7dc6e8fa14b8805 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 29 Apr 2016 09:42:54 -0700 Subject: [SPARK-14314][SPARK-14315][ML][SPARKR] Model persistence in SparkR (glm & kmeans) SparkR ```glm``` and ```kmeans``` model persistence. Unit tests. Author: Yanbo Liang Author: Gayathri Murali Closes #12778 from yanboliang/spark-14311. Closes #12680 Closes #12683 --- R/pkg/inst/tests/testthat/test_mllib.R | 41 ++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) (limited to 'R/pkg/inst/tests') diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index 954abb00d4..6a822be121 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -126,6 +126,33 @@ test_that("glm summary", { expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4) }) +test_that("glm save/load", { + training <- suppressWarnings(createDataFrame(sqlContext, iris)) + m <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) + s <- summary(m) + + modelPath <- tempfile(pattern = "glm", fileext = ".tmp") + ml.save(m, modelPath) + expect_error(ml.save(m, modelPath)) + ml.save(m, modelPath, overwrite = TRUE) + m2 <- ml.load(modelPath) + s2 <- summary(m2) + + expect_equal(s$coefficients, s2$coefficients) + expect_equal(rownames(s$coefficients), rownames(s2$coefficients)) + expect_equal(s$dispersion, s2$dispersion) + expect_equal(s$null.deviance, s2$null.deviance) + expect_equal(s$deviance, s2$deviance) + expect_equal(s$df.null, s2$df.null) + expect_equal(s$df.residual, s2$df.residual) + expect_equal(s$aic, s2$aic) + expect_equal(s$iter, s2$iter) + expect_true(!s$is.loaded) + expect_true(s2$is.loaded) + + unlink(modelPath) +}) + test_that("kmeans", { newIris <- iris newIris$Species <- NULL @@ -150,6 +177,20 @@ test_that("kmeans", { summary.model <- summary(model) cluster <- summary.model$cluster expect_equal(sort(collect(distinct(select(cluster, "prediction")))$prediction), c(0, 1)) + + # Test model save/load + modelPath <- tempfile(pattern = "kmeans", fileext = ".tmp") + ml.save(model, modelPath) + expect_error(ml.save(model, modelPath)) + ml.save(model, modelPath, overwrite = TRUE) + model2 <- ml.load(modelPath) + summary2 <- summary(model2) + expect_equal(sort(unlist(summary.model$size)), sort(unlist(summary2$size))) + expect_equal(summary.model$coefficients, summary2$coefficients) + expect_true(!summary.model$is.loaded) + expect_true(summary2$is.loaded) + + unlink(modelPath) }) test_that("naiveBayes", { -- cgit v1.2.3