diff options
author | lewuathe <lewuathe@me.com> | 2015-10-19 10:46:10 -0700 |
---|---|---|
committer | DB Tsai <dbt@netflix.com> | 2015-10-19 10:46:10 -0700 |
commit | 4c33a34ba3167ae67fdb4978ea2166ce65638fb9 (patch) | |
tree | 5a5dbae89a230ad0c82acab25ba98e9121b1af6b /R/pkg | |
parent | dfa41e63b98c28b087c56f94658b5e99e8a7758c (diff) | |
download | spark-4c33a34ba3167ae67fdb4978ea2166ce65638fb9.tar.gz spark-4c33a34ba3167ae67fdb4978ea2166ce65638fb9.tar.bz2 spark-4c33a34ba3167ae67fdb4978ea2166ce65638fb9.zip |
[SPARK-10668] [ML] Use WeightedLeastSquares in LinearRegression with L…
…2 regularization if the number of features is small
Author: lewuathe <lewuathe@me.com>
Author: Lewuathe <sasaki@treasure-data.com>
Author: Kai Sasaki <sasaki@treasure-data.com>
Author: Lewuathe <lewuathe@me.com>
Closes #8884 from Lewuathe/SPARK-10668.
Diffstat (limited to 'R/pkg')
-rw-r--r-- | R/pkg/R/mllib.R | 5 | ||||
-rw-r--r-- | R/pkg/inst/tests/test_mllib.R | 2 |
2 files changed, 4 insertions, 3 deletions
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index cd00bbbeec..25615e805e 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -45,11 +45,12 @@ setClass("PipelineModel", representation(model = "jobj")) #' summary(model) #'} setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFrame"), - function(formula, family = c("gaussian", "binomial"), data, lambda = 0, alpha = 0) { + function(formula, family = c("gaussian", "binomial"), data, lambda = 0, alpha = 0, + solver = "auto") { family <- match.arg(family) model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", "fitRModelFormula", deparse(formula), data@sdf, family, lambda, - alpha) + alpha, solver) return(new("PipelineModel", model = model)) }) diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R index 032f8ec68b..3331ce7383 100644 --- a/R/pkg/inst/tests/test_mllib.R +++ b/R/pkg/inst/tests/test_mllib.R @@ -59,7 +59,7 @@ test_that("feature interaction vs native glm", { test_that("summary coefficients match with native glm", { training <- createDataFrame(sqlContext, iris) - stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training)) + stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training, solver = "l-bfgs")) coefs <- as.vector(stats$coefficients) rCoefs <- as.vector(coef(glm(Sepal.Width ~ Sepal.Length + Species, data = iris))) expect_true(all(abs(rCoefs - coefs) < 1e-6)) |