aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--R/pkg/R/mllib.R10
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib.R6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala40
4 files changed, 55 insertions, 5 deletions
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index 64d19fab7e..9a53f757b4 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -138,10 +138,11 @@ predict_internal <- function(object, newData) {
#' This can be a character string naming a family function, a family function or
#' the result of a call to a family function. Refer R family at
#' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}.
-#' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance
-#' weights as 1.0.
#' @param tol positive convergence tolerance of iterations.
#' @param maxIter integer giving the maximal number of IRLS iterations.
+#' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance
+#' weights as 1.0.
+#' @param regParam regularization parameter for L2 regularization.
#' @param ... additional arguments passed to the method.
#' @aliases spark.glm,SparkDataFrame,formula-method
#' @return \code{spark.glm} returns a fitted generalized linear model
@@ -171,7 +172,8 @@ predict_internal <- function(object, newData) {
#' @note spark.glm since 2.0.0
#' @seealso \link{glm}, \link{read.ml}
setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
- function(data, formula, family = gaussian, tol = 1e-6, maxIter = 25, weightCol = NULL) {
+ function(data, formula, family = gaussian, tol = 1e-6, maxIter = 25, weightCol = NULL,
+ regParam = 0.0) {
if (is.character(family)) {
family <- get(family, mode = "function", envir = parent.frame())
}
@@ -190,7 +192,7 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper",
"fit", formula, data@sdf, family$family, family$link,
- tol, as.integer(maxIter), as.character(weightCol))
+ tol, as.integer(maxIter), as.character(weightCol), regParam)
new("GeneralizedLinearRegressionModel", jobj = jobj)
})
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index 1e6da650d1..825a24073b 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -148,6 +148,12 @@ test_that("spark.glm summary", {
baseModel <- stats::glm(Sepal.Width ~ Sepal.Length + Species, data = iris)
baseSummary <- summary(baseModel)
expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4)
+
+ # Test spark.glm works with regularization parameter
+ data <- as.data.frame(cbind(a1, a2, b))
+ df <- suppressWarnings(createDataFrame(data))
+ regStats <- summary(spark.glm(df, b ~ a1 + a2, regParam = 1.0))
+ expect_equal(regStats$aic, 13.32836, tolerance = 1e-4) # 13.32836 is from summary() result
})
test_that("spark.glm save/load", {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
index 0d3181d0ac..7a6ab618a1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
@@ -69,7 +69,8 @@ private[r] object GeneralizedLinearRegressionWrapper
link: String,
tol: Double,
maxIter: Int,
- weightCol: String): GeneralizedLinearRegressionWrapper = {
+ weightCol: String,
+ regParam: Double): GeneralizedLinearRegressionWrapper = {
val rFormula = new RFormula()
.setFormula(formula)
val rFormulaModel = rFormula.fit(data)
@@ -86,6 +87,7 @@ private[r] object GeneralizedLinearRegressionWrapper
.setTol(tol)
.setMaxIter(maxIter)
.setWeightCol(weightCol)
+ .setRegParam(regParam)
val pipeline = new Pipeline()
.setStages(Array(rFormulaModel, glr))
.fit(data)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
index a4568e83fa..d8032c4e17 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
@@ -1034,6 +1034,46 @@ class GeneralizedLinearRegressionSuite
.setFamily("gaussian")
.fit(datasetGaussianIdentity.as[LabeledPoint])
}
+
+ test("generalized linear regression: regularization parameter") {
+ /*
+ R code:
+
+ a1 <- c(0, 1, 2, 3)
+ a2 <- c(5, 2, 1, 3)
+ b <- c(1, 0, 1, 0)
+ data <- as.data.frame(cbind(a1, a2, b))
+ df <- suppressWarnings(createDataFrame(data))
+
+ for (regParam in c(0.0, 0.1, 1.0)) {
+ model <- spark.glm(df, b ~ a1 + a2, regParam = regParam)
+ print(as.vector(summary(model)$aic))
+ }
+
+ [1] 12.88188
+ [1] 12.92681
+ [1] 13.32836
+ */
+ val dataset = spark.createDataFrame(Seq(
+ LabeledPoint(1, Vectors.dense(5, 0)),
+ LabeledPoint(0, Vectors.dense(2, 1)),
+ LabeledPoint(1, Vectors.dense(1, 2)),
+ LabeledPoint(0, Vectors.dense(3, 3))
+ ))
+ val expected = Seq(12.88188, 12.92681, 13.32836)
+
+ var idx = 0
+ for (regParam <- Seq(0.0, 0.1, 1.0)) {
+ val trainer = new GeneralizedLinearRegression()
+ .setRegParam(regParam)
+ .setLabelCol("label")
+ .setFeaturesCol("features")
+ val model = trainer.fit(dataset)
+ val actual = model.summary.aic
+ assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with regParam = $regParam.")
+ idx += 1
+ }
+ }
}
object GeneralizedLinearRegressionSuite {