aboutsummaryrefslogtreecommitdiff
path: root/R/pkg/inst
diff options
context:
space:
mode:
authorFelix Cheung <felixcheung_m@hotmail.com>2017-01-10 11:42:07 -0800
committerFelix Cheung <felixcheung@apache.org>2017-01-10 11:42:07 -0800
commit9bc3507e411b0ad9207e3053f80ac82f19b18f26 (patch)
tree05b787ff541d809eef4913b0c61824556ff3fd9b /R/pkg/inst
parentd5b1dc934a2482886c2c095de90e4c6a49ec42bd (diff)
downloadspark-9bc3507e411b0ad9207e3053f80ac82f19b18f26.tar.gz
spark-9bc3507e411b0ad9207e3053f80ac82f19b18f26.tar.bz2
spark-9bc3507e411b0ad9207e3053f80ac82f19b18f26.zip
[SPARK-19133][SPARKR][ML] fix glm for Gamma, clarify glm family supported
## What changes were proposed in this pull request? R family is a longer list than what Spark supports. ## How was this patch tested? manual Author: Felix Cheung <felixcheung_m@hotmail.com> Closes #16511 from felixcheung/rdocglmfamily.
Diffstat (limited to 'R/pkg/inst')
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib_regression.R26
1 files changed, 17 insertions, 9 deletions
diff --git a/R/pkg/inst/tests/testthat/test_mllib_regression.R b/R/pkg/inst/tests/testthat/test_mllib_regression.R
index e20dafa414..c450a15171 100644
--- a/R/pkg/inst/tests/testthat/test_mllib_regression.R
+++ b/R/pkg/inst/tests/testthat/test_mllib_regression.R
@@ -61,14 +61,22 @@ test_that("spark.glm and predict", {
# poisson family
model <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species,
- family = poisson(link = identity))
+ family = poisson(link = identity))
prediction <- predict(model, training)
expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double")
vals <- collect(select(prediction, "prediction"))
rVals <- suppressWarnings(predict(glm(Sepal.Width ~ Sepal.Length + Species,
- data = iris, family = poisson(link = identity)), iris))
+ data = iris, family = poisson(link = identity)), iris))
expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
+ # Gamma family
+ x <- runif(100, -1, 1)
+ y <- rgamma(100, rate = 10 / exp(0.5 + 1.2 * x), shape = 10)
+ df <- as.DataFrame(as.data.frame(list(x = x, y = y)))
+ model <- glm(y ~ x, family = Gamma, df)
+ out <- capture.output(print(summary(model)))
+ expect_true(any(grepl("Dispersion parameter for gamma family", out)))
+
# Test stats::predict is working
x <- rnorm(15)
y <- x + rnorm(15)
@@ -103,11 +111,11 @@ test_that("spark.glm summary", {
df <- suppressWarnings(createDataFrame(iris))
training <- df[df$Species %in% c("versicolor", "virginica"), ]
stats <- summary(spark.glm(training, Species ~ Sepal_Length + Sepal_Width,
- family = binomial(link = "logit")))
+ family = binomial(link = "logit")))
rTraining <- iris[iris$Species %in% c("versicolor", "virginica"), ]
rStats <- summary(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining,
- family = binomial(link = "logit")))
+ family = binomial(link = "logit")))
coefs <- unlist(stats$coefficients)
rCoefs <- unlist(rStats$coefficients)
@@ -222,7 +230,7 @@ test_that("glm and predict", {
training <- suppressWarnings(createDataFrame(iris))
# gaussian family
model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training)
- prediction <- predict(model, training)
+ prediction <- predict(model, training)
expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double")
vals <- collect(select(prediction, "prediction"))
rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris)
@@ -235,7 +243,7 @@ test_that("glm and predict", {
expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double")
vals <- collect(select(prediction, "prediction"))
rVals <- suppressWarnings(predict(glm(Sepal.Width ~ Sepal.Length + Species,
- data = iris, family = poisson(link = identity)), iris))
+ data = iris, family = poisson(link = identity)), iris))
expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
# Test stats::predict is working
@@ -268,11 +276,11 @@ test_that("glm summary", {
df <- suppressWarnings(createDataFrame(iris))
training <- df[df$Species %in% c("versicolor", "virginica"), ]
stats <- summary(glm(Species ~ Sepal_Length + Sepal_Width, data = training,
- family = binomial(link = "logit")))
+ family = binomial(link = "logit")))
rTraining <- iris[iris$Species %in% c("versicolor", "virginica"), ]
rStats <- summary(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining,
- family = binomial(link = "logit")))
+ family = binomial(link = "logit")))
coefs <- unlist(stats$coefficients)
rCoefs <- unlist(rStats$coefficients)
@@ -409,7 +417,7 @@ test_that("spark.survreg", {
x = c(0, 2, 1, 1, 1, 0, 0), sex = c(0, 0, 0, 0, 1, 1, 1))
expect_error(
model <- survival::survreg(formula = survival::Surv(time, status) ~ x + sex, data = rData),
- NA)
+ NA)
expect_equal(predict(model, rData)[[1]], 3.724591, tolerance = 1e-4)
}
})