aboutsummaryrefslogtreecommitdiff
path: root/R
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-11-10 11:34:36 -0800
committerXiangrui Meng <meng@databricks.com>2015-11-10 11:34:36 -0800
commitf14e95115c0939a77ebcb00209696a87fd651ff9 (patch)
tree047faca37642e049f252e870232492253e2e4ba1 /R
parent87aedc48c01dffbd880e6ca84076ed47c68f88d0 (diff)
downloadspark-f14e95115c0939a77ebcb00209696a87fd651ff9.tar.gz
spark-f14e95115c0939a77ebcb00209696a87fd651ff9.tar.bz2
spark-f14e95115c0939a77ebcb00209696a87fd651ff9.zip
[ML][R] SparkR::glm summary result to compare with native R
Follow up #9561. Due to [SPARK-11587](https://issues.apache.org/jira/browse/SPARK-11587) has been fixed, we should compare SparkR::glm summary result with native R output rather than hard-code one. mengxr Author: Yanbo Liang <ybliang8@gmail.com> Closes #9590 from yanboliang/glm-r-test.
Diffstat (limited to 'R')
-rw-r--r--R/pkg/R/mllib.R2
-rw-r--r--R/pkg/inst/tests/test_mllib.R31
2 files changed, 11 insertions, 22 deletions
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index 7126b7cde4..f23e1c7f1f 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -106,7 +106,7 @@ setMethod("summary", signature(object = "PipelineModel"),
coefficients <- matrix(coefficients, ncol = 4)
colnames(coefficients) <- c("Estimate", "Std. Error", "t value", "Pr(>|t|)")
rownames(coefficients) <- unlist(features)
- return(list(DevianceResiduals = devianceResiduals, Coefficients = coefficients))
+ return(list(devianceResiduals = devianceResiduals, coefficients = coefficients))
} else {
coefficients <- as.matrix(unlist(coefficients))
colnames(coefficients) <- c("Estimate")
diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R
index 42287ea19a..d497ad8c9d 100644
--- a/R/pkg/inst/tests/test_mllib.R
+++ b/R/pkg/inst/tests/test_mllib.R
@@ -72,22 +72,17 @@ test_that("feature interaction vs native glm", {
test_that("summary coefficients match with native glm", {
training <- createDataFrame(sqlContext, iris)
stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training, solver = "normal"))
- coefs <- unlist(stats$Coefficients)
- devianceResiduals <- unlist(stats$DevianceResiduals)
+ coefs <- unlist(stats$coefficients)
+ devianceResiduals <- unlist(stats$devianceResiduals)
- rCoefs <- as.vector(coef(glm(Sepal.Width ~ Sepal.Length + Species, data = iris)))
- rStdError <- c(0.23536, 0.04630, 0.07207, 0.09331)
- rTValue <- c(7.123, 7.557, -13.644, -10.798)
- rPValue <- c(0.0, 0.0, 0.0, 0.0)
+ rStats <- summary(glm(Sepal.Width ~ Sepal.Length + Species, data = iris))
+ rCoefs <- unlist(rStats$coefficients)
rDevianceResiduals <- c(-0.95096, 0.72918)
- expect_true(all(abs(rCoefs - coefs[1:4]) < 1e-6))
- expect_true(all(abs(rStdError - coefs[5:8]) < 1e-5))
- expect_true(all(abs(rTValue - coefs[9:12]) < 1e-3))
- expect_true(all(abs(rPValue - coefs[13:16]) < 1e-6))
+ expect_true(all(abs(rCoefs - coefs) < 1e-5))
expect_true(all(abs(rDevianceResiduals - devianceResiduals) < 1e-5))
expect_true(all(
- rownames(stats$Coefficients) ==
+ rownames(stats$coefficients) ==
c("(Intercept)", "Sepal_Length", "Species_versicolor", "Species_virginica")))
})
@@ -96,21 +91,15 @@ test_that("summary coefficients match with native glm of family 'binomial'", {
training <- filter(df, df$Species != "setosa")
stats <- summary(glm(Species ~ Sepal_Length + Sepal_Width, data = training,
family = "binomial"))
- coefs <- as.vector(stats$Coefficients)
+ coefs <- as.vector(stats$coefficients[,1])
rTraining <- iris[iris$Species %in% c("versicolor","virginica"),]
rCoefs <- as.vector(coef(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining,
family = binomial(link = "logit"))))
- rStdError <- c(3.0974, 0.5169, 0.8628)
- rTValue <- c(-4.212, 3.680, 0.469)
- rPValue <- c(0.000, 0.000, 0.639)
-
- expect_true(all(abs(rCoefs - coefs[1:3]) < 1e-4))
- expect_true(all(abs(rStdError - coefs[4:6]) < 1e-4))
- expect_true(all(abs(rTValue - coefs[7:9]) < 1e-3))
- expect_true(all(abs(rPValue - coefs[10:12]) < 1e-3))
+
+ expect_true(all(abs(rCoefs - coefs) < 1e-4))
expect_true(all(
- rownames(stats$Coefficients) ==
+ rownames(stats$coefficients) ==
c("(Intercept)", "Sepal_Length", "Sepal_Width")))
})