aboutsummaryrefslogtreecommitdiff
path: root/R/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'R/pkg')
-rw-r--r--R/pkg/R/mllib.R5
-rw-r--r--R/pkg/inst/tests/test_mllib.R2
2 files changed, 4 insertions, 3 deletions
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index cd00bbbeec..25615e805e 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -45,11 +45,12 @@ setClass("PipelineModel", representation(model = "jobj"))
#' summary(model)
#'}
setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFrame"),
- function(formula, family = c("gaussian", "binomial"), data, lambda = 0, alpha = 0) {
+ function(formula, family = c("gaussian", "binomial"), data, lambda = 0, alpha = 0,
+ solver = "auto") {
family <- match.arg(family)
model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"fitRModelFormula", deparse(formula), data@sdf, family, lambda,
- alpha)
+ alpha, solver)
return(new("PipelineModel", model = model))
})
diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R
index 032f8ec68b..3331ce7383 100644
--- a/R/pkg/inst/tests/test_mllib.R
+++ b/R/pkg/inst/tests/test_mllib.R
@@ -59,7 +59,7 @@ test_that("feature interaction vs native glm", {
test_that("summary coefficients match with native glm", {
training <- createDataFrame(sqlContext, iris)
- stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training))
+ stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training, solver = "l-bfgs"))
coefs <- as.vector(stats$coefficients)
rCoefs <- as.vector(coef(glm(Sepal.Width ~ Sepal.Length + Species, data = iris)))
expect_true(all(abs(rCoefs - coefs) < 1e-6))