diff options
Diffstat (limited to 'R/pkg')
-rw-r--r-- | R/pkg/R/mllib.R | 5 | ||||
-rw-r--r-- | R/pkg/inst/tests/test_mllib.R | 2 |
2 files changed, 4 insertions, 3 deletions
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index cd00bbbeec..25615e805e 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -45,11 +45,12 @@ setClass("PipelineModel", representation(model = "jobj")) #' summary(model) #'} setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFrame"), - function(formula, family = c("gaussian", "binomial"), data, lambda = 0, alpha = 0) { + function(formula, family = c("gaussian", "binomial"), data, lambda = 0, alpha = 0, + solver = "auto") { family <- match.arg(family) model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", "fitRModelFormula", deparse(formula), data@sdf, family, lambda, - alpha) + alpha, solver) return(new("PipelineModel", model = model)) }) diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R index 032f8ec68b..3331ce7383 100644 --- a/R/pkg/inst/tests/test_mllib.R +++ b/R/pkg/inst/tests/test_mllib.R @@ -59,7 +59,7 @@ test_that("feature interaction vs native glm", { test_that("summary coefficients match with native glm", { training <- createDataFrame(sqlContext, iris) - stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training)) + stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training, solver = "l-bfgs")) coefs <- as.vector(stats$coefficients) rCoefs <- as.vector(coef(glm(Sepal.Width ~ Sepal.Length + Species, data = iris))) expect_true(all(abs(rCoefs - coefs) < 1e-6)) |