aboutsummaryrefslogtreecommitdiff
path: root/R
diff options
context:
space:
mode:
authorXin Ren <iamshrek@126.com>2016-08-31 21:39:31 -0700
committerShivaram Venkataraman <shivaram@cs.berkeley.edu>2016-08-31 21:39:31 -0700
commit7a5000f39ef4f195696836f8a4e8ab4ff5c14dd2 (patch)
tree402f7fc410378418369b4abb4e1b3d92a1358a5c /R
parentd008638fbedc857c1adc1dff399d427b8bae848e (diff)
downloadspark-7a5000f39ef4f195696836f8a4e8ab4ff5c14dd2.tar.gz
spark-7a5000f39ef4f195696836f8a4e8ab4ff5c14dd2.tar.bz2
spark-7a5000f39ef4f195696836f8a4e8ab4ff5c14dd2.zip
[SPARK-17241][SPARKR][MLLIB] SparkR spark.glm should have configurable regularization parameter
https://issues.apache.org/jira/browse/SPARK-17241 ## What changes were proposed in this pull request? Spark has configurable L2 regularization parameter for generalized linear regression. It is very important to have them in SparkR so that users can run ridge regression. ## How was this patch tested? Test manually on local laptop. Author: Xin Ren <iamshrek@126.com> Closes #14856 from keypointt/SPARK-17241.
Diffstat (limited to 'R')
-rw-r--r--R/pkg/R/mllib.R10
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib.R6
2 files changed, 12 insertions, 4 deletions
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index 64d19fab7e..9a53f757b4 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -138,10 +138,11 @@ predict_internal <- function(object, newData) {
#' This can be a character string naming a family function, a family function or
#' the result of a call to a family function. Refer R family at
#' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}.
-#' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance
-#' weights as 1.0.
#' @param tol positive convergence tolerance of iterations.
#' @param maxIter integer giving the maximal number of IRLS iterations.
+#' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance
+#' weights as 1.0.
+#' @param regParam regularization parameter for L2 regularization.
#' @param ... additional arguments passed to the method.
#' @aliases spark.glm,SparkDataFrame,formula-method
#' @return \code{spark.glm} returns a fitted generalized linear model
@@ -171,7 +172,8 @@ predict_internal <- function(object, newData) {
#' @note spark.glm since 2.0.0
#' @seealso \link{glm}, \link{read.ml}
setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
- function(data, formula, family = gaussian, tol = 1e-6, maxIter = 25, weightCol = NULL) {
+ function(data, formula, family = gaussian, tol = 1e-6, maxIter = 25, weightCol = NULL,
+ regParam = 0.0) {
if (is.character(family)) {
family <- get(family, mode = "function", envir = parent.frame())
}
@@ -190,7 +192,7 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper",
"fit", formula, data@sdf, family$family, family$link,
- tol, as.integer(maxIter), as.character(weightCol))
+ tol, as.integer(maxIter), as.character(weightCol), regParam)
new("GeneralizedLinearRegressionModel", jobj = jobj)
})
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index 1e6da650d1..825a24073b 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -148,6 +148,12 @@ test_that("spark.glm summary", {
baseModel <- stats::glm(Sepal.Width ~ Sepal.Length + Species, data = iris)
baseSummary <- summary(baseModel)
expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4)
+
+ # Test spark.glm works with regularization parameter
+ data <- as.data.frame(cbind(a1, a2, b))
+ df <- suppressWarnings(createDataFrame(data))
+ regStats <- summary(spark.glm(df, b ~ a1 + a2, regParam = 1.0))
+ expect_equal(regStats$aic, 13.32836, tolerance = 1e-4) # 13.32836 is from summary() result
})
test_that("spark.glm save/load", {