aboutsummaryrefslogtreecommitdiff
path: root/R/pkg
diff options
context:
space:
mode:
authorlewuathe <lewuathe@me.com>2015-10-19 10:46:10 -0700
committerDB Tsai <dbt@netflix.com>2015-10-19 10:46:10 -0700
commit4c33a34ba3167ae67fdb4978ea2166ce65638fb9 (patch)
tree5a5dbae89a230ad0c82acab25ba98e9121b1af6b /R/pkg
parentdfa41e63b98c28b087c56f94658b5e99e8a7758c (diff)
downloadspark-4c33a34ba3167ae67fdb4978ea2166ce65638fb9.tar.gz
spark-4c33a34ba3167ae67fdb4978ea2166ce65638fb9.tar.bz2
spark-4c33a34ba3167ae67fdb4978ea2166ce65638fb9.zip
[SPARK-10668] [ML] Use WeightedLeastSquares in LinearRegression with L…
…2 regularization if the number of features is small Author: lewuathe <lewuathe@me.com> Author: Lewuathe <sasaki@treasure-data.com> Author: Kai Sasaki <sasaki@treasure-data.com> Author: Lewuathe <lewuathe@me.com> Closes #8884 from Lewuathe/SPARK-10668.
Diffstat (limited to 'R/pkg')
-rw-r--r--R/pkg/R/mllib.R5
-rw-r--r--R/pkg/inst/tests/test_mllib.R2
2 files changed, 4 insertions, 3 deletions
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index cd00bbbeec..25615e805e 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -45,11 +45,12 @@ setClass("PipelineModel", representation(model = "jobj"))
#' summary(model)
#'}
setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFrame"),
- function(formula, family = c("gaussian", "binomial"), data, lambda = 0, alpha = 0) {
+ function(formula, family = c("gaussian", "binomial"), data, lambda = 0, alpha = 0,
+ solver = "auto") {
family <- match.arg(family)
model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"fitRModelFormula", deparse(formula), data@sdf, family, lambda,
- alpha)
+ alpha, solver)
return(new("PipelineModel", model = model))
})
diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R
index 032f8ec68b..3331ce7383 100644
--- a/R/pkg/inst/tests/test_mllib.R
+++ b/R/pkg/inst/tests/test_mllib.R
@@ -59,7 +59,7 @@ test_that("feature interaction vs native glm", {
test_that("summary coefficients match with native glm", {
training <- createDataFrame(sqlContext, iris)
- stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training))
+ stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training, solver = "l-bfgs"))
coefs <- as.vector(stats$coefficients)
rCoefs <- as.vector(coef(glm(Sepal.Width ~ Sepal.Length + Species, data = iris)))
expect_true(all(abs(rCoefs - coefs) < 1e-6))