aboutsummaryrefslogtreecommitdiff
path: root/R/pkg/inst/tests
diff options
context:
space:
mode:
Diffstat (limited to 'R/pkg/inst/tests')
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib_classification.R17
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib_clustering.R4
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib_regression.R27
3 files changed, 30 insertions, 18 deletions
diff --git a/R/pkg/inst/tests/testthat/test_mllib_classification.R b/R/pkg/inst/tests/testthat/test_mllib_classification.R
index 2e0dea321e..5f84a620c1 100644
--- a/R/pkg/inst/tests/testthat/test_mllib_classification.R
+++ b/R/pkg/inst/tests/testthat/test_mllib_classification.R
@@ -68,12 +68,17 @@ test_that("spark.logit", {
df <- suppressWarnings(createDataFrame(iris))
model <- spark.logit(df, Species ~ ., regParam = 0.5)
summary <- summary(model)
+
+ # test summary coefficients return matrix type
+ expect_true(class(summary$coefficients) == "matrix")
+ expect_true(class(summary$coefficients[, 1]) == "numeric")
+
versicolorCoefsR <- c(1.52, 0.03, -0.53, 0.04, 0.00)
virginicaCoefsR <- c(-2.62, 0.27, -0.02, 0.16, 0.42)
setosaCoefsR <- c(1.10, -0.29, 0.55, -0.19, -0.42)
- versicolorCoefs <- unlist(summary$coefficients[, "versicolor"])
- virginicaCoefs <- unlist(summary$coefficients[, "virginica"])
- setosaCoefs <- unlist(summary$coefficients[, "setosa"])
+ versicolorCoefs <- summary$coefficients[, "versicolor"]
+ virginicaCoefs <- summary$coefficients[, "virginica"]
+ setosaCoefs <- summary$coefficients[, "setosa"]
expect_true(all(abs(versicolorCoefsR - versicolorCoefs) < 0.1))
expect_true(all(abs(virginicaCoefsR - virginicaCoefs) < 0.1))
expect_true(all(abs(setosaCoefs - setosaCoefs) < 0.1))
@@ -136,8 +141,8 @@ test_that("spark.logit", {
summary <- summary(model)
versicolorCoefsR <- c(3.94, -0.16, -0.02, -0.35, -0.78)
virginicaCoefsR <- c(-3.94, 0.16, -0.02, 0.35, 0.78)
- versicolorCoefs <- unlist(summary$coefficients[, "versicolor"])
- virginicaCoefs <- unlist(summary$coefficients[, "virginica"])
+ versicolorCoefs <- summary$coefficients[, "versicolor"]
+ virginicaCoefs <- summary$coefficients[, "virginica"]
expect_true(all(abs(versicolorCoefsR - versicolorCoefs) < 0.1))
expect_true(all(abs(virginicaCoefsR - virginicaCoefs) < 0.1))
@@ -145,7 +150,7 @@ test_that("spark.logit", {
model <- spark.logit(training, Species ~ ., regParam = 0.5)
summary <- summary(model)
coefsR <- c(-6.08, 0.25, 0.16, 0.48, 1.04)
- coefs <- unlist(summary$coefficients[, "Estimate"])
+ coefs <- summary$coefficients[, "Estimate"]
expect_true(all(abs(coefsR - coefs) < 0.1))
# Test prediction with string label
diff --git a/R/pkg/inst/tests/testthat/test_mllib_clustering.R b/R/pkg/inst/tests/testthat/test_mllib_clustering.R
index aad834bb64..28a6eeba2c 100644
--- a/R/pkg/inst/tests/testthat/test_mllib_clustering.R
+++ b/R/pkg/inst/tests/testthat/test_mllib_clustering.R
@@ -166,6 +166,10 @@ test_that("spark.kmeans", {
expect_equal(k, 2)
expect_equal(sort(collect(distinct(select(cluster, "prediction")))$prediction), c(0, 1))
+ # test summary coefficients return matrix type
+ expect_true(class(summary.model$coefficients) == "matrix")
+ expect_true(class(summary.model$coefficients[1, ]) == "numeric")
+
# Test model save/load
modelPath <- tempfile(pattern = "spark-kmeans", fileext = ".tmp")
write.ml(model, modelPath)
diff --git a/R/pkg/inst/tests/testthat/test_mllib_regression.R b/R/pkg/inst/tests/testthat/test_mllib_regression.R
index c450a15171..81a5bdc414 100644
--- a/R/pkg/inst/tests/testthat/test_mllib_regression.R
+++ b/R/pkg/inst/tests/testthat/test_mllib_regression.R
@@ -87,11 +87,14 @@ test_that("spark.glm summary", {
# gaussian family
training <- suppressWarnings(createDataFrame(iris))
stats <- summary(spark.glm(training, Sepal_Width ~ Sepal_Length + Species))
-
rStats <- summary(glm(Sepal.Width ~ Sepal.Length + Species, data = iris))
- coefs <- unlist(stats$coefficients)
- rCoefs <- unlist(rStats$coefficients)
+ # test summary coefficients return matrix type
+ expect_true(class(stats$coefficients) == "matrix")
+ expect_true(class(stats$coefficients[, 1]) == "numeric")
+
+ coefs <- stats$coefficients
+ rCoefs <- rStats$coefficients
expect_true(all(abs(rCoefs - coefs) < 1e-4))
expect_true(all(
rownames(stats$coefficients) ==
@@ -117,8 +120,8 @@ test_that("spark.glm summary", {
rStats <- summary(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining,
family = binomial(link = "logit")))
- coefs <- unlist(stats$coefficients)
- rCoefs <- unlist(rStats$coefficients)
+ coefs <- stats$coefficients
+ rCoefs <- rStats$coefficients
expect_true(all(abs(rCoefs - coefs) < 1e-4))
expect_true(all(
rownames(stats$coefficients) ==
@@ -141,8 +144,8 @@ test_that("spark.glm summary", {
stats <- summary(spark.glm(df, b ~ a1 + a2, family = "binomial", weightCol = "w"))
rStats <- summary(glm(b ~ a1 + a2, family = "binomial", data = data, weights = w))
- coefs <- unlist(stats$coefficients)
- rCoefs <- unlist(rStats$coefficients)
+ coefs <- stats$coefficients
+ rCoefs <- rStats$coefficients
expect_true(all(abs(rCoefs - coefs) < 1e-3))
expect_true(all(rownames(stats$coefficients) == c("(Intercept)", "a1", "a2")))
expect_equal(stats$dispersion, rStats$dispersion)
@@ -169,7 +172,7 @@ test_that("spark.glm summary", {
data <- as.data.frame(cbind(A, b))
df <- createDataFrame(data)
stats <- summary(spark.glm(df, b ~ . - 1))
- coefs <- unlist(stats$coefficients)
+ coefs <- stats$coefficients
expect_true(all(abs(c(0.5, 0.25) - coefs) < 1e-4))
})
@@ -259,8 +262,8 @@ test_that("glm summary", {
rStats <- summary(glm(Sepal.Width ~ Sepal.Length + Species, data = iris))
- coefs <- unlist(stats$coefficients)
- rCoefs <- unlist(rStats$coefficients)
+ coefs <- stats$coefficients
+ rCoefs <- rStats$coefficients
expect_true(all(abs(rCoefs - coefs) < 1e-4))
expect_true(all(
rownames(stats$coefficients) ==
@@ -282,8 +285,8 @@ test_that("glm summary", {
rStats <- summary(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining,
family = binomial(link = "logit")))
- coefs <- unlist(stats$coefficients)
- rCoefs <- unlist(rStats$coefficients)
+ coefs <- stats$coefficients
+ rCoefs <- rStats$coefficients
expect_true(all(abs(rCoefs - coefs) < 1e-4))
expect_true(all(
rownames(stats$coefficients) ==