aboutsummaryrefslogtreecommitdiff
path: root/R/pkg/inst/tests/test_mllib.R
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-11-04 08:28:33 -0800
committerXiangrui Meng <meng@databricks.com>2015-11-04 08:28:33 -0800
commite328b69c31821e4b27673d7ef6182ab3b7a05ca8 (patch)
tree7bd0416235fa72fca7097accb6c0a7a6019f80e5 /R/pkg/inst/tests/test_mllib.R
parentc09e5139874fb3626e005c8240cca5308b902ef3 (diff)
downloadspark-e328b69c31821e4b27673d7ef6182ab3b7a05ca8.tar.gz
spark-e328b69c31821e4b27673d7ef6182ab3b7a05ca8.tar.bz2
spark-e328b69c31821e4b27673d7ef6182ab3b7a05ca8.zip
[SPARK-9492][ML][R] LogisticRegression in R should provide model statistics
Like ml ```LinearRegression```, ```LogisticRegression``` should provide a training summary including feature names and their coefficients. Author: Yanbo Liang <ybliang8@gmail.com> Closes #9303 from yanboliang/spark-9492.
Diffstat (limited to 'R/pkg/inst/tests/test_mllib.R')
-rw-r--r--R/pkg/inst/tests/test_mllib.R17
1 files changed, 17 insertions, 0 deletions
diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R
index 3331ce7383..032cfef061 100644
--- a/R/pkg/inst/tests/test_mllib.R
+++ b/R/pkg/inst/tests/test_mllib.R
@@ -67,3 +67,20 @@ test_that("summary coefficients match with native glm", {
as.character(stats$features) ==
c("(Intercept)", "Sepal_Length", "Species_versicolor", "Species_virginica")))
})
+
+test_that("summary coefficients match with native glm of family 'binomial'", {
+ df <- createDataFrame(sqlContext, iris)
+ training <- filter(df, df$Species != "setosa")
+ stats <- summary(glm(Species ~ Sepal_Length + Sepal_Width, data = training,
+ family = "binomial"))
+ coefs <- as.vector(stats$coefficients)
+
+ rTraining <- iris[iris$Species %in% c("versicolor","virginica"),]
+ rCoefs <- as.vector(coef(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining,
+ family = binomial(link = "logit"))))
+
+ expect_true(all(abs(rCoefs - coefs) < 1e-4))
+ expect_true(all(
+ as.character(stats$features) ==
+ c("(Intercept)", "Sepal_Length", "Sepal_Width")))
+})